ApolloResearch/rib

Make load_dataset instance methods of dataset configs

danbraunai-apollo opened this issue · 1 comments

We currently have:

def load_dataset(
    dataset_config: DatasetConfig,
    model_n_ctx: Optional[int] = None,
    tlens_model_path: Optional[Path] = None,
) -> Dataset:
    """
    Load a dataset based on the provided type and arguments.

    Args:
        dataset_config (DatasetConfig): The dataset config.
        model_n_ctx (int): The max context length of the model, used for HFDatasetConfigs. Data
            sequences are packed to dataset_config.n_ctx if it is not None and is <= model_n_ctx,
            otherwise to model_n_ctx.
        tlens_model_path (Optional[Path]): The path to the tlens model, used for modular arithmetic
            to collect config info used to train the model.

    Returns:
        The dataset.
    """

    if isinstance(dataset_config, ModularArithmeticDatasetConfig):
        return create_modular_arithmetic_dataset(
            dataset_config=dataset_config, tlens_model_path=tlens_model_path
        )
    elif isinstance(dataset_config, HFDatasetConfig):
        return create_hf_dataset(dataset_config=dataset_config, model_n_ctx=model_n_ctx)
    elif isinstance(dataset_config, VisionDatasetConfig):
        return create_vision_dataset(dataset_config=dataset_config)
    else:
        assert isinstance(dataset_config, BlockVectorDatasetConfig)
        return create_block_vector_dataset(dataset_config=dataset_config)

This would be much more natural if each dataset config had a create_dataset method. I guess you'd have them all take in *args, **kwargs, and have modadd also take in tlens_model_path and HF take in model_n_ctx.

We've already moved away from doing this with the Ablation schedules. I guess it's fine to leave this as is to preserve the "config BaseModel classes don't have methods" property. Though I still would prefer this to be implemented.