Implementing beam propagation in pytorch
askaradeniz opened this issue ยท 17 comments
It would be useful to have a beam propagation function written in pytorch for deep learning.
We can rewrite the Fresnel and Fraunhofer beam propagator defined in here with torch.
I have added a new submodule named odak.learn
. It is supposed to be imported separately then the rest of the library as in below example.
import odak
import odak.learn
Note that once odak.learn
is imported, it also imports torch
. I have also added torch
to the requirements.txt as well.
We can initiate our effort for a beam propagator compiled in torch
using this file in the odak.learn
submodule.
For now, it is a good idea to fork the repository as is. We can later merge the forked repository once you feel confident that you have a working beam propagator running in torch. Make sure you also have a working and verified test routine that can go under test
directory in the root.
If all goes well, I promise to grant you rights to this repository as a collaborator.
It seems working with torch now.
To write beam propagation function in pytorch, I used the functions in torch.fft module.
fftshift
function is newly added to pytorch (https://pytorch.org/docs/master/fft.html). So, I updated the torch to nightly version and updated the requirements file accordingly.
I have written the function in classical.py
under odak.learn
submodule. Basically, it was enough to replace the numpy functions with torch.
I also added a test script test_learn
which is very similar to test_beam_propagation
. In main
function, beam propagation with torch is used. In compare
function, I compared the results with propagating in numpy. Results seem close to each other up to 3 decimal points.
Sounds awesome. Do you have a dedicated fork where you try these things or are you conducting such experiments in your local?
How should we move with this one? Shall I grant the collaborator access so that you can push it to the main branch. Or do you want me to merge the main branch with your fork? Please advise.
I did the experiments with my own fork.
I opened a pull request for this feature. You can review and merge it into the main branch.
I just saw your pull request. However, your version does not compile with travis CL, and I believe this is due to the requirements.txt
. The torch version you choose to set in that file is 1.8.0+dev...
. This isn't a major version, travis CL does not have that in it's stack of libraries, and we can not expect future users to rely on such a version. I believe torch has support for fft since 1.7.0, here is the documentation for that. May I ask of you to switch to that torch version and verify your code? Once this is verified, update your requirements.txt
with torch version 1.7.0
, and finally raise an another merge request.
Probably, it is also a good idea to have the test routine named as test_learn_beam_propagation.py
as hopefully there will be more additions to the learn
submodule.
About this specific line, from_numpy syntax can work with Numpy, however I am not sure if it can work with Cupy
. Perhaps it is a good idea to add a check as in below for making sure that it works both with Cupy
and Numpy
:
if np.__name__ == 'cupy':
sample_field = np.asnumpy(sample_field)
Same goes for other variables that may be effected from torch-cupy situation.
Thank you for the review. I will make the necessary changes soon.
In pytorch 1.7.0, there is fft module but fftshift
, ifftshift
functions do not exist. There are implementations of these functions in neural-holography utilities. Can we use them?
I believe so. But please let the authors know that you will be using a part of their implementations. Double check with the license of their repository and with them.
I would also suggest to have an another issue as a reminder to our selves for migrating to torch 1.8.0 in the future. I guess we can do that once it is a stable release.
Once you brought fftshift
and ifftshift
, having them in a separate script named toolkit.py
in the learn
submodule also makes sense to me. What do you suggest in this case?
I also think having a toolkit.py
script is a good idea. We may write additional utility functions for learn
there.
@askaradeniz I would also suggest to have an another issue as a reminder to our selves for migrating to torch 1.8.0 in the future. I guess we can do that once it is a stable release.
I am keeping my promise and providing you a contributor access to this repository as well ๐ ๐ฅ
Thanks!
According to this issue torch.meshgrid
and numpy.meshgrid
behaves differently.
It looks like they produce matrices that are transpose of each other. I made a small change at this line for this reason:
https://github.com/kunguz/odak/blob/2a7c900e49b37b8bf6ecd4e6ce785c40e06f5ea1/odak/wave/classical.py#L28
https://github.com/kunguz/odak/blob/2a7c900e49b37b8bf6ecd4e6ce785c40e06f5ea1/odak/learn/classical.py#L29
Nicely done! It is good to be able to ditch transpose with this simple trick.