ResNetV2 from model.zoo does not specify train arg for ResNetV2Block
a1302z opened this issue · 2 comments
Hi, I tried to use resnet_v2.ResNet18, but get an error when using this model, as the ResNetV2Block expects a train arg in the call function, which is not provided by the implementation.
As a minimal example:
from objax.zoo import resnet_v2
fake_data = np.random.randn(2, 3, 224, 224)
model = resnet_v2.ResNet18(
in_channels=3,
num_classes=1000,
)
model(fake_data)
This produces the following error:
Traceback (most recent call last):
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 488, in run_layer
return f(*args, **util.local_kwargs(kwargs, f))
TypeError: ResNetV2Block.__call__() missing 1 required positional argument: 'training'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 488, in run_layer
return f(*args, **util.local_kwargs(kwargs, f))
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 497, in __call__
args = self.run_layer(i, f, args, kwargs)
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 490, in run_layer
raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e
TypeError: Sequential layer[0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310> ResNetV2Block.__call__() missing 1 required positional argument: 'training'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/a/O/resnet_minimal_example.py", line 9, in <module>
model(fake_data)
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 497, in __call__
args = self.run_layer(i, f, args, kwargs)
File "/home/a/anaconda3/envs/o/lib/python3.10/site-packages/objax/nn/layers.py", line 490, in run_layer
raise type(e)(f'Sequential layer[{layer}] {f} {e}') from e
TypeError: Sequential layer[3] objax.zoo.resnet_v2.ResNetV2BlockGroup(
[0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310>
[1] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92b3fc7250>
) Sequential layer[0] <objax.zoo.resnet_v2.ResNetV2Block object at 0x7f92c4a38310> ResNetV2Block.__call__() missing 1 required positional argument: 'training'
Am I missing something here? If not I think one way to solve this would be to introduce a train() and eval() function, as in pytorch, which sets the training value as instance variable. I am happy to produce a pull request, if you think this is a good option.
Alex,
Thank you for your interest in Objax!
You need an additional model(.., trainining=True/False) parameter. See the example in:
https://github.com/google/objax/blob/master/examples/image_classification/imagenet_resnet50_train.py
Hope this helps!
Thanks, I just figured it out. Sorry for the distraction and thanks for the answer :)