|
|
|
@ -81,6 +81,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
auto* input = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
int group = ctx.Attr<int>("group");
|
|
|
|
|
|
|
|
|
|
auto input_dims = input->dims();
|
|
|
|
|
auto num = input_dims[0];
|
|
|
|
|
auto channel = input_dims[1];
|
|
|
|
@ -101,6 +102,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
int blocks = NumBlocks(output_grad->numel());
|
|
|
|
|
int threads = kNumCUDAThreads;
|
|
|
|
|
int count = num * group_column * group_row * sp_sz;
|
|
|
|
|
|
|
|
|
|
ShuffleChannel<
|
|
|
|
|
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
|
|
|
|
|
count, feature_map_size, input_grad_data, output_grad_data, group_row,
|
|
|
|
|