Encounter edge_padding_low error while testing
Closed this issue · 5 comments
sh0416 commented
tests/test_flash.py:200: in func
o, bwd = jax.vjp(fwd,q,k,v)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/api.py:2169: in vjp
return _vjp(
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/api.py:2178: in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:143: in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/ad.py:132: in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/profiler.py:335: in wrapper
return func(*args, **kwargs)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/interpreters/partial_eval.py:774: in trace_to_jaxpr_nounits
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/_src/linear_util.py:192: in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
tests/test_flash.py:199: in fwd
return mha(q,k,v, is_causal=bool(causal), window_size=window_size)
src/flash_attn_jax/flash.py:237: in flash_mha
o = _flash_mha_vjp(q,k,v,dict(softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size))
src/flash_attn_jax/flash.py:213: in fwd
out, lse = _flash_mha_fwd(q,k,v, **config)
src/flash_attn_jax/flash.py:53: in _flash_mha_fwd
return tuple(_flash_mha_fwd_p.bind(q,k,v, softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size))
src/flash_attn_jax/flash.py:135: in mha_fwd_batch
out, lse = _flash_mha_fwd_p.bind(q, k, v, **kwargs)
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jax/experimental/custom_partitioning.py:500: in _custom_partitioning_lowering_rule
return mlir.lower_fun(
src/flash_attn_jax/flash_hlo.py:111: in _flash_mha_fwd_hlo_lowering
q_padded = mlir.hlo.PadOp(q,z,[0,0,0,0],[0,0,0,dpad],[0,0,0,0]).result
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <jaxlib.mlir.dialects._stablehlo_ops_gen.PadOp object at 0x7fbb843c3a40>, operand = <jaxlib.mlir._mlir_libs._mlir.ir.BlockArgument object at 0x7fbb40581e30>
padding_value = <jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7fbb483102b0>, edge_padding_low = [0, 0, 0, 0], edge_padding_high = [0, 0, 0, 5]
interior_padding = [0, 0, 0, 0]
def __init__(self, operand, padding_value, edge_padding_low, edge_padding_high, interior_padding, *, loc=None, ip=None):
operands = []
results = []
attributes = {}
regions = None
operands.append(_get_op_result_or_value(operand))
operands.append(_get_op_result_or_value(padding_value))
_ods_context = _ods_get_default_loc_context(loc)
attributes["edge_padding_low"] = (edge_padding_low if (
isinstance(edge_padding_low, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(edge_padding_low, context=_ods_context))
attributes["edge_padding_high"] = (edge_padding_high if (
isinstance(edge_padding_high, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(edge_padding_high, context=_ods_context))
attributes["interior_padding"] = (interior_padding if (
isinstance(interior_padding, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('GenericDenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('GenericDenseI64ArrayAttr')(interior_padding, context=_ods_context))
_ods_successors = None
> super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
E RuntimeError: Invalid attribute value for the key "edge_padding_low" when attempting to create the operation "stablehlo.pad" (Unable to cast Python instance to C++ type (#define PYBIND11_DETAILED_ERROR_MESSAGES or compile in debug mode for details))
../anaconda3/envs/flash-jax/lib/python3.9/site-packages/jaxlib/mlir/dialects/_stablehlo_ops_gen.py:3466: RuntimeError
I got this error while testing with pytest. Is there any simple solution to resolve this error?
sh0416 commented
I installed from source, FYI.
nshepperd commented
Huhh. What jax and jaxlib versions?
sh0416 commented
jax: 0.4.28
jaxlib: 0.4.28
numpy: 1.26.4
python: 3.9.19 (main, May 6 2024, 19:43:03) [GCC 11.2.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='gpusystem', release='5.15.0-100-generic', version='#110-Ubuntu SMP Wed Feb 7 13:27:48 UTC 2024', machine='x86_64')
This one.. seems that the problem is in jax, not your code.
nshepperd commented
Okie, I'll do some testing, should be easy to fix.
nshepperd commented
Tested this on 0.4.28, this should work now.