Feature: Support pairwise multiplication in mul nodes
YalcinerMustafa opened this issue · 3 comments
I want to verify a neural network that (among other layers) also contains mul-nodes . The mul-nodes in my model compute the pairwise product of a constant vector and the variable input vector from a preceding layer. The constant node is of type float32[100]
.
See the attached model for more details.
merged_model.zip
At the moment, Marabou silently parses the whole network, including the mul-node and when a (dummy) verification task is triggered, it outputs the error seen in [1]. It seems as it assumes the constant factor to be a scalar and tries to pass it on in an equation.
I wrote a quick fix for my dummy-verification task by extending the functionality of the method MarabouNetworkONNX::mulEquations() with the diff indicated in [2] such that the variable multiple
is no longer assumed to be a scalar, consistent with the ONNX specification [3]. With these changes, it conducts a pairwise multiplication if the second input is recognized as a vector.
Let me know if there is something I am missing here.
My questions are:
- Is this a useful feature from your point of view?
- Do you have any particular implementation suggestion?
- Shall I open a pull request, such that you can review the changes?
Thank you in advance.
[1]
TypeError: addAddend(): incompatible function arguments. The following argument types are supported:
1. (self: maraboupy.MarabouCore.Equation, arg0: float, arg1: int) -> None
Invoked with: <maraboupy.MarabouCore.Equation object at 0x7f4cc4ecc6b0>, array([0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,
1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,
1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.,
0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0.,
1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1., 0., 1.],
dtype=float32), 0
[2]
diff --git a/maraboupy/MarabouNetworkONNX.py b/maraboupy/MarabouNetworkONNX.py
index 0ebf69ee..ca3a9b49 100644
--- a/maraboupy/MarabouNetworkONNX.py
+++ b/maraboupy/MarabouNetworkONNX.py
@@ -1119,6 +1119,7 @@ class MarabouNetworkONNX(MarabouNetwork.MarabouNetwork):
return
multiple = self.constantMap[inputName2]
+ multiple_is_array = hasattr(multiple,"__len__")
input1 = self.varMap[inputName1]
outputVariables = self.makeNewVariables(nodeName)
input1 = input1.reshape(-1)
@@ -1126,7 +1127,10 @@ class MarabouNetworkONNX(MarabouNetwork.MarabouNetwork):
for i in range(len(input1)):
e = MarabouUtils.Equation()
- e.addAddend(multiple, input1[i])
+ if multiple_is_array:
+ e.addAddend(multiple[i], input1[i])
+ else:
+ e.addAddend(multiple, input1[i])
e.addAddend(-1, outputVariables[i])
e.setScalar(0.0)
self.addEquation(e)
[3] https://github.com/onnx/onnx/blob/main/docs/Operators.md#mul
Hi @YalcinerMustafa , thanks a lot for finding out about this issue. We are actually in the process of replacing the python ONNX parser with a python binding of a C++ ONNX parser. Therefore, your suggested change would be useful in the relatively short term, but will be soon replaced.
@MatthewDaggitt is in charge of this effort. @MatthewDaggitt, could you please incorporate the suggested change in the C++ parser?
Hello @wu-haoze and @MatthewDaggitt ,
thank you for the quick response and efforts improving Marabou.
I see, I will work with a local patch then and hopefully switch to the new C++ python bindings when they are ready.
Best,
Mustafa
@YalcinerMustafa I am running into a similar issue and was wondering how your local patch works?