unified interface for CPU and CUDA
lgarithm opened this issue · 3 comments
lgarithm commented
unified interface for CPU and CUDA
lgarithm commented
WIP: stdml/stdtensor#65
lgarithm commented
example.hpp
#include <ttl/tensor>
namespace ttl
{
using internal::cuda_memory;
using internal::host_memory;
struct google_tpu;
struct graphcore_ipu;
namespace internal
{
template <typename R> class basic_allocator<R, graphcore_ipu>
{
public:
R *operator()(size_t count)
{ // FIXME: call ipu API
return basic_allocator<R, host_memory>()(count);
}
};
template <typename R> class basic_deallocator<R, graphcore_ipu>
{
public:
void operator()(R *data)
{ // FIXME: call ipu API
basic_deallocator<R, host_memory>()(data);
}
};
} // namespace internal
} // namespace ttl
using cpu = ttl::host_memory;
using gpu = ttl::cuda_memory;
using ipu = ttl::graphcore_ipu;
using tpu = ttl::google_tpu;
class example
{
public:
template <typename R, typename D>
void operator()(const ttl::tensor_ref<R, 1, D> &y,
const ttl::tensor_view<R, 1, D> &x) const;
template <typename R> // specialize for cpu
void operator()(const ttl::tensor_ref<R, 1> &y,
const ttl::tensor_view<R, 1> &x) const;
template <typename R> // specialize for gpu
void operator()(const ttl::tensor_ref<R, 1, gpu> &y,
const ttl::tensor_view<R, 1, gpu> &x) const;
};
example.cpp
#include <cstdio>
#include <example.hpp>
#include <ttl/tensor>
template <typename R, typename D>
void example::operator()(const ttl::tensor_ref<R, 1, D> &y,
const ttl::tensor_view<R, 1, D> &x) const
{
printf("%s :: %s\n", __func__, "default");
}
template <typename R>
void example::operator()(const ttl::tensor_ref<R, 1, cpu> &y,
const ttl::tensor_view<R, 1, cpu> &x) const
{
printf("%s :: %s\n", __func__, "cpu");
}
template <typename R>
void example::operator()(const ttl::tensor_ref<R, 1, gpu> &y,
const ttl::tensor_view<R, 1, gpu> &x) const
{
printf("%s :: %s\n", __func__, "gpu");
}
int main()
{
int n = 10;
{
ttl::tensor<int, 1> x(n);
ttl::tensor<int, 1> y(n);
example()(ref(y), view(x));
}
{
ttl::tensor<int, 1, ipu> x(n);
ttl::tensor<int, 1, ipu> y(n);
example()(ref(y), view(x));
}
return 0;
}