dmlc/mshadow

How can embedding work under cuda version 6.5

quanzongfeng opened this issue · 2 comments

In embed layer backward function, when grad_out.shape_[0] >grad_out.shape_[1], then it will invoke SortByKey function, which is defined when cuda_version >7000. So , under version 6.5, it will not work, just raise error and abort!

How to make it work under version 6.5? I modify the code, just use
AddTakeGrad(grad_in, data, grad_out);
when backward, but i found the result is worse than in another machine with cuda 7.0.

So How to modify the code to make it work when under cuda version 6.5 and grad_out.shape_[0] > grad_out.shape_[1] ?

The code are as follow:

if ((grad_out.shape_[0] < grad_out.shape_[1]) && (grad_out.shape_[0] < 512)) {
    AddTakeGrad(grad_in, data, grad_out);
  } else {
    Tensor<xpu, 2, int> workspace =
      ctx.requested[embedding::kTempSpace].get_space_typed<xpu, 2, int>(
      mshadow::Shape2(2, data.shape_.Size()), s);
    Tensor<xpu, 1, int> sorted_data = workspace[0];
    Tensor<xpu, 1, int> original_index = workspace[1];
    sorted_data = tcast<int>(data);
    original_index = range<int>(0, data.shape_.Size());
    SortByKey(sorted_data, original_index, true);
    AddTakeGradLargeBatch(grad_in, sorted_data, original_index, grad_out);

}

template<typename KDType, typename VDType>
inline void SortByKey(Tensor<gpu, 1, KDType> keys, Tensor<gpu, 1, VDType> values,
bool is_ascend) {
CHECK_EQ(keys.CheckContiguous(), true);
CHECK_EQ(values.CheckContiguous(), true);

if CUDA_VERSION >= 7000

cudaStream_t stream = Stream::GetStream(keys.stream_);
thrust::device_ptr key_iter = thrust::device_pointer_cast(keys.dptr_);
thrust::device_ptr value_iter = thrust::device_pointer_cast(values.dptr_);
if (is_ascend) {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::less());
} else {
thrust::stable_sort_by_key(
thrust::cuda::par.on(stream),
key_iter, key_iter + keys.size(0), value_iter, thrust::greater());
}

else

LOG(FATAL) << "SortByKey is only supported for CUDA version >=7.0!";

endif

}

MXNet requires cuda 7.5
try updating your cuda sdk. You don't need root for this

szha commented

This code base has been donated to the Apache MXNet project per #373, and repo is deprecated. Future development and issue tracking should continue in Apache MXNet.