|
|
|
@ -150,12 +150,7 @@ class StackKernel : public framework::OpKernel<T> {
|
|
|
|
|
int total_num = pre * n * post;
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
constexpr auto kMaxThreshold = 16;
|
|
|
|
|
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
|
|
|
|
|
n > kMaxThreshold) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
VLOG(10) << "Stack more than " << kMaxThreshold
|
|
|
|
|
<< " tensors on GPU may be slow.";
|
|
|
|
|
thrust::device_vector<const T *> device_x_vec(x_datas);
|
|
|
|
|
auto x_data_arr = device_x_vec.data().get();
|
|
|
|
|
#else
|
|
|
|
@ -168,14 +163,6 @@ class StackKernel : public framework::OpKernel<T> {
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
else { // NOLINT
|
|
|
|
|
framework::Array<const T *, kMaxThreshold> x_data_arr;
|
|
|
|
|
for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
|
|
|
|
|
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class StackOpGrad : public framework::OperatorWithKernel {
|
|
|
|
@ -244,34 +231,19 @@ class StackGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int post = total_num / (n * pre);
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
constexpr auto kMaxThreshold = 16;
|
|
|
|
|
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
|
|
|
|
|
n > kMaxThreshold) {
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
VLOG(10) << "Stack more than " << kMaxThreshold
|
|
|
|
|
<< " tensors on GPU may be slow.";
|
|
|
|
|
thrust::device_vector<T *> device_dx_vec(dx_datas);
|
|
|
|
|
auto dx_data_arr = device_dx_vec.data().get();
|
|
|
|
|
#else
|
|
|
|
|
auto dx_data_arr = dx_datas.data();
|
|
|
|
|
#endif
|
|
|
|
|
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
|
|
|
|
|
post);
|
|
|
|
|
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
// Wait() must be called because device_dx_vec may be destructed before
|
|
|
|
|
// kernel ends
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#ifdef __NVCC__
|
|
|
|
|
else { // NOLINT
|
|
|
|
|
framework::Array<T *, kMaxThreshold> dx_data_arr;
|
|
|
|
|
for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i];
|
|
|
|
|
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
|
|
|
|
|
post);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|