sisl/NNet

NNet class vs writeNNet format class

Closed this issue · 6 comments

I think the list/array of weights created by the NNet class constructor has a different shape than the one expected in writeNNet function. In any case, when I try to write the net from an NNet object to a file using writeNNet, I get an error.

You are welcome to have a look at test/TestNnetExtensions.py in my fork repository. Note that I wrote a new constructor (that gets attributes as arguments -- seems more convenient, certainly for my purposes) and turned the original constructor (from a file) into a class method; but I did not touch the code of that method. I considered changing the shape in NNet class, making it more consistent, but that would also affect the evaluation methods, hence complicating things in other ways.

Thanks! I'll have a look at your fork

Yes looks like readNNet/writeNNet have the weight matrices are transposed compared to the class constructor weight matrices. So far I've just tested readNNet/writeNNet for the methods that convert between nnet format and other formats. I haven't tested writing an nnet file from an existing NNet object. I'll push a fix soon and add some new test cases like yours to show everything is consistent.

Merged in some changes to the readNNet/writeNNet functions and converter methods so that the weight matrices are defined consistently. There's still some work to do to add tests to make sure these issues aren't re-introduced in the future, but hopefully now things will work for you.

Looks like it works - thanks!

Great!