huggingface/audio-transformers-course

Gtzan Split Unit 4

Closed this issue · 1 comments

Im chapters/en/chapter4/fine-tuning.mdx the following snippet is presented for loading the gtzan dataset:

from datasets import load_dataset

gtzan = load_dataset("marsyas/gtzan", "all")
gtzan

This returns a DatasetDict object, not a Dataset object, which causes the next snippet to fail:

gtzan = gtzan.train_test_split(seed=42, shuffle=True, test_size=0.1)
gtzan

When I run these together as is I get:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-4-0475a19d0be5>](https://localhost:8080/#) in <cell line: 1>()
----> 1 gtzan = gtzan.train_test_split(seed=42, shuffle=True, test_size=0.1)
      2 gtzan

AttributeError: 'DatasetDict' object has no attribute 'train_test_split'

I can bypass this by pointing the train_test_split function to the "train" split within the original DatasetDict object returned by the load_dataset function:

gtzan = gtzan["train"].train_test_split(seed=42, shuffle=True, test_size=0.1)
gtzan

Output:

DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 899
    })
    test: Dataset({
        features: ['file', 'audio', 'genre'],
        num_rows: 100
    })
})

Recommend updating the second code snippet to call train_test_split on the "train" split. Unless there is a way to get load_dataset to return the Dataset object itself - I'm not even sure what the "all" flag refers to there. I can make this change but was instructed on the discord server to file an issue.

Thank you for reporting the issue, and adding the details. We've added a fix in #131