[Discuss][IR] Relax PrimValue
YuchenJin opened this issue · 12 comments
This is a draft proposal to show high-level ideas about supporting constant POD (plain-old-data) values (e.g., int, float, bool) in Relax IR.
Relax has first-class support of interacting with TVM FFI in the graph by calling into PackedFunc via call_dps_packed
and call_packed
. When integrating with third-party libraries such as LibTorch, third-party library functions may take constant arguments of POD data types. For example, ATen::unique takes optional boolean arguments sorted
and return_inverse
, and ATen::add takes optional integer argument alpha
.
So far we can represent such POD data types by wrapping them into a tvm attribute, but it needs special handling in the codegen and it’s hard to automatically generate such attributes.
To bridge the Relax IR with the runtime, we propose to introduce PrimValue
to Relax to express primitive POD values in the IR.
IR Representation
In this proposal, we introduce an expression R.PrimValue
to Relax. A PrimValue
is an IR construct that can represents POD values such as int, float, and boolean, and string.
It can be implemented as a wrapper around PrimExpr
, which can be tir::IntImm
, tir::FloatImm
, tir::Bool
, and tir::StringImm
, .
class PrimValue(Expr):
# a constant value that can be tir::IntImm/FloatImm/Bool/StringImm
value: PrimExpr
Type of PrimValue
We can reuse PrimType to represent POD values. Note PrimType is not a subtype of relax.ObjectType
because at runtime, PrimValue does not map to a TVM runtime Object, while it maps to TVMPODValue_
.
TVMScript example
The following code block demonstrates a simple Relax function with call_dps_packed
and call_packed
into libtorch functions:
@tvm.script.ir_module
class MyMod:
@R.function
def main(c0: Tensor((32, 32), "float32"), c1: Tensor((32, 32), "float32")):
# libtorch function with "out" as its suffix follows DPS calling convention
x = R.call_dps_packed("libtorch_add_out", (c0, c1), alpha=2, output_shape=(32, 32), dtype="float32")
# bind a PrimValue to a Var
approximate: PrimType = R.PrimValue("tanh")
y = R.call_dps_packed("libtorch_gelu_out", x, approximate, output_shape=(32, 32), dtype="float32")
out = R.call_packed("libtorch_unique", x, sorted=True, return_inverse=True, type_args=(Tensor(ndim=1, dtype="float32")))
return out
alpha=2
in the rhs of the first binding is parsed as aPrimValue
withtir::IntImm(2)
as the value.approximate
is a relax.Var ofPrimType
, and it can be passed intocall_dps_packed
as arguement.- Runtime: Relax VM has already support it because the VM register can store any
TVMRetValue
, which includeTVMPODValue_
andstd::string
.
Other considerations:
Scope: Right now this design note restricts PrimValue to represent int, float, bool, or string and to be standalone. We can think of possible directions in the future:
- Allow PrimValue to interact with TensorType. For example, we can define the semantics of
tensor / int
as in pytorch and numpy. - Allow PrimValue to interact with TupleType. Do we want to represent
Tuple(int)
in the IR? One caveat is TVM container (e.g.,runtime.Tuple
) does not support POD data types today.
Argument types used in Libtorch can be found here: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#registering-a-function-in-native_functionsyaml
Thanks @YuchenJin for the proposal! I have a few comments on that.
- With PrimValue, we now have two options for a Relax high-level operator API. Say we are designing the operator
batch_norm
, which takes five tensors ofdata
,gamma
beta
,moving_mean
andmoving_var
, together with an axis, and an epsilon value:- One option is to put all of them as operator input, since the axis and epsilon can be treated as PrimValues and can have type PrimType. With this design, the Call to a
batch_norm
operator will beCall(Op::Get("relax.nn.batch_norm"), {data, gamma, beta, moving_mean, moving_var, axis, epsilon}, Attrs())
- The other option is to wrap the axis and epsilon with
BatchNormAttrs
, as Relay does. Here the Call will look likeCall(Op::Get("relax.nn.batch_norm"), {data, gamma, beta, moving_mean, moving_var}, BatchNormAttrs(axis, epsilon))
call_packed_dps
inputs. As for high-level operator calls, we don’t have such limitations and thus can keep using attributes. But even so, we still need to elaborate the reason of our design choice. - One option is to put all of them as operator input, since the axis and epsilon can be treated as PrimValues and can have type PrimType. With this design, the Call to a
- As Yuchen mentioned above, it is still uncertain whether to support a tuple of integers as a valid Relax expression. So I’m curious about the LibTorch support on the cases where an operator may take a list of PrimValues as input (e.g.,
softmax
operator which takes a list of axes). Firstly I’m not quite sure how the API for such operators look like in LibTorch. And secondly, if such cases do exist, what is the current way to call a LibTorch softmax, for example?
I echo the thanks for writing up this proposal. As we discussed at the community meeting, I am a bit hesitant about including PrimValues as a separate type/value in the language because it would mean that there would be more than one way to represent the same value in Relax: We would have both rank-0 tensors (which represent scalars) and PrimValues. The choice of run-time representation (NDArray vs a TVM POD value) is a lower-level concern and perhaps should not require additional complexity in the front-end language, including in the type system. I would recommend using the NDArray representation wherever possible and converting in whichever operators do require PrimValues. We should profile the cost of such conversions--if the cost is small (imagining that we might have to do this multiple times in a model), I would say it's really not worth introducing the additional complexity into the front-end language. (The situation I would prefer to avoid is users getting mad that instantiating a constant results in the "wrong kind" of constant for whatever operator they're using, something which is likely to cause annoyance.)
If we do introduce PrimValues as a separate value in the language, it might be worth having some policy for when we would prefer to use a PrimValue and when we would prefer to use a 0-rank tensor. @psrivas2 suggested one policy, of using tensors for all Relax-level operators and using PrimValues only for interacting with low-level libraries.
That said, POD values can represent strings while NDArrays cannot, so perhaps one option might be to add a string constant node (or even allow strings to be passed as an argument to the Constant
expr) and define string values (represented using the existing TVM runtime string container).
I would also be hesitant to introduce a type that is not a subtype of Object
, since it would complicate some of the type system mechanics. I suppose this would be analogous to how primitive values in Java and C# are not Objects either. I think the situation of wrapping these primitive values in tensors/other containers would be analogous to boxing/unboxing in Java and C# (Java version, C# version).
Thanks everyone for the discussions at the community meeting today, and thanks @slyubomirsky and @psrivas2 for proposing alternative plans and summarizing the tradeoffs!
Introducing a new type indeed needs careful consideration. One problem of using 0-rank tensors to represent POD values is that tensor is device-specific, while these POD values are always on host. Pytorch and other libraries make such distinction between host-only value and device-specific tensor.
To reduce the complexity of writing passes, we can indeed restrict the Relax-level operators (for example comparison, addition, multiplications and other operators) to only take tensors, and use PrimValues only for interacting with low-level libraries as @psrivas2 suggested. The common case for our compiler is tensor arithmetics that needs to be optimized, and we want to reduce the complexity of such common case so optimization passes do not need to worry about more general cases; For PrimValue, having ability to be able to represent calls into libraries which can take PrimValue as arguments is sufficient and there is not a strong need to rewrite these parts during compilation.
Follow-up question: Are PrimValues mutable? E.g., can a PackedFunc mutate them in place? I would assume (and hope) not. This would mean we have a distinction between what is passed by value and what is passed by reference, just like a lot of other languages (this might be useful for users).
Thank you all for the great discussion!
For @MasterJH5574,
So I’m curious about the LibTorch support on the cases where an operator may take a list of PrimValues as input (e.g., softmax operator which takes a list of axes).
It is unclear if they provide the general support for list, but they seem to support a list of integers. Take a look at at::IntArrayRef
in conv2d as an example.
at::Tensor &at::_slow_conv2d_forward_out(at::Tensor &output, const at::Tensor &self, const at::Tensor &weight, at::IntArrayRef kernel_size, const c10::optional<at::Tensor> &bias, at::IntArrayRef stride, at::IntArrayRef padding)
And secondly, if such cases do exist, what is the current way to call a LibTorch softmax, for example?
I believe this is what we need to figure out. So I would like to have some supports like tuple, but not sure this has to be runtime.Tuple
.
For @slyubomirsky,
Are PrimValues mutable? E.g., can a PackedFunc mutate them in place?
Based on my investigation so far, they seem to be invariants since main purpose is to pass the configuration values.
For @YuchenJin ,
One problem of using 0-rank tensors to represent POD values is that tensor is device-specific, while these POD values are always on host
Do you assume that Libtorch fallback would be always running on the host device?
Thanks for the discussions so far!
Are PrimValues mutable? E.g., can a PackedFunc mutate them in place?
If we restrict the Relax-level operators to only take device-aware tensors, and restrict PrimValues to be only used for external function calls, PrimValues are not mutable: PrimValues are constant values that are passed as arguments to packed functions.
Do you assume that Libtorch fallback would be always running on the host device?
The Libtorch fallback functions can run on devices such as GPU, while the PrimValues are arguments on the host being passed to the kernel functions. These kernel functions are pass-by-value, and these arguments are copied by the host to a dedicated memory buffer on the device for example the constant memory on cuda gpu.
For today's discussion for Open Dev meeting, I summarized our discussion so far.
Hope this helps.
Motivation
We need a way to pass POD values (e.g., int, string, boolean) to call_packed
and call_packed_dps
to feed libtorch API.
Requirements
source: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#func
POD Values
. Particularly,int
,float
,bool
,str
Scalar
: supports binding to any numerical types from Python, including integral types, floating point types, and zero dimensional tensors.Tensor[]
: translates into a C++ argument of typeArrayRef<Tensor>
(a.k.a.TensorList
)int[]
: accepts an optional length specifier, e.g.,int[2]
, which has no effect in C++ but extends our Python bindings to accept a bare number, which will be expanded into an appropriately sized list by repeating the number.bool[N]
(where N is1-4
)
Options
- Thinkable but not considered
- Attribute
- Attribute is compile-time concept: it should be consumed and disappear during the compilation.
call_packed
andcall_packed_dps
are runtime constructs
- Attribute
- O1: N-rank Tensor (NDArray)
- Property
- Can be expressed with constant node in the AST
- Tensor is device-specific
- Pros:
- Already existing support for POD values and
int[]
,bool[N]
- Already existing support for POD values and
- Cons:
- String needs special touch - NDArray cannot represent strings
- Inefficient for CUDA-like API where kernel launch has to be done from the host side
- At runtime, values should be copied from device→host, then host→device for kernel launch
- Property
- O2: Introduce
PrimValue
Expr- Property
- Will have
PrimType
- not a subtype of
relax.ObjectType
. At runtime, PrimValue does not map to a TVM runtime Object, while it maps toTVMPODValue_
- Relax VM register can store any
TVMRetValue
, which includeTVMPODValue_
andstd::string
.
- Relax VM register can store any
- not a subtype of
- Lives in the host side and will be passed to the kernel functions in pass-by-value: they will be transferred to a dedicated memory buffer on the device at runtime(e.g., constant memory on cuda gpu).
- Will have
- Pros:
- Host side value
- Might be good to express Torch
Scalar
type
- Cons:
- Need to define the interaction with
TupleType
- May cause confusion around the relation with O1
- Large string or list might occur noticeable overhead
- Need to define the interaction with
- Property
Points of discussion from yesterday's Relax community's meeting:
PrimValue
s might add considerable complexity to the language, since it would be yet another kind of value. For invoking functions/operators alone, there would be tensors,PrimValue
s, and attributes.- If we stick to a convention of using
PrimValue
s only for third-party low-level libraries, this complexity may be manageable - Using zero-rank tensors for constants poses some practical problems because
NDArray
s have to be associated with specific devices for certain cases, like with the GPU, we might be forced to move values to and from the device- An alternative would be exposing the device to Relax and allowing for heterogeneous computations to be described in Relax, but this would be a very complex new feature.
- We could also have special-case handling for these low-level libraries to avoid this copying
PrimValue
s are always on the host device, so it gets around this issue
- Alternatively, we could try to simplify matters by, e.g., getting rid of attributes and using
PrimValue
s for that purpose- It would be important to ensure that we could still use those values at compile time for type checking and shape rules
- Rather than using attributes in an ad hoc way in type checking, we could instead have schema validation for operators and be able to systematically validate the use of
PrimValue
s and tensors. This would be similar to PyTorch's operator dispatch system.
- One issue with
PrimValue
s is that they are not TVMObject
s and so cannot be placed inside a runtime ADT container- We could modify the runtime containers to take
TVMRetValue
s, which would be able to hold thesePrimValue
s. - Alternatively, we could box the
PrimValue
s in either a new or existing object
- We could modify the runtime containers to take
Some possible directions:
- Big adjustment to existing operator implementation approach
- Make
PrimValue
s/PrimType
s first-class citizens and able to be contained in tuples, etc - Phase out attributes in favor of
PrimValue
s - Don’t try to use Relay operator strategies—Relax has the ability to call into TOPI directly
- Make
- Alternative: Require boxing for
PrimValue
s to allow them to appear inside other data structures - Tension with introducing
PrimValue
s: Tension: Relax is a domain-specific language, but we are adding lots of constructs that are greatly increasing its generality, which may create great difficulties for pass writers- One response is that
PrimValue
s are mainly meant for interacting with low-level libraries, so we could avoid permittingPrimValue
s to interact with the higher-level constructs in the language (e.g., by requiring wrapping them in tensor values) - On the other hand, the more general use of
PrimValue
s to replace current uses for attributes could lead to greater flexibility and convenience (allowing computed values to be used for things that are now required to be compile-time attributes). The type system can express tensors of unknown rank or shape, so the language would have the flexibility to deal with these issues.
- One response is that
We will continue discussions because we have not reached consensus on how to scope PrimValue
s in the language. Additionally, we noted that this is a good time to consider revisions to the operator system because most Relax operators have not yet been implemented. Similarly, if we wish to tackle the issue of NDArray
s' being associated with specific devices, this is a good time for it.
Personal opinion: If we aim to replace attributes with PrimValue
s, we would probably also want to build default values for arguments into the language. Replacing attributes would unfortunately be a generally laborious affair.
Thank you, @slyubomirsky for great summary!
Based on our discussion yesterday, @YuchenJin and I could define the following action items:
- A0: Introduce
PrimValue
andPrimType
by restricting its usage tocall_packed_dps
as the first step. - A1: Define interaction between
PrimValues
and other values such asTensor
andTuple
- A2: Potentially unify attributes with
PrimValue
and introduce operator schema.
Yesterday, I believe we reached consensus on A0 while leaving A1 and A2 as the future discussion topics.
Since libtorch integration requires A0 and part of A1 (particularly, Tuple
), I'll start working on A0 and I'd like to continue our discussion on A1 and A2.
Thank you all for the fruitful discussion!