tlc-pack/relax

[Discuss][IR] Relax PrimValue

YuchenJin opened this issue · 12 comments

This is a draft proposal to show high-level ideas about supporting constant POD (plain-old-data) values (e.g., int, float, bool) in Relax IR.

Relax has first-class support of interacting with TVM FFI in the graph by calling into PackedFunc via call_dps_packed and call_packed. When integrating with third-party libraries such as LibTorch, third-party library functions may take constant arguments of POD data types. For example, ATen::unique takes optional boolean arguments  sorted and return_inverse, and ATen::add takes optional integer argument alpha.

So far we can represent such POD data types by wrapping them into a tvm attribute, but it needs special handling in the codegen and it’s hard to automatically generate such attributes.

To bridge the Relax IR with the runtime, we propose to introduce PrimValue to Relax to express primitive POD values in the IR.

IR Representation

In this proposal, we introduce an expression R.PrimValue to Relax. A PrimValue is an IR construct that can represents POD values such as int, float, and boolean, and string.

It can be implemented as a wrapper around PrimExpr, which can be tir::IntImm, tir::FloatImm, tir::Bool, and tir::StringImm, .

 class PrimValue(Expr):
    # a constant value that can be tir::IntImm/FloatImm/Bool/StringImm
    value: PrimExpr

Type of PrimValue

We can reuse PrimType to represent POD values. Note PrimType is not a subtype of relax.ObjectType because at runtime, PrimValue does not map to a TVM runtime Object, while it maps to TVMPODValue_.

TVMScript example

The following code block demonstrates a simple Relax function with call_dps_packed and call_packed into libtorch functions:

@tvm.script.ir_module
class MyMod:
    @R.function
    def main(c0: Tensor((32, 32), "float32"), c1: Tensor((32, 32), "float32")):

	# libtorch function with "out" as its suffix follows DPS calling convention  
        x = R.call_dps_packed("libtorch_add_out", (c0, c1), alpha=2, output_shape=(32, 32), dtype="float32")
				
        # bind a PrimValue to a Var
	approximate: PrimType = R.PrimValue("tanh")
	y = R.call_dps_packed("libtorch_gelu_out", x, approximate, output_shape=(32, 32), dtype="float32")

	out = R.call_packed("libtorch_unique", x, sorted=True, return_inverse=True, type_args=(Tensor(ndim=1, dtype="float32")))
        return out
  • alpha=2 in the rhs of the first binding is parsed as a PrimValue with tir::IntImm(2) as the value.
  • approximate is a relax.Var of PrimType, and it can be passed into call_dps_packed as arguement.
  • Runtime: Relax VM has already support it because the VM register can store any TVMRetValue, which include TVMPODValue_ and std::string.

Other considerations:

Scope: Right now this design note restricts PrimValue to represent int, float, bool, or string and to be standalone. We can think of possible directions in the future:

  • Allow PrimValue to interact with TensorType. For example, we can define the semantics of tensor / int as in pytorch and numpy.
  • Allow PrimValue to interact with TupleType. Do we want to represent Tuple(int) in the IR? One caveat is TVM container (e.g., runtime.Tuple) does not support POD data types today.

Thanks @YuchenJin for the proposal! I have a few comments on that.

  • With PrimValue, we now have two options for a Relax high-level operator API. Say we are designing the operator batch_norm, which takes five tensors of data, gamma beta, moving_mean and moving_var, together with an axis, and an epsilon value:
    • One option is to put all of them as operator input, since the axis and epsilon can be treated as PrimValues and can have type PrimType. With this design, the Call to a batch_norm operator will be
      Call(Op::Get("relax.nn.batch_norm"), {data, gamma, beta, moving_mean, moving_var, axis, epsilon}, Attrs())
    • The other option is to wrap the axis and epsilon with BatchNormAttrs, as Relay does. Here the Call will look like
      Call(Op::Get("relax.nn.batch_norm"), {data, gamma, beta, moving_mean, moving_var}, BatchNormAttrs(axis, epsilon))
    So I suppose we should discuss and figure out a standard for operator design. My understanding is that in order to connect with LibTorch interface as a fallback support, we have virtually no other choice but to wrap the integers/booleans/floats as PrimValues so that they can serve as Relax Expr and be passed as call_packed_dps inputs. As for high-level operator calls, we don’t have such limitations and thus can keep using attributes. But even so, we still need to elaborate the reason of our design choice.
  • As Yuchen mentioned above, it is still uncertain whether to support a tuple of integers as a valid Relax expression. So I’m curious about the LibTorch support on the cases where an operator may take a list of PrimValues as input (e.g., softmax operator which takes a list of axes). Firstly I’m not quite sure how the API for such operators look like in LibTorch. And secondly, if such cases do exist, what is the current way to call a LibTorch softmax, for example?

