coreylowman/dfdx

trait TryConcatAlong not satisfied when using constants

opfromthestart opened this issue · 1 comments

I half think this is a rustc bug but I wanted to make sure. When I try to concatenate tensors whose dimensions are specified by constants, but not literals, it gives me an error. Here is a complete example along with the error it gives. As you can see, when using X and Y the program does not compile, but substituting their values does work. In this case, the function normal does not compile while the function fake does.

#![feature(generic_const_exprs)]
use dfdx::shapes::Rank2;

use dfdx::{
    shapes::Axis,
    tensor::{AutoDevice, Tensor, ZerosTensor},
    tensor_ops::TryConcatAlong,
};

fn main() {
    normal::<10>();
    fake::<10>();
}

const X: usize = 64;
const Y: usize = 32;
const Z: usize = X + Y;

fn normal<const B: usize>() {
    let d = AutoDevice::default();
    let t1: Tensor<Rank2<B, X>, _, _, _> = d.zeros();
    let t2: Tensor<Rank2<B, Y>, f32, _, _> = d.zeros();
    let t3: Tensor<Rank2<B, Z>, f32, _, _> = (t1, t2).concat_along(Axis::<1>);
}
fn fake<const B: usize>() {
    let d = AutoDevice::default();
    let t1: Tensor<Rank2<B, 64>, f32, _, _> = d.zeros();
    let t2: Tensor<Rank2<B, 32>, f32, _, _> = d.zeros();
    let t3: Tensor<Rank2<B, Z>, f32, _, _> = (t1, t2).concat_along(Axis::<1>);
}

The error is

error[E0599]: the method `concat_along` exists for tuple `(Tensor<(Const<B>, Const<X>), _, Cuda>, Tensor<(Const<B>, Const<Y>), f32, Cuda>)`, but its trait bounds were not satisfied
  --> src/main.rs:23:55
   |
23 | ... f32, _, _> = (t1, t2).concat_along(Axis::<1>);
   |                           ^^^^^^^^^^^^ method cannot be called due to unsatisfied trait bounds
   |
   = note: the following trait bounds were not satisfied:
           `((Const<B>, Const<X>), (Const<B>, Const<Y>)): TryConcatAlong<_>`
           which is required by `(Tensor<(Const<B>, Const<X>), f32, Cuda>, Tensor<(Const<B>, Const<Y>), f32, Cuda>): TryConcatAlong<_>`

Hmm yeah this is weird, agree that it's a rustc bug with const propagation (not quite sure what the correct term is for it). Will close for now