ApolloResearch/rib

Prevent extra arguments to Dataset classes

danbraunai-apollo opened this issue · 1 comments

We currently have e.g.

class ModularArithmeticDatasetConfig(BaseModel):
    """Config for the modular arithmetic dataset.

    We set fields to optional so that we have the option of loading them in from a pre-saved config
    file (see `rib/loader.create_modular_arithmetic_dataset`)
    """

    source: Literal["custom"] = "custom"
    name: Literal["modular_arithmetic"] = "modular_arithmetic"
    return_set: Literal["train", "test", "all", "both"] = Field(
        ...,
        description="The dataset to return. If 'both', returns both the train and test datasets."
        "If 'all', returns the combined train and test datasets.",
    )
    modulus: Optional[int] = Field(None, description="The modulus to use for the dataset.")
    fn_name: Optional[Literal["add", "subtract", "x2xyy2"]] = Field(
        None,
        description="The function to use for the dataset. One of 'add', 'subtract', or 'x2xyy2'.",
    )
    frac_train: Optional[float] = Field(
        None, description="Fraction of the dataset to use for training."
    )
    seed: Optional[int] = Field(None, description="The random seed value for reproducibility.")

But this doesn't prevent one from passing extra arguments like "return_set_frac" to this class. We would like it to prevent this.

We just need model_config = ConfigDict(extra="forbid") in all the relevant classes (I think probably every pydantic class)

Agree forbidding extra arguments is probably good. Sometimes it will mess things up if we remove config arguments then try to load old configs, but I think we don't care very much.

Separately, it would be nice if modular arithmetic dataset supported these arguments, but that's another issue.