JuliaGPU/XLA.jl

Possible TPU crash

staticfloat opened this issue · 0 comments

Base.@pure calc_rounds(L; exponent=3) = ceil(UInt32, exponent * log(L) / log(typemax(UInt32)))
function shuffle(x::XRTArray)
    rounds = calc_rounds(length(x))
    round_idx = XRTArray(rounds)
    while round_idx > XRTArray(0)
        keys = rand(XRTArray{UInt32}, length(x))
        x = sort(x, keys=keys)
        # Commenting this out crashes the TPU
        #round_idx -= XRTArray(1)
    end
    return x
end