Can you provide test code?
Closed this issue · 9 comments
Thanks for your good work, loading saved clip model parameters failed.
Is it in SeqPAR or PromptPAR?
PromptPAR
the training results are normal, model and clip_model weight yet were saved, but clip_model weight can't be loaded by clip.load()
model = build_model(wt['clip_model'])
Traceback (most recent call last):
File "", line 1, in
File "OpenPAR/PromptPAR/clip/model.py", line 564, in build_model
model.load_state_dict(state_dict,strict=False)
File "CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CLIP:
size mismatch for transformer.prompt_text_deep: copying a param with shape torch.Size([12, 3, 5, 768]) from checkpoint, the shape in current model is torch.Size([12, 0, 5, 768]).
how to convert to onnx?
I think the parameter for text prompt length in config was not set correctly during the test.
because, I used 5 classes with text attributes modified.
oh, could you provite test code example?
clip_model.visual
train_logits,final_similarity = model(imgs.cuda(),clip_model=clip_model)
Traceback (most recent call last):
File "", line 1, in
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/OpenPAR/PromptPAR/base_block.py", line 31, in forward
clip_image_features,all_class,attenmap=clip_model.visual(imgs.type(clip_model.dtype))
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/OpenPAR/PromptPAR/clip/model.py", line 326, in forward
x,all_class,attnmap = self.transformer(x)
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/OpenPAR/PromptPAR/clip/model.py", line 253, in forward
x,attn_output_weights = block(x,self.visual_mask)
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/OpenPAR/PromptPAR/clip/model.py", line 202, in forward
attn_output, attn_output_weights = self.attention(self.ln_1(x),visual_mask)
File "/OpenPAR/PromptPAR/clip/model.py", line 198, in attention
return self.attn(x, x, x, need_weights=True, attn_mask=self.attn_mask)
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/CLIP/lib/python3.8/site-packages/torch/nn/modules/activation.py", line 978, in forward
return F.multi_head_attention_forward(
File "/CLIP/lib/python3.8/site-packages/torch/nn/functional.py", line 4235, in multi_head_attention_forward
raise RuntimeError('The size of the 2D attn_mask is not correct.')
RuntimeError: The size of the 2D attn_mask is not correct.
We've updated the test code and Readme so that you can use the provided code for testing purposes. And we can't provide a method to convert to onnx, you may need to implement it yourself
thank you,indeed, the test was ok.
the two parameters may be missing: vis_prompt, use_textprompt.
it was successful when pytorch was converted to onnx, but excute below code occured error as following:
sess = onnxruntime.InferenceSession('./onnx/clip_model.onnx')
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ./onnx/clip_model.onnx failed:Node (Concat_5586) Op (Concat) [ShapeInferenceError] All inputs to Concat must have same rank. Input 1 has rank 4 != 3
sentences that may cause errors as below:
x=torch.cat([x[:77,:,:],self.prompt_text_deep[layer].to(x.dtype).to(x.device)],dim=0)
if args.use_div :
x = torch.cat([x[:,:1],self.part_class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) ,x[:,1:]],dim=1)
%5586 = Concat[axis = 0](%5561, %5580, %5585)
%5561 = Slice(%5549, %5558, %5559, %5557, %5560)
%5580 = Cast[to = 1] (%5579)
%5585 = Slice(%5549, %5582, %5583, %5581, %5584)