Cloning of torch module behaves unexpectedly
sebffischer opened this issue · 3 comments
Thanks a lot already for fixing the class of the cloned torch module!
I still observed some differences between the cloned object and the clonee, see below where the weight
of the linear layer is missing the "nn_parameter" class.
library(mlr3torch)
#> Loading required package: mlr3
#> Loading required package: mlr3pipelines
#> Loading required package: torch
a = nn_linear(1, 1)
b = a$clone(deep = TRUE)
a
#> An `nn_module` containing 2 parameters.
#>
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]
b
#> An `nn_module` containing 2 parameters.
#>
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]
b$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7"
a$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7" "nn_parameter"
Created on 2023-12-18 with reprex v2.0.2
I tried to do this myself, but because clone
is autogenerated, I had to change the Declarations file (I believe), but when I try to autogenerate the functions I get:
> torchgen::generate("~/gh/torch")
Starting code generation ...
Error in `purrr::map_chr()`:
ℹ In index: 1.
Caused by error:
! Result must be length 1, not 0.
This kind of still breaks when one wants to do something like nn_parameter(torch_tensor(1))$clone()
.
While admittedly this might be less important, I think the (autogenerated)torch_clone()
method itself should just preserve the attributes.
I see, you are right! Perhaps you can add a method like clone_with_attributes
around here:
Line 14 in 0e9fdd7
With the behavior you desire and then I can make the plumbing for it to be the default clone
.
TBH I haven't executed much torchgen like this, almost always doing load_all() then generate()
with the defaults. So maybe there's a hardcode path that doesn't work properly when running from a different directory.