|
|
|
@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
|
|
|
|
|
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
|
|
|
|
|
int64_t src_after = src_stride_numel[axis];
|
|
|
|
|
int64_t dst_after = dst_stride_numel[axis];
|
|
|
|
|
int64_t copy_size = std::min(src_after, dst_after);
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
|
|
|
|
@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
|
|
|
|
|
if (platform::is_cpu_place(place)) {
|
|
|
|
|
auto& cpu_place = boost::get<platform::CPUPlace>(place);
|
|
|
|
|
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
|
|
|
|
|
src + i * src_after, sizeof(T) * src_after);
|
|
|
|
|
src + i * src_after, sizeof(T) * copy_size);
|
|
|
|
|
} else {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
|
|
|
|
|
auto& cuda_ctx =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
|
|
|
|
|
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
|
|
|
|
|
src + i * src_after, sizeof(T) * src_after,
|
|
|
|
|
src + i * src_after, sizeof(T) * copy_size,
|
|
|
|
|
cuda_ctx.stream());
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Paddle is not compiled with GPU");
|
|
|
|
|