tracel-ai/burn

Default tensor dtype can be configured on the device instead of by a type.

Opened this issue · 3 comments

Instead of defining the default dtype on the backend implementation type, we could set it as a context.

type Backend = Cuda<bf16>;

to

let device: Device<Cuda> = Default::default();
device.dtype_global(DType::BF16);
let tensor = Tensor::random(...);
assert_eq!(tensor.dtype(), DType::BF16);

or

let device: Device<Cuda> = Default::default();
device.with_dtype(DType::BF16, || {
    let tensor = Tensor::random(...);
    assert_eq!(tensor.dtype(), DType::BF16);
});

The challenge would be to define the default behavior: does the default dtype impact the current device across all threads, which could have unintended side effects? It could also be based on the StreamId (so the thread ID on native). The scoped approach also fixes the problem of duplicating all tensor initialization APIs to allow for a custom default dtype.

Hmmmmm. Ok, I do like policy-driven-devices. We see them in other things, like lowering precision, or enabling/disabling async ordering nondeterminism. I am generally pro- on policy at the Device level.

Things get a bit weird when we're working with multiple devices; and/or Proxy-Devices (one approach to NUMA, or remote devices); and the interaction between different policies.

let device: Device<Cuda> = Default::default();
device.with_dtype(DType::BF16, || {
   let tensor = Tensor::random(...);
   assert_eq!(tensor.dtype(), DType::BF16);
});

This would mean I would end up doing

let tensor_f32 = device.with_dtype(DType::F32, || {
  Tensor::random(...)
};

let tensor_bf32 = device.with_dtype(DType::BF32, || {
  Tensor::random(...)
};

let tensor_f16 = device.with_dtype(DType::F16, || {
  Tensor::random(...)
};

What are the reasons for not going for

let tensor_f32 = Tensor::random(..., DType::F32);
let tensor_bf32 = Tensor::random(..., DType::BF32);
let tensor_f16 = Tensor::random(..., DType::F16);

?

@torsteingrindvik

We would have to create the following:

let tensor_f32 = Tensor::random_dtype(..., DType::F32);

So now every tensor creation operation has two variants. This is fine, but if you don't provide a way to easily change the default dtype, then the DType needs to be passed as config across an entire codebase. What I'm suggesting is to attach those defaults to a device, which the user can change at runtime to execute a block of code with different defaults. It could be for element types, for memory management strategy (allocations), etc.