coreylowman/dfdx

Please minimize the requirements for the optimizers

emchristiansen opened this issue · 1 comments

Adam and the other optimizers expect the thing they're updating to impl TensorCollection, e.g. in this signature:

impl<M: TensorCollection<E, D>, D: Device<E>, E: Dtype> Optimizer<M, D, E> for Adam<M, E, D> {
    fn update(
        &mut self,
        module: &mut M,
        gradients: &Gradients<E, D>,
    ) -> Result<(), OptimizerUpdateError<D::Err>>;

   ...
}

Here, TensorCollection requires the implementation of iter_tensors:

/// Type alias that specifies the how a module's type changes when using a different dtype and/or
/// device.
type To<E2: Dtype, D2: Device<E2>>;

/// Specifies how to iterate through tensors or modules containted within this module, and how
/// to contruct this module given values for its fields. Returns `Err(_)` to indicate an error,
/// `Ok(None)` to indicate that there is no error and a module has not been built, and
/// `Ok(Some(_))` contains `Self::Output<E2, D2>`
fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
    visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err>;

AFAICT this requirement is overkill, as the optimizers don't need to reconstruct the module from its constituent tensors; they just need to mutate the tensors in place.
Also, though not stated here, it appears to assume there are default constructors for the constituent tensors, which is not generally true (see #839).

So, perhaps the optimizers should be refactored to rely on a trait that merely visits and mutates tensors inside existing modules?

Yep this will be addressed in nn rewrite