I echo the thanks for writing up this proposal. As we discussed at the community meeting, I am a bit hesitant about including PrimValues as a separate type/value in the language because it would mean that there would be more than one way to represent the same value in Relax: We would have both rank-0 tensors (which represent scalars) and PrimValues. The choice of run-time representation (NDArray vs a TVM POD value) is a lower-level concern and perhaps should not require additional complexity in the front-end language, including in the type system. I would recommend using the NDArray representation wherever possible and converting in whichever operators do require PrimValues. We should profile the cost of such conversions--if the cost is small (imagining that we might have to do this multiple times in a model), I would say it's really not worth introducing the additional complexity into the front-end language. (The situation I would prefer to avoid is users getting mad that instantiating a constant results in the "wrong kind" of constant for whatever operator they're using, something which is likely to cause annoyance.)

If we do introduce PrimValues as a separate value in the language, it might be worth having some policy for when we would prefer to use a PrimValue and when we would prefer to use a 0-rank tensor. @psrivas2 suggested one policy, of using tensors for all Relax-level operators and using PrimValues only for interacting with low-level libraries.

That said, POD values can represent strings while NDArrays cannot, so perhaps one option might be to add a string constant node (or even allow strings to be passed as an argument to the Constant expr) and define string values (represented using the existing TVM runtime string container).

I would also be hesitant to introduce a type that is not a subtype of Object, since it would complicate some of the type system mechanics. I suppose this would be analogous to how primitive values in Java and C# are not Objects either. I think the situation of wrapping these primitive values in tensors/other containers would be analogous to boxing/unboxing in Java and C# (Java version, C# version).

Thanks everyone for the discussions at the community meeting today, and thanks @slyubomirsky and @psrivas2 for proposing alternative plans and summarizing the tradeoffs!

Introducing a new type indeed needs careful consideration. One problem of using 0-rank tensors to represent POD values is that tensor is device-specific, while these POD values are always on host. Pytorch and other libraries make such distinction between host-only value and device-specific tensor.

To reduce the complexity of writing passes, we can indeed restrict the Relax-level operators (for example comparison, addition, multiplications and other operators) to only take tensors, and use PrimValues only for interacting with low-level libraries as @psrivas2 suggested. The common case for our compiler is tensor arithmetics that needs to be optimized, and we want to reduce the complexity of such common case so optimization passes do not need to worry about more general cases; For PrimValue, having ability to be able to represent calls into libraries which can take PrimValue as arguments is sufficient and there is not a strong need to rewrite these parts during compilation.

Follow-up question: Are PrimValues mutable? E.g., can a PackedFunc mutate them in place? I would assume (and hope) not. This would mean we have a distinction between what is passed by value and what is passed by reference, just like a lot of other languages (this might be useful for users).

Thank you all for the great discussion!

For @MasterJH5574,

So I’m curious about the LibTorch support on the cases where an operator may take a list of PrimValues as input (e.g., softmax operator which takes a list of axes).

It is unclear if they provide the general support for list, but they seem to support a list of integers. Take a look at at::IntArrayRef in conv2d as an example.

at::Tensor &at::_slow_conv2d_forward_out(at::Tensor &output, const at::Tensor &self, const at::Tensor &weight, at::IntArrayRef kernel_size, const c10::optional<at::Tensor> &bias, at::IntArrayRef stride, at::IntArrayRef padding)

And secondly, if such cases do exist, what is the current way to call a LibTorch softmax, for example?

I believe this is what we need to figure out. So I would like to have some supports like tuple, but not sure this has to be runtime.Tuple.

For @slyubomirsky,

Are PrimValues mutable? E.g., can a PackedFunc mutate them in place?

Based on my investigation so far, they seem to be invariants since main purpose is to pass the configuration values.

For @YuchenJin ,

One problem of using 0-rank tensors to represent POD values is that tensor is device-specific, while these POD values are always on host

Do you assume that Libtorch fallback would be always running on the host device?

Thanks for the discussions so far!

Are PrimValues mutable? E.g., can a PackedFunc mutate them in place?

If we restrict the Relax-level operators to only take device-aware tensors, and restrict PrimValues to be only used for external function calls, PrimValues are not mutable: PrimValues are constant values that are passed as arguments to packed functions.

Do you assume that Libtorch fallback would be always running on the host device?

The Libtorch fallback functions can run on devices such as GPU, while the PrimValues are arguments on the host being passed to the kernel functions. These kernel functions are pass-by-value, and these arguments are copied by the host to a dedicated memory buffer on the device for example the constant memory on cuda gpu.

For today's discussion for Open Dev meeting, I summarized our discussion so far.
Hope this helps.

Motivation

We need a way to pass POD values (e.g., int, string, boolean) to call_packed and call_packed_dps to feed libtorch API.

Requirements

source: https://github.com/pytorch/pytorch/tree/master/aten/src/ATen/native#func

  • POD Values. Particularly, int, float, bool, str
  • Scalar : supports binding to any numerical types from Python, including integral types, floating point types, and zero dimensional tensors.
  • Tensor[] : translates into a C++ argument of type ArrayRef<Tensor>
     (a.k.a. TensorList)
  • int[] : accepts an optional length specifier, e.g., int[2], which has no effect in C++ but extends our Python bindings to accept a bare number, which will be expanded into an appropriately sized list by repeating the number.
  • bool[N] (where N is 1-4)

