Freeze class-level experiment attributes
Closed this issue · 0 comments
Problem
Currently there is nothing preventing a user from changing the values of class-level attributes of ExperimentConfig
s at runtime. This was ok in the old implementation but now, as each light_engine
is spawned using "forkserver"
, we run into a problem, the code:
import torch.multiprocessing as mp
import time
class BaseBabyAIGoToObjExperimentConfig:
MY_GREAT_VARIABLE = 3
def my_func(config):
print(config.MY_GREAT_VARIABLE)
if __name__ == "__main__":
mp = mp.get_context("forkserver") # Broken
# mp = mp.get_context("fork") # Works
BaseBabyAIGoToObjExperimentConfig.MY_GREAT_VARIABLE = 5
cfg = BaseBabyAIGoToObjExperimentConfig()
p = mp.Process(target=my_func, kwargs=dict(config=cfg))
print("main", cfg.MY_GREAT_VARIABLE)
p.start()
p.join()
will print 5
and then 3
when mp = mp.get_context("forkserver")
but 5
then 5
when mp = mp.get_context("fork")
. This means that a user might change the value of a class-level attribute before running training but this change will not propagate to the training process. We seem to need "forkserver"
for some CUDA reasons (?) so it's probably best to:
- Disallow the user from changing class-level variables, or
- Detect any such change and throw an error if the runner is called with such a, modified, config.
Solution
This requires:
Having some means by which to stop changes to class-level variables. One approach that goes part of the way there is the following pattern:
class FrozenClassVariables(type):
def __setattr__(cls, attr, value):
raise RuntimeError("Cannot edit class-level attributes.")
class SomeClass(object, metaclass=FrozenClassVariables):
yar = 3
if __name__ == "__main__":
try:
SomeClass.yar = 6 # Error
except Exception as _:
print("Threw exception")
SomeClass().bar = 12 # No error
I'm not sure how to make ExperimentConfig
s automatically have this metaclass but I presume it's possible.
Dependencies
None