`DistributeFilesDataset`, allow kwargs in `get_sub_epoch_dataset`
Icemole opened this issue · 10 comments
In our codebase we have several dataset implementations, and I want to implement the DistributeFilesDataset
in such a way that it wraps the dataset in a generic way. However, right now the get_sub_epoch_files
callable only receives a list of files, as seen here. Passing the dataset itself to the function so that we allow it to be wrapped in a generic way would be ideal for us. Would there be any way of allowing the callable to have additional kwargs?
As I envision it, in the same way we provide the files and the callable in the DistributeFilesDataset
dictionary, we could also provide additional arbitrary kwargs. This would not only include the original dataset dict, but also whether the user wants to do local caching or not (that is, use CachedFile
) and so on.
Can you give an example how this can be useful? I don't really understand the use case where you would want different kwargs depending on what files you get? Or maybe I misunderstood.
Or phrased different, what I understand what you want to do should already be possible? Just use get_sub_epoch_dataset=functools.partial(my_get_sub_epoch_dataset, **whatever_extra_kwargs_you_want)
? Then my_get_sub_epoch_dataset(files=files, **whatever_extra_kwargs_you_want)
will be called.
Oh, I can look into functools.partial
, that's definitely a possibility. As you might have inferred, this is loosely related to rwth-i6/sisyphus#192 in the sense that sisyphus crashes because I can't pickle a closure. I need (edit: needed) a closure here because in our codebase, depending on the dataset that the user wants to use, the dataset dictionary will have one shape or another.
In this sense, I think DistributeFilesDataset
is a wrapper that can be implemented without taking into consideration for the dataset below, so I'm trying to implement it generically. However, the example provided in the docs returns a hardcoded dictionary:
def get_sub_epoch_dataset(files_subepoch: List[str]) -> Dict[str, Any]:
from returnn.util.file_cache import CachedFile
return {
"class": "HDFDataset",
"files": [CachedFile(fn) for fn in files_subepoch],
}
What I'm trying to do is to provide this dictionary as a parameter, and just modify the "files"
key. That's easy to do, but I wasn't able to do it until you told me about functools.partial
, I'll try that :)
So there's an issue: I need a non-serializable version of the config to make sisyphus happy (which is why I can't use any closure: AttributeError: Can't pickle local object 'get_sub_epoch_dataset.<locals>.get_sub_epoch_files'
), but I also need serializable version of the config to make the RETURNN config happy.
If I pass get_sub_epoch_dataset=functools.partial(my_get_sub_epoch_dataset, **whatever_extra_kwargs_you_want)
, the functools.partial(...)
bit doesn't seem to be resolved before writing the RETURNN config, and so the RETURNN config file has this object/string: 'get_sub_epoch_dataset': functools.partial(<function get_sub_epoch_dataset at 0x7fc1ca851e10>, {'dataset': {'train': ...}})
. I thought this would have been properly serialized once passed to the RETURNN config, but it doesn't seem to be serialized.
Some other people have the same issue online: see faustomorales/vit-keras#15, pymc-devs/pymc#6167. The solution seems to be not using functools.partial
and relying on native methods or cloudpickle
/dill
instead.
I wouldn't like to have to rely on "native methods" (that is, hardcoding each of the functions in a file, depending on the value of the parameters) because I think that would imply implementing an exponential number of functions, depending on the number of **whatever_extra_kwargs_you_want
.
What I'm trying to do is to provide this dictionary as a parameter, and just modify the
"files"
key. That's easy to do, but I wasn't able to do it until you told me aboutfunctools.partial
, I'll try that :)
But that's wrong. That's not what you should do. files
should really be a list of files (or list of nested structure over files) but nothing else (no other kwargs in there). Whatever other options/kwargs you need, you should only pass to functools.partial
, but not to files
. But you can pass whatever you need there. So in your get_sub_epoch_dataset
, you will have everything you need then.
If I pass
get_sub_epoch_dataset=functools.partial(my_get_sub_epoch_dataset, **whatever_extra_kwargs_you_want)
, thefunctools.partial(...)
bit doesn't seem to be resolved before writing the RETURNN config, and so the RETURNN config file has this object/string:'get_sub_epoch_dataset': functools.partial(<function get_sub_epoch_dataset at 0x7fc1ca851e10>, {'dataset': {'train': ...}})
. I thought this would have been properly serialized once passed to the RETURNN config, but it doesn't seem to be serialized.
It sounds like it was serialized, but just incorrectly. Actually everything seems to be correctly serialized except the <function get_sub_epoch_dataset at 0x7fc1ca851e10>
.
I think you need to use CodeWrapper
and DelayedFormat
, sth like CodeWrapper(DelayedFormat('lambda files: get_sub_epoch_dataset(files, **({}))', whatever_extra_kwargs_you_want))
. And then you need to use sth like i6_core.serialization.Import
or some other mechanism to make get_sub_epoch_dataset
available in the config (e.g. via ReturnnConfig
python_epilog
). See some example usages in i6_experiments. E.g., to extend this a bit, here is an example of creating my ReturnnConfig
(from i6_experiments.users.zeyer.train_v3.train
):
returnn_train_config = ReturnnConfig(
returnn_train_config_dict,
python_epilog=[
serialization.Collection(
[
serialization.NonhashedCode(get_import_py_code()),
serialization.NonhashedCode(
nn.ReturnnConfigSerializer.get_base_extern_data_py_code_str_direct(extern_data_raw)
),
*serialize_model_def(model_def, unhashed_package_root=unhashed_package_root),
serialization.Import(
train_def, import_as="_train_def", unhashed_package_root=unhashed_package_root
),
# Consider the imports as non-hashed. We handle any logic changes via the explicit hash below.
serialization.Import(_returnn_v2_get_model, import_as="get_model", use_for_hash=False),
serialization.Import(_returnn_v2_train_step, import_as="train_step", use_for_hash=False),
serialization.ExplicitHash(
{
# Increase the version whenever some incompatible change is made in this train() function,
# which influences the outcome, but would otherwise not influence the hash.
"version": 3,
# Whatever the caller provides. This could also include another version,
# but this is up to the caller.
"extra": extra_hash,
"setup_base_name": setup_base_name,
}
),
serialization.PythonEnlargeStackWorkaroundNonhashedCode,
serialization.PythonCacheManagerFunctionNonhashedCode,
serialization.PythonModelineNonhashedCode,
]
+ list(epilog)
)
],
...
)
Just like I import get_model
there via serialization.Import
, you can import get_sub_epoch_dataset
from somewhere. I used i6_experiments.common.setups.serialization
here but the code was also moved to i6_core.serialization
and I think you can simply use that instead.
What I'm trying to do is to provide this dictionary as a parameter, and just modify the "files" key. That's easy to do, but I wasn't able to do it until you told me about functools.partial, I'll try that :)
But that's wrong. That's not what you should do. files should really be a list of files (or list of nested structure over files) but nothing else (no other kwargs in there). Whatever other options/kwargs you need, you should only pass to functools.partial, but not to files. But you can pass whatever you need there. So in your get_sub_epoch_dataset, you will have everything you need then.
So what I'm trying to do specifically is this:
def get_sub_epoch_dataset(
files_sub_epoch: List[Tuple[str, str]], dataset: Dict[str, Any], num_workers: int, local_caching: bool, **kwargs
) -> Dict[str, Any]:
"""
Obtains the files for each sub-epoch, respecting the structure of the original dataset given in :param:`dataset`.
:param files_sub_epoch: List of feature/alignment pairs to be loaded in the specific sub-epoch.
:param dataset: Dataset dictionary to be used for training.
:param num_workers: Number of parallel workers for data processing. Only used to know which keys to index.
:param local_caching: Whether to cache the files locally or not.
If `True`, the files will be cached through :class:`returnn.util.file_cache.CachedFile`.
:return: Dataset to be used in the specific sub-epoch.
"""
new_dataset = copy.deepcopy(dataset)
if local_caching:
from returnn.util.file_cache import CachedFile
feature_files, alignment_files = tuple(zip(*files_sub_epoch))
if local_caching:
train_set = (
new_dataset["train"]["datasets"] if num_workers == 1 else new_dataset["train"]["dataset"]["datasets"]
)
train_set["features"]["files"] = [CachedFile(f) if local_caching else f for f in feature_files]
train_set["alignments"]["files"] = [CachedFile(f) if local_caching else f for f in alignment_files]
return new_dataset
And then I'm using that function to declare the training dataset as follows:
distribute_files_dataset["train"] = {
"class": "DistributeFilesDataset",
"files": list(zip(train_dataset["features"]["files"], train_dataset["alignments"]["files"])),
"get_sub_epoch_dataset": functools.partial(
get_sub_epoch_dataset,
{
"dataset": dataset,
"num_workers": num_workers,
"local_caching": local_caching,
},
),
"partition_epoch": partition_epoch,
"seq_ordering": train_set_files_sorting,
}
I first examined i6_core.serialization
and tried substituting functools.partial
by i6_core.serialization.PartialImport
, which seems to be exactly what I need, but I'm running into issues (see rwth-i6/i6_core#513).
Then I tried what you proposed, substituting functools.partial
by CodeWrapper(DelayedFormat("lambda files: get_sub_epoch_dataset(files, dataset={}, ...", [dataset, ...]))
and have some ReturnnConfig
object with the python_epilog
updated:
ReturnnConfig(
config={},
python_epilog=[
serialization.Import(get_sub_epoch_dataset, import_as="get_sub_epoch_dataset"),
distribute_files_dataset
]
)
But I'm also running into errors, all related to the fact that the python code can't be serialized.
I also tried without the DelayedFormat
since I actually have all parameters except files
available at the time of writing the RETURNN config, but I'm also running into similar errors to the approach above:
RuntimeError: Could not serialize lambda files: get_sub_epoch_dataset(files, dataset={'train': {'class': 'MetaDataset', ...}}}, num_workers=1, local_caching=True)
I don't understand why you put distribute_files_dataset
into the python_epilog
. That should go into the normal config as usual.
The python_epilog
is only needed for any of the serialization
logic (e.g. serialization.Collection
, serialization.Import
).
But also, as said, you cannot just refer to get_sub_epoch_dataset
in the config like that. Maybe this simpler variant works?
...
"get_sub_epoch_dataset": functools.partial(
CodeWrapper("get_sub_epoch_dataset"),
**{
"dataset": dataset,
"num_workers": num_workers,
"local_caching": local_caching,
},
),
...
Or the variant with DelayedFormat
, as I was suggesting before. In any case, then you need to make get_sub_epoch_dataset
available via serialization.Import
via the python_epilog
.
But I'm also running into errors, all related to the fact that the python code can't be serialized.
I'm not sure what errors. But also, it's very wrong that you put distribute_files_dataset
into the python_epilog
.
I also tried without the
DelayedFormat
since I actually have all parameters exceptfiles
available at the time of writing the RETURNN config, but I'm also running into similar errors to the approach above
I'm not really sure how you can do that without DelayedFormat
or functools.partial
? I'm not really sure how you do it then.
I'm also not sure where you get the error exactly then.
Btw, your usage of num_workers == 1
seems very misleading. I guess you actually want to check whether the MultiProcDataset
was used? But then checking for num_workers == 1
doesn't really make sense, as you could also use MultiProcDataset
with num_workers = 1
(which is a valid use case, although maybe not so much with DistributeFilesDataset
). Maybe you want to check num_workers == 0
? num_workers
tells you how many sub procs there are. But if you actually want to check whether MultiProcDataset
was used, why not just check directly for it, like ...dataset["class"] == "MultiProcDataset"
? Seems much more clear to me.
I don't understand why you put
distribute_files_dataset
into thepython_epilog
. That should go into the normal config as usual. [...]
Thanks for the feedback. I think I managed to make it work via serialization.PartialImport
, which I found suitable for this purpose. I'm now testing the actual functionality.
For future reference, what I ended up doing was:
distribute_files_dataset = copy.deepcopy(my_dataset)
distribute_files_dataset["train"] = {
"class": "DistributeFilesDataset",
"files": feat_align_pairs,
"get_sub_epoch_dataset": returnn.config.CodeWrapper("get_sub_epoch_dataset"),
"partition_epoch": partition_epoch,
"seq_ordering": train_set_files_sorting,
}
And then:
distribute_files_train_config = returnn.ReturnnConfig(
config=distribute_files_train,
python_prolog=[
serialization.Collection(
serializer_objects=[
serialization.PartialImport(
code_object_path=get_sub_epoch_dataset,
unhashed_package_root="apptek_asr.meta.nn",
hashed_arguments={
"dataset": dataset_without_files,
"num_workers": num_workers,
"local_caching": local_caching,
},
unhashed_arguments={},
import_as="get_sub_epoch_dataset",
),
]
)
],
)
returnn_training_config.update(distribute_files_train_config)
Since this was solved without the need of further implementation, I'll close this issue.
Btw, your usage of num_workers == 1 seems very misleading. [...]
This is what we had in our codebase. I'll check with my colleagues in case we need to change this. Thanks for the feedback!
For future reference, the error I got when using the suggestion from @albertz:
[FIRST APPROACH]
"get_sub_epoch_dataset": functools.partial(
CodeWrapper("get_sub_epoch_dataset"),
dataset=dataset,
num_workers=num_workers,
local_caching=local_caching,
),
But I get this error in the manager:
...
590 returnn.config.CodeWrapper("get_sub_epoch_dataset"),
591 dataset=dataset,
592 num_workers=num_workers,
593 local_caching=local_caching,
594 ),
595 "partition_epoch": partition_epoch,
596 "seq_ordering": train_set_files_sorting,
597 }
599 return distribute_files_dataset
TypeError: the first argument must be callable
Indeed, the CodeWrapper
object isn't callable; we'd have to do many things for that to work, and then I'd prefer to just use Import
or PartialImport
.
I also tried directly
[SECOND APPROACH]
"get_sub_epoch_dataset": functools.partial(
get_sub_epoch_dataset,
dataset=dataset_without_files,
num_workers=num_workers,
local_caching=local_caching,
),
But also crashed, this time in the create_files
task of the ReturnnTrainingJob
:
RuntimeError: Could not serialize functools.partial(<function get_sub_epoch_dataset at 0x7f80eff8e3b0>, dataset={'train': ...}
Not only because the function couldn't be obviously serialized properly before the call to functools.partial
, but also because the tk.Path
s used for the model training weren't able to be serialized either.
FInally I tried directly inserting the functools.partial
function as stated above in the python prolog of the RETURNN config and calling it via CodeWrapper
:
python_epilog=functools.partial(
CodeWrapper("get_sub_epoch_dataset"),
dataset=dataset,
num_workers=num_workers,
local_caching=local_caching,
),
...
"get_sub_epoch_dataset": returnn.config.CodeWrapper("get_sub_epoch_dataset")
But it failed with exactly the same error: TypeError: the first argument must be callable
.
I also tried the equivalent without CodeWrapper
in functools.partial(CodeWrapper(...), ...)
, but it also failed with a similar error to [SECOND APPROACH]:
RuntimeError: Could not serialize functools.partial(<function get_sub_epoch_dataset at 0x7f80eff8e3b0>, dataset={'train': ...}
I don't exactly understand. When you use functools.partial(CodeWrapper("get_sub_epoch_dataset"), ...)
, when this is in the config
part of ReturnnConfig
, it should be serialized to functools.partial(get_sub_epoch_dataset, ...)
(just via repr
). Is this the case? Where exactly do you get this TypeError
? How does the serialized config look like?
After trying myself: Oh, I see, this error happens when even creating the functools.partial
object, not at serialization or at loading the config. This error comes from a check within functools.partial
.
It would be good if you would provide such important aspects. It is always relevant:
- Where exactly does the error happen (manager or job worker? also stack trace),
- how does the serialized config look like,
- how does your code look like.
So then we cannot use functools.partial
this way here, as CodeWrapper
is not callable.
But we don't need to do it this way. You can also just use DelayedFormat
here for the config
part, i.e. "get_sub_epoch_dataset": DelayedFormat(...)
. Or just "get_sub_epoch_dataset": CodeWrapper("get_sub_epoch_dataset")
.
FInally I tried directly inserting the
functools.partial
function as stated above in the python prolog of the RETURNN config and calling it via"get_sub_epoch_dataset": returnn.config.CodeWrapper("get_sub_epoch_dataset")
But the first approach failed with exactly the same error, and the second approach failed with a very similar error (can't serialize
<function ...>, ...
).
I again don't exactly understand what error you get with what approach at what point. Why do you say "first approach"? There is only a single approach here in this section, namely "get_sub_epoch_dataset": CodeWrapper("get_sub_epoch_dataset")
. And what error? And where? And what second approach? Or do you refer to this as the third approach here? So this works?