Torch "View"
Closed this issue · 4 comments
Prerequisites
Please make sure to check off these prerequisites before submitting a bug report.
- Test that the bug appears on the current version of the master branch. Make sure to include the commit hash of the commit you checked out.
- Check that the issue hasn't already been reported, by checking the currently open issues.
- If there are steps to reproduce the problem, make sure to write them down below.
- If relevant, please include the hls4ml project files, which were created directly before and/or after the bug.
Quick summary
In hls4ml/converters/pytorch/reshape.py
there is
reshape_layers = ['View']
and the matching code in def parse_reshape_layer
. If one tries to use, e.g., variable.View(-1,N)
the PyTorch fails to compile. Changing View
to view
seems to work. Below is the code diff and a reproducer.
Details
Steps to Reproduce
- Clone the hls4ml repository
- Checkout the master branch, with commit hash: [5c0c4e6]
- Run the following code
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from hls4ml.converters import convert_from_pytorch_model
from hls4ml.utils.config import config_from_pytorch_model
test_root_path = Path(__file__).parent
if __name__ == "__main__":
class test(nn.Module):
def __init__(self, n_in, n_out, size_in, momentum=0.2):
super().__init__()
self.view_mult = n_out * size_in
self.conv1 = nn.Conv1d(
n_in,
n_out,
kernel_size=3,
padding=1,
bias=False,
)
def forward(self, x):
z = self.conv1(x)
z = z.view(-1, self.view_mult)
return z
n_in = 2
n_out = 4
size_in = 1024
n_batch = 16
model = test(n_in, n_out, size_in)
model = model.to(memory_format=torch.channels_last)
model.eval()
print(model)
X_input = np.random.rand(n_batch, n_in, size_in)
pytorch_prediction = model(torch.Tensor(X_input)).detach().numpy()
# X_input is channels last
X_input = np.ascontiguousarray(X_input.transpose(0, 2, 1))
config = config_from_pytorch_model(model,
inputs_channel_last=True,
transpose_outputs=False)
config['Model']['Strategy'] = 'Resource'
config['Model']['Precision'] = 'ap_fixed<64,24>'
print(config)
backend='Vivado'
output_dir = str(test_root_path / f'hls4mlprj_test_{backend}_io_stream')
hls_model = convert_from_pytorch_model(
model,
(None, n_in, size_in),
hls_config=config,
output_dir=output_dir,
backend=backend,
io_type='io_stream',
)
print(list(hls_model.get_layers()))
hls_model.compile()
print("pytorch_prediction")
print(pytorch_prediction)
print("pytorch_prediction.shape: ", end=" ")
print(pytorch_prediction.shape)
# reshape hls prediction to channels last, then transpose, then reshape
# to match .view
hls_prediction = np.reshape(
np.transpose(
np.reshape(hls_model.predict(X_input),
(n_batch, size_in, n_out)),
(0,2,1)
),
(n_batch, size_in * n_out)
)
print("hls_prediction")
print(hls_prediction)
print("hls_prediction.shape: ", end=" ")
print(hls_prediction.shape)
rtol = 0
atol = 5.0e-2
np.testing.assert_allclose(hls_prediction,
pytorch_prediction,
rtol=rtol, atol=atol)
Expected behavior
Success.
Actual behavior
test(
(conv1): Conv1d(2, 4, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
)
{'Model': {'Precision': 'ap_fixed<64,24>', 'ReuseFactor': 1, 'InputsChannelLast': True, 'TransposeOutputs': False, 'Strategy': 'Resource'}}
Interpreting Model ...
Topology:
Layer name: conv1, layer type: Conv1D, input shape: [[None, 2, 1024]]
Traceback (most recent call last):
File "/home/hls4ml-user/work/ewstapp_research/isolate/NETWORK/test_view.py", line 56, in <module>
hls_model = convert_from_pytorch_model(
File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/__init__.py", line 309, in convert_from_pytorch_model
return pytorch_to_hls(config)
File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/pytorch_to_hls.py", line 314, in pytorch_to_hls
raise Exception(f'Unsupported function {operation}')
Exception: Unsupported function view
Optional
Possible fix
Apply this patch:
diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py
index 37191135..c696fa6e 100644
--- a/hls4ml/converters/pytorch/reshape.py
+++ b/hls4ml/converters/pytorch/reshape.py
@@ -3,12 +3,12 @@ import numpy as np
from hls4ml.converters.pytorch_to_hls import pytorch_handler
from hls4ml.converters.utils import parse_data_format
-reshape_layers = ['View']
+reshape_layers = ['view']
@pytorch_handler(*reshape_layers)
def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
- assert operation == 'View'
+ assert operation == 'view'
layer = {}
layer['class_name'] = 'Reshape'
Results in success. If this fix seems correct I can make a PR.
hls4ml doesn't parse modules/operations directly, it parses the torch.fx trace graph. sometimes this corresponds to the operation/module/layer you use directly, sometimes it's a lower-level operation you don't know. we need to investigate if the operation only ever appears as view
(implying current behavior is a bug), or there's View
and view
with somewhat different semantics and we need support for both. View
could be the result of nn.Flatten
, but I'm not sure, needs checking.
I have seen this difference in capitalization when the same operation is implemented both as a layer object and as a function. I'll have a look to confirm.
Yes, the differences in capitalization between classes and functions are common. e.g., Conv2d
vs conv2d
. I don't think PyTorch has a View
class.
I remembered how we handled this. We basically decided to follow the capitalization of the classes, so first letter capitalized, and have this map https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/converters/pytorch_to_hls.py#L84-L98 to map the lower-case functions to the right name. In this case it's a little bit meaningless, but for consistency, the solution is to add a mapping of view
to View
here. I can make the PR for that.