Make load_dataset instance methods of dataset configs
danbraunai-apollo opened this issue · 1 comments
danbraunai-apollo commented
We currently have:
def load_dataset(
dataset_config: DatasetConfig,
model_n_ctx: Optional[int] = None,
tlens_model_path: Optional[Path] = None,
) -> Dataset:
"""
Load a dataset based on the provided type and arguments.
Args:
dataset_config (DatasetConfig): The dataset config.
model_n_ctx (int): The max context length of the model, used for HFDatasetConfigs. Data
sequences are packed to dataset_config.n_ctx if it is not None and is <= model_n_ctx,
otherwise to model_n_ctx.
tlens_model_path (Optional[Path]): The path to the tlens model, used for modular arithmetic
to collect config info used to train the model.
Returns:
The dataset.
"""
if isinstance(dataset_config, ModularArithmeticDatasetConfig):
return create_modular_arithmetic_dataset(
dataset_config=dataset_config, tlens_model_path=tlens_model_path
)
elif isinstance(dataset_config, HFDatasetConfig):
return create_hf_dataset(dataset_config=dataset_config, model_n_ctx=model_n_ctx)
elif isinstance(dataset_config, VisionDatasetConfig):
return create_vision_dataset(dataset_config=dataset_config)
else:
assert isinstance(dataset_config, BlockVectorDatasetConfig)
return create_block_vector_dataset(dataset_config=dataset_config)
This would be much more natural if each dataset config had a create_dataset method. I guess you'd have them all take in *args, **kwargs, and have modadd also take in tlens_model_path
and HF take in model_n_ctx
.
danbraunai-apollo commented
We've already moved away from doing this with the Ablation schedules. I guess it's fine to leave this as is to preserve the "config BaseModel classes don't have methods" property. Though I still would prefer this to be implemented.