coreylowman/dfdx

Consider helpers for accessing tensors from tuples and input wrappers

swfsql opened this issue · 1 comments

swfsql commented

Disclaimer: I'm a beginner on AI and am currently trying to implement a Unet as a study.
Also please consider this a draft idea.

The overall goal would be for the Sequential generation to be used in more cases if a Module can first access a tensor from a set (or tuple) of tensors.

For example:

#[derive(Default, Clone, Sequential)]
pub struct ResidualAdd2<T: Clone + std::fmt::Debug> {
    // input = Input
    pub split: SplitInto<(Id, Id)>,
    // input = (Input, Input)
    pub t: On<tuple::_0, T>, // access the first slot from the tuple and then pass it to the `T` module
    // input = (T::Output, Input)
    pub add: Add,
    // input = T::Output = Input
}

note: Id just forwards the input as-is, and Add just calls TryAdd for both tensors.

I'm not sure if this would be a good way to go, but the On module would apply the T module on the first tensor from the input flow, on the first from the tuple of two tensors.

I'm not sure if this is true, but by avoiding inserting more layering type information directly into the split field, we may be able to make use of some Modules that would otherwise need to be recursive. Although I wouldn't really be too happy with dealing with tuple indexes all around the Architecture.

If this direction has some worth in it, then maybe it would also be better for the tensors to be named and stored in structures, and maybe have some access derivation, such as:

#[input_wrapper] // generates a `mod split { .. }`
pub struct Split<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}

#[derive(Default, Clone, Sequential)]
pub struct ResidualAdd2<T: Clone + std::fmt::Debug> {
    // input = Input
    pub split: SplitInto<(Id, Id)>,
    // input = (Input, Input)
    pub input_to_wrapper: split::FromTuple, // converts from (A, B) into Split<A, B>
    // input = Split<Input, Input>
    pub t: On<split::forward, T>, // access the field `forward` and then pass it to the `T` module
    // input = Split<T::Output, Input>
    pub input_to_tuple: split::IntoTuple, // converts from Split<A, B> into (A, B)
    // input = (T::Output, Input)
    pub add: Add,
    // input = T::Output = Input
}

Where in this case, the effects would be the same but raw tuple indexes would no longer be used. The "access" concepts I imagined something inspired on how druid works (as far as generating a module with structs representing each field goes), although it's not clear whether going that way pays it off.

Thanks for reading! Please also consider this a draft idea.
Edit: will try to create a draft PR containing the macro attr derive for this.

As stated in the PR, I don't think this is a general demanded case and would make more sense to be an external library experiment, but feel free to ping in case anyone have some question or suggestion.