|
|
|
@ -149,11 +149,20 @@ class StackKernel : public framework::OpKernel<T> {
|
|
|
|
|
for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
|
|
|
|
|
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
int total_num = pre * n * post;
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
|
thrust::device_vector<const T *> device_x_vec(x_datas);
|
|
|
|
|
auto x_data_arr = device_x_vec.data().get();
|
|
|
|
|
|
|
|
|
|
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
|
|
|
|
|
|
|
|
|
|
// Wait() must be called because device_x_vec may be destructed before
|
|
|
|
|
// kernel ends
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
#else
|
|
|
|
|
auto x_data_arr = x_datas.data();
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
size_t x_offset = 0;
|
|
|
|
|
size_t y_offset = 0;
|
|
|
|
|
for (int i = 0; i < pre; i++) {
|
|
|
|
@ -164,10 +173,6 @@ class StackKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
x_offset += post;
|
|
|
|
|
}
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
// Wait() must be called because device_x_vec may be destructed before
|
|
|
|
|
// kernel ends
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|