Options

  • Thinkable but not considered
    • Attribute
      • Attribute is compile-time concept: it should be consumed and disappear during the compilation.
      • call_packed and call_packed_dps are runtime constructs
  • O1: N-rank Tensor (NDArray)
    • Property
      • Can be expressed with constant node in the AST
      • Tensor is device-specific
    • Pros:
      • Already existing support for POD values and int[], bool[N]
    • Cons:
      • String needs special touch - NDArray cannot represent strings
      • Inefficient for CUDA-like API where kernel launch has to be done from the host side
        • At runtime, values should be copied from device→host, then host→device for kernel launch
  • O2: Introduce PrimValue Expr
    • Property
      • Will have PrimType
        • not a subtype of relax.ObjectType. At runtime, PrimValue does not map to a TVM runtime Object, while it maps to TVMPODValue_
          • Relax VM register can store any TVMRetValue, which include TVMPODValue_ and std::string.
      • Lives in the host side and will be passed to the kernel functions in pass-by-value: they will be transferred to a dedicated memory buffer on the device at runtime(e.g., constant memory on cuda gpu).
    • Pros:
      • Host side value
      • Might be good to express Torch Scalar type
    • Cons:
      • Need to define the interaction with TupleType
      • May cause confusion around the relation with O1
      • Large string or list might occur noticeable overhead

Points of discussion from yesterday's Relax community's meeting:

  • PrimValues might add considerable complexity to the language, since it would be yet another kind of value. For invoking functions/operators alone, there would be tensors, PrimValues, and attributes.
  • If we stick to a convention of using PrimValues only for third-party low-level libraries, this complexity may be manageable
  • Using zero-rank tensors for constants poses some practical problems because NDArrays have to be associated with specific devices for certain cases, like with the GPU, we might be forced to move values to and from the device
    • An alternative would be exposing the device to Relax and allowing for heterogeneous computations to be described in Relax, but this would be a very complex new feature.
    • We could also have special-case handling for these low-level libraries to avoid this copying
    • PrimValues are always on the host device, so it gets around this issue
  • Alternatively, we could try to simplify matters by, e.g., getting rid of attributes and using PrimValues for that purpose
    • It would be important to ensure that we could still use those values at compile time for type checking and shape rules
    • Rather than using attributes in an ad hoc way in type checking, we could instead have schema validation for operators and be able to systematically validate the use of PrimValues and tensors. This would be similar to PyTorch's operator dispatch system.
  • One issue with PrimValues is that they are not TVM Objects and so cannot be placed inside a runtime ADT container
    • We could modify the runtime containers to take TVMRetValues, which would be able to hold these PrimValues.
    • Alternatively, we could box the PrimValues in either a new or existing object

Some possible directions:

  • Big adjustment to existing operator implementation approach
    • Make PrimValues/PrimTypes first-class citizens and able to be contained in tuples, etc
    • Phase out attributes in favor of PrimValues
    • Don’t try to use Relay operator strategies—Relax has the ability to call into TOPI directly
  • Alternative: Require boxing for PrimValues to allow them to appear inside other data structures
  • Tension with introducing PrimValues: Tension: Relax is a domain-specific language, but we are adding lots of constructs that are greatly increasing its generality, which may create great difficulties for pass writers
    • One response is that PrimValues are mainly meant for interacting with low-level libraries, so we could avoid permitting PrimValues to interact with the higher-level constructs in the language (e.g., by requiring wrapping them in tensor values)
    • On the other hand, the more general use of PrimValues to replace current uses for attributes could lead to greater flexibility and convenience (allowing computed values to be used for things that are now required to be compile-time attributes). The type system can express tensors of unknown rank or shape, so the language would have the flexibility to deal with these issues.

We will continue discussions because we have not reached consensus on how to scope PrimValues in the language. Additionally, we noted that this is a good time to consider revisions to the operator system because most Relax operators have not yet been implemented. Similarly, if we wish to tackle the issue of NDArrays' being associated with specific devices, this is a good time for it.

Personal opinion: If we aim to replace attributes with PrimValues, we would probably also want to build default values for arguments into the language. Replacing attributes would unfortunately be a generally laborious affair.

Thank you, @slyubomirsky for great summary!

Based on our discussion yesterday, @YuchenJin and I could define the following action items:

  • A0: Introduce PrimValue and PrimType by restricting its usage to call_packed_dps as the first step.
  • A1: Define interaction between PrimValues and other values such as Tensor and Tuple
  • A2: Potentially unify attributes with PrimValue and introduce operator schema.

Yesterday, I believe we reached consensus on A0 while leaving A1 and A2 as the future discussion topics.
Since libtorch integration requires A0 and part of A1 (particularly, Tuple), I'll start working on A0 and I'd like to continue our discussion on A1 and A2.
Thank you all for the fruitful discussion!