coreylowman/dfdx

Impossible to create a module with a parameter that lacks a const shape?

emchristiansen opened this issue · 1 comments

I'd like to create a module from a struct like this:

#[derive(Debug, Clone)]
pub struct MyNet<E, D>
where
  E: Dtype,
  D: Device<E>,
{
  pub metadata: MyMetadata,
  pub logits: Tensor<(usize,), E, D, NoneTape>,
}

However, I run into two issues when impl'ing TensorCollection:

  1. Minor issue: TensorCollection assumes the module can be constructed from purely the Tensor values and nothing else, meaning I actually need to refactor the struct to remove metadata, which is inconvenient.
  2. Major issue: To register logits when calling visitor.visit_fields I need to call Self::tensor, which expects a TensorOptions. But, TensorOptions can only be constructed if S: ConstShape as that is assumed by all the construction methods and TensorOptions is marked non-exhaustive. But even if this were relaxed I wouldn't be able to construct logits without knowing its (runtime) shape, which I'm unable to determine given the limited context allowed by the signature.

Is this a correct diagnosis of the limitations of the module API?
If so, do you suggest a workaround?

This will be addressed with the nn rewrite