Impossible to create a module with a parameter that lacks a const shape?
emchristiansen opened this issue · 1 comments
emchristiansen commented
I'd like to create a module from a struct like this:
#[derive(Debug, Clone)]
pub struct MyNet<E, D>
where
E: Dtype,
D: Device<E>,
{
pub metadata: MyMetadata,
pub logits: Tensor<(usize,), E, D, NoneTape>,
}
However, I run into two issues when impl'ing TensorCollection
:
- Minor issue:
TensorCollection
assumes the module can be constructed from purely the Tensor values and nothing else, meaning I actually need to refactor the struct to removemetadata
, which is inconvenient. - Major issue: To register
logits
when callingvisitor.visit_fields
I need to callSelf::tensor
, which expects aTensorOptions
. But,TensorOptions
can only be constructed ifS: ConstShape
as that is assumed by all the construction methods andTensorOptions
is marked non-exhaustive. But even if this were relaxed I wouldn't be able to constructlogits
without knowing its (runtime) shape, which I'm unable to determine given the limited context allowed by the signature.
Is this a correct diagnosis of the limitations of the module API?
If so, do you suggest a workaround?
coreylowman commented
This will be addressed with the nn rewrite