Prevent extra arguments to Dataset classes
danbraunai-apollo opened this issue · 1 comments
danbraunai-apollo commented
We currently have e.g.
class ModularArithmeticDatasetConfig(BaseModel):
"""Config for the modular arithmetic dataset.
We set fields to optional so that we have the option of loading them in from a pre-saved config
file (see `rib/loader.create_modular_arithmetic_dataset`)
"""
source: Literal["custom"] = "custom"
name: Literal["modular_arithmetic"] = "modular_arithmetic"
return_set: Literal["train", "test", "all", "both"] = Field(
...,
description="The dataset to return. If 'both', returns both the train and test datasets."
"If 'all', returns the combined train and test datasets.",
)
modulus: Optional[int] = Field(None, description="The modulus to use for the dataset.")
fn_name: Optional[Literal["add", "subtract", "x2xyy2"]] = Field(
None,
description="The function to use for the dataset. One of 'add', 'subtract', or 'x2xyy2'.",
)
frac_train: Optional[float] = Field(
None, description="Fraction of the dataset to use for training."
)
seed: Optional[int] = Field(None, description="The random seed value for reproducibility.")
But this doesn't prevent one from passing extra arguments like "return_set_frac" to this class. We would like it to prevent this.
We just need model_config = ConfigDict(extra="forbid")
in all the relevant classes (I think probably every pydantic class)
nix-apollo commented
Agree forbidding extra arguments is probably good. Sometimes it will mess things up if we remove config arguments then try to load old configs, but I think we don't care very much.
Separately, it would be nice if modular arithmetic dataset supported these arguments, but that's another issue.