stdml/stdtensor

thrust adaptor

Opened this issue · 1 comments

template <template <typename, ttl::rank_t, typename> class T, typename R,
          ttl::rank_t r, typename S>
thrust::device_ptr<R> begin(const T<R, r, S> &t)
{
    R *d = const_cast<R *>(t.data());
    thrust::device_ptr<R> p = thrust::device_pointer_cast(d);
    return p;
}

template <template <typename, ttl::rank_t, typename> class T, typename R,
          ttl::rank_t r, typename S>
thrust::device_ptr<R> end(const T<R, r, S> &t)
{
    R *d = const_cast<R *>(t.data_end());
    thrust::device_ptr<R> p = thrust::device_pointer_cast(d);
    return p;
}
thrust::device_vector<R> y(n);

R* p = thrust::raw_pointer_cast(y.data());