mlverse/torch

Cloning of torch module behaves unexpectedly

sebffischer opened this issue · 3 comments

Thanks a lot already for fixing the class of the cloned torch module!
I still observed some differences between the cloned object and the clonee, see below where the weight of the linear layer is missing the "nn_parameter" class.

library(mlr3torch)
#> Loading required package: mlr3
#> Loading required package: mlr3pipelines
#> Loading required package: torch

a = nn_linear(1, 1)

b = a$clone(deep = TRUE)

a
#> An `nn_module` containing 2 parameters.
#> 
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]

b
#> An `nn_module` containing 2 parameters.
#> 
#> ── Parameters ──────────────────────────────────────────────────────────────────
#> • weight: Float [1:1, 1:1]
#> • bias: Float [1:1]

b$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7"

a$parameters$weight |> attributes()
#> $class
#> [1] "torch_tensor" "R7"           "nn_parameter"

Created on 2023-12-18 with reprex v2.0.2

I tried to do this myself, but because clone is autogenerated, I had to change the Declarations file (I believe), but when I try to autogenerate the functions I get:

> torchgen::generate("~/gh/torch")
Starting code generation ...
Error in `purrr::map_chr()`:In index: 1.
Caused by error:
! Result must be length 1, not 0.

This kind of still breaks when one wants to do something like nn_parameter(torch_tensor(1))$clone().
While admittedly this might be less important, I think the (autogenerated)torch_clone() method itself should just preserve the attributes.

I see, you are right! Perhaps you can add a method like clone_with_attributes around here:

print = function(n = 30) {

With the behavior you desire and then I can make the plumbing for it to be the default clone.

TBH I haven't executed much torchgen like this, almost always doing load_all() then generate() with the defaults. So maybe there's a hardcode path that doesn't work properly when running from a different directory.