wangkuiyi/gotorch

Do we really need Module.Init?

wangkuiyi opened this issue · 1 comments

From the definition as the following, I understand the only purpose for user-defined modules to call Module.Init in their newers is to let each sub-module know about its parent.

gotorch/nn/module.go

Lines 46 to 64 in 047d424

func (m *Module) Init(outer IModule) {
if m.outer != nil {
return
}
moduleType := reflect.TypeOf(m).Elem()
fv := reflect.ValueOf(outer).Elem()
for i := 0; i < fv.NumField(); i++ {
v := fv.Field(i)
f := fv.Type().Field(i)
if f.Type == moduleType && f.Name == moduleType.Name() {
if v.Addr() == reflect.ValueOf(m) {
// Calling Init in a valid Module: struct{*Module} or struct{Module}
m.outer = outer
m.isTraining = true
}
}
}
torchCheck(m.outer != nil, "GoTorch requires defining modules via embedding a `Module` struct by value")
}

Is this purpose due to the requirement that when the user calls a module's To or ZeroGrad method, we can trace up to the top ancestor of the sub-module hierarchy and make sure that all modules in the hierarchy move to the specified device or have all parameter gradients cleared?

If this is the reasoning behind Module.Init, I am afraid that the implementation of To and ZeroGrad are not tracing up to the root; instead, I see them simply call m.outer.

gotorch/nn/module.go

Lines 97 to 101 in 047d424

func (m *Module) To(device torch.Device) {
// TODO(shendiaomo): to be implemented after the `To` method of `Tensors` is ready
moduleType := reflect.TypeOf((*IModule)(nil)).Elem()
tensorType := reflect.TypeOf((*torch.Tensor)(nil)).Elem()
sv := reflect.ValueOf(m.outer).Elem() // Elem gets what the pointer points to.

Oh. I got it -- I misunderstood the design.

The purpose of Module.Init is to store the address of the newly instantiated user-defined module, say BatchNormModule, into its first field, Module, as the outer field. Later, when users call BatchNormModule.To("cuda"), they actually call Module.To, which could follow Module.outer to know the address of the BatchNormModule instance and recursively move all tensor-typed fields and submodules to the specified device "cuda".