Merge pull request #15027 from shippingwang/shufflechannel
Add Shuffle Channel Operatorinference-pre-release-gpu
commit
88bd7e1a61
@ -0,0 +1,113 @@
|
||||
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/shuffle_channel_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ShuffleChannelOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of ShuffleChannelOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of ShuffleChannelOp should not be null.");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
ctx->SetOutputDim("Out", input_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor, default Tensor<float>), "
|
||||
"the input feature data of ShuffleChannelOp, the layout is NCHW.");
|
||||
AddOutput("Out",
|
||||
"(Tensor, default Tensor<float>), the output of "
|
||||
"ShuffleChannelOp. The layout is NCHW.");
|
||||
AddAttr<int>("group", "the number of groups.")
|
||||
.SetDefault(1)
|
||||
.AddCustomChecker([](const int& group) {
|
||||
PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0.");
|
||||
});
|
||||
|
||||
AddComment(R"DOC(
|
||||
Shuffle Channel operator
|
||||
This opearator shuffles the channels of input x.
|
||||
It divide the input channels in each group into several subgroups,
|
||||
and obtain a new order by selecting element from every subgroup one by one.
|
||||
|
||||
Shuffle channel operation makes it possible to build more powerful structures
|
||||
with multiple group convolutional layers.
|
||||
please get more information from the following paper:
|
||||
https://arxiv.org/pdf/1707.01083.pdf
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class ShuffleChannelGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@Grad) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Output(X@Grad) should not be null");
|
||||
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), input_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp,
|
||||
ops::ShuffleChannelOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
|
||||
REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
shuffle_channel,
|
||||
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::ShuffleChannelOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
shuffle_channel_grad,
|
||||
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::ShuffleChannelGradOpKernel<paddle::platform::CPUDeviceContext,
|
||||
double>);
|
@ -0,0 +1,125 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/shuffle_channel_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
static constexpr int kNumCUDAThreads = 512;
|
||||
static constexpr int kNumMaximumNumBlocks = 4096;
|
||||
|
||||
static inline int NumBlocks(const int N) {
|
||||
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
|
||||
kNumMaximumNumBlocks);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ShuffleChannel(const int nthreads, const int feature_map_size,
|
||||
T* output, const T* input, int group_row,
|
||||
int group_column, int len) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int offset = blockDim.x * gridDim.x;
|
||||
for (size_t ii = index; ii < nthreads; ii += offset) {
|
||||
const int n = index / group_row / group_column / len;
|
||||
const int i = (index / group_column / len) % group_row;
|
||||
const int j = index / len % group_column;
|
||||
const int k = index - (n * feature_map_size + (i * group_column + j) * len);
|
||||
T* p_o = output + n * feature_map_size + (j * group_row + i) * len;
|
||||
p_o[k] = input[index];
|
||||
}
|
||||
}
|
||||
template <typename DeviceContext, typename T>
|
||||
class ShuffleChannelOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<framework::Tensor>("X");
|
||||
auto* output = ctx.Output<framework::Tensor>("Out");
|
||||
int group = ctx.Attr<int>("group");
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto num = input_dims[0];
|
||||
auto channel = input_dims[1];
|
||||
auto height = input_dims[2];
|
||||
auto weight = input_dims[3];
|
||||
|
||||
auto feature_map_size = channel * height * weight;
|
||||
auto sp_sz = height * weight;
|
||||
int group_row = group;
|
||||
int group_column = channel / group_row;
|
||||
// count is the product of NCHW same as numel()
|
||||
int count = num * group_column * group_row * sp_sz;
|
||||
|
||||
int blocks = NumBlocks(output->numel());
|
||||
int threads = kNumCUDAThreads;
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
ShuffleChannel<
|
||||
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
|
||||
count, feature_map_size, output_data, input_data, group_row,
|
||||
group_column, sp_sz);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<framework::Tensor>("X");
|
||||
int group = ctx.Attr<int>("group");
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto num = input_dims[0];
|
||||
auto channel = input_dims[1];
|
||||
auto height = input_dims[2];
|
||||
auto weight = input_dims[3];
|
||||
auto feature_map_size = channel * height * weight;
|
||||
auto sp_sz = height * weight;
|
||||
|
||||
int group_row = group;
|
||||
int group_column = channel / group_row;
|
||||
auto* output_grad =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* input_grad =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
|
||||
int blocks = NumBlocks(output_grad->numel());
|
||||
int threads = kNumCUDAThreads;
|
||||
int count = num * group_column * group_row * sp_sz;
|
||||
|
||||
ShuffleChannel<
|
||||
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
|
||||
count, feature_map_size, input_grad_data, output_grad_data, group_row,
|
||||
group_column, sp_sz);
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
shuffle_channel,
|
||||
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::ShuffleChannelOpCUDAKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
shuffle_channel_grad,
|
||||
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
|
||||
float>,
|
||||
ops::ShuffleChannelGradOpCUDAKernel<paddle::platform::CUDADeviceContext,
|
||||
double>);
|
@ -0,0 +1,95 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ShuffleChannelOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<framework::Tensor>("X");
|
||||
auto* output = ctx.Output<framework::Tensor>("Out");
|
||||
int group = ctx.Attr<int>("group");
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto num = input_dims[0];
|
||||
auto channel = input_dims[1];
|
||||
auto height = input_dims[2];
|
||||
auto weight = input_dims[3];
|
||||
|
||||
auto feature_map_size = channel * height * weight;
|
||||
auto sp_sz = height * weight;
|
||||
int group_row = group;
|
||||
int group_column = channel / group_row;
|
||||
|
||||
const T* input_data = input->data<T>();
|
||||
T* output_data = output->mutable_data<T>(ctx.GetPlace());
|
||||
for (int n = 0; n < num; ++n) {
|
||||
for (int i = 0; i < group_row; ++i) {
|
||||
for (int j = 0; j < group_column; ++j) {
|
||||
const T* p_i = input_data + n * feature_map_size +
|
||||
(i * group_column + j) * sp_sz;
|
||||
T* p_o =
|
||||
output_data + n * feature_map_size + (j * group_row + i) * sp_sz;
|
||||
memcpy(p_o, p_i, sizeof(int) * sp_sz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ShuffleChannelGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* input = ctx.Input<framework::Tensor>("X");
|
||||
int group = ctx.Attr<int>("group");
|
||||
|
||||
auto input_dims = input->dims();
|
||||
auto num = input_dims[0];
|
||||
auto channel = input_dims[1];
|
||||
auto height = input_dims[2];
|
||||
auto weight = input_dims[3];
|
||||
auto feature_map_size = channel * height * weight;
|
||||
auto sp_sz = height * weight;
|
||||
|
||||
int group_row = group;
|
||||
int group_column = channel / group_row;
|
||||
|
||||
auto* output_grad =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
auto* input_grad =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
|
||||
const T* output_grad_data = output_grad->data<T>();
|
||||
for (int n = 0; n < num; ++n) {
|
||||
for (int i = 0; i < group_row; ++i) {
|
||||
for (int j = 0; j < group_column; ++j) {
|
||||
const T* p_i = output_grad_data + n * feature_map_size +
|
||||
(i * group_column + j) * sp_sz;
|
||||
T* p_o = input_grad_data + n * feature_map_size +
|
||||
(j * group_row + i) * sp_sz;
|
||||
memcpy(p_o, p_i, sizeof(int) * sp_sz);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import sys
|
||||
import math
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
class TestShuffleChannelOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "shuffle_channel"
|
||||
self.batch_size = 10
|
||||
self.input_channels = 16
|
||||
self.layer_h = 4
|
||||
self.layer_w = 4
|
||||
self.group = 4
|
||||
self.x = np.random.random(
|
||||
(self.batch_size, self.input_channels, self.layer_h,
|
||||
self.layer_w)).astype('float32')
|
||||
self.inputs = {'X': self.x}
|
||||
self.attrs = {'group': self.group}
|
||||
n, c, h, w = self.x.shape
|
||||
input_reshaped = np.reshape(self.x,
|
||||
(-1, self.group, c // self.group, h, w))
|
||||
input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4))
|
||||
self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue