RecursionError when attempting to unpickle objax objects
bytbox opened this issue · 2 comments
This is with objax-1.4.0, from PyPI. The problem does not occur with objax-1.3.1.
For example:
(env) theseus:~/mlqm/de$ python
Python 3.9.2 (default, Feb 20 2021, 18:40:11)
[GCC 10.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle, objax
>>> try:
... pickle.loads(pickle.dumps(objax.nn.Linear(3,1)))
... except RecursionError:
... print('Recursion error')
...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Recursion error
The error is coming from loads
, not dumps
. The resulting stack trace ends with:
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
return self._value
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
return self._value
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
return self._value
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
return self._value
RecursionError: maximum recursion depth exceeded
I think it's probably related to the fact that we added __getattr__
method to variables which may cause issues when restoring from pickle.
Here is stackoverflow discussion related to __getattr__
and pickle: https://stackoverflow.com/questions/49380224/how-to-make-classes-with-getattr-pickable
I think the fix would be to change how value
property works in variables:
class TrainVar(BaseVar):
@property
def value(self) -> JaxArray:
return self.__dict__['_value'] # instead of self._value
and possibly similar changes for TrainRef
and StateVar
.
Would you be able to try these changes to see if they work or not?
This doesn't seem to work:
Traceback (most recent call last):
File "/home/scott/objax/tests/pickle.py", line 33, in test_on_linear
lin_ = pickle.loads(pickled)
File "/home/scott/objax/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/objax/objax/variable.py", line 180, in value
return self.__dict__['_value']
KeyError: '_value'
And indeed, if I print out self.__dict__
the line before the return
in value()
, it's just empty.
However, I can fix the error by following that link a little more closely, and changing BaseVar
to throw an AttributeError
when _value
is not available. That passes all tests and appears to fix this error. (But, I'm not super familiar with how objax internals work, so maybe it's obvious that that's the wrong thing to do. Let me know!)
A pull request follows...