commit
7d680be5a3
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,187 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "cub/cub.cuh"
|
||||
#include "paddle/fluid/framework/data_layout.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, framework::DataLayout layout, bool HasBias>
|
||||
__global__ void KeAffineChannelCUDA(const T* x, const T* scale, const T* bias,
|
||||
const int C, const int HxW, const int num,
|
||||
T* y) {
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = gid; i < num; i += stride) {
|
||||
const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C;
|
||||
if (HasBias) {
|
||||
y[i] = scale[c] * x[i] + bias[c];
|
||||
} else {
|
||||
y[i] = scale[c] * x[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AffineChannelCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<framework::Tensor>("X");
|
||||
auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||
|
||||
auto* y = ctx.Output<framework::Tensor>("Out");
|
||||
y->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
const framework::DataLayout layout =
|
||||
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
auto dims = x->dims();
|
||||
const int num = x->numel();
|
||||
int N = dims[0];
|
||||
int C = layout == framework::DataLayout::kNCHW ? dims[1]
|
||||
: dims[dims.size() - 1];
|
||||
int HxW = num / N / C;
|
||||
|
||||
const T* x_d = x->data<T>();
|
||||
const T* scale_d = scale->data<T>();
|
||||
const T* bias_d = bias->data<T>();
|
||||
T* y_d = y->data<T>();
|
||||
|
||||
int block = 1024;
|
||||
int grid = (num + block - 1) / block;
|
||||
if (layout == framework::DataLayout::kNCHW) {
|
||||
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
|
||||
true><<<grid, block, 0, dev_ctx.stream()>>>(
|
||||
x_d, scale_d, bias_d, C, HxW, num, y_d);
|
||||
} else {
|
||||
KeAffineChannelCUDA<T, framework::DataLayout::kNHWC,
|
||||
true><<<grid, block, 0, dev_ctx.stream()>>>(
|
||||
x_d, scale_d, bias_d, C, HxW, num, y_d);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int BlockDim, framework::DataLayout layout>
|
||||
__global__ void AffineChannelScaleBiasGradientCUDAKernel(
|
||||
const T* dy, const T* x, const int N, const int C, const int HxW, T* dscale,
|
||||
T* dbias) {
|
||||
const int outer_size = C;
|
||||
const int inner_size = N * HxW;
|
||||
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage ds_storage;
|
||||
__shared__ typename BlockReduce::TempStorage db_storage;
|
||||
|
||||
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
|
||||
T ds_sum = 0;
|
||||
T db_sum = 0;
|
||||
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
|
||||
const int index = layout == framework::DataLayout::kNCHW
|
||||
? (j / HxW * C + i) * HxW + j % HxW
|
||||
: j * outer_size + i;
|
||||
ds_sum += dy[index] * x[index];
|
||||
db_sum += dy[index];
|
||||
}
|
||||
ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum());
|
||||
db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum());
|
||||
if (threadIdx.x == 0) {
|
||||
dscale[i] = ds_sum;
|
||||
dbias[i] = db_sum;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AffineChannelGradCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<framework::Tensor>("X");
|
||||
auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto* bias = ctx.Input<framework::Tensor>("Bias");
|
||||
auto* dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
auto* dscale =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("Scale"));
|
||||
auto* dbias = ctx.Output<framework::Tensor>(framework::GradVarName("Bias"));
|
||||
|
||||
const framework::DataLayout layout =
|
||||
framework::StringToDataLayout(ctx.Attr<std::string>("data_layout"));
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
|
||||
auto dims = x->dims();
|
||||
const int num = x->numel();
|
||||
int N = dims[0];
|
||||
int C = layout == framework::DataLayout::kNCHW ? dims[1]
|
||||
: dims[dims.size() - 1];
|
||||
int HxW = num / N / C;
|
||||
|
||||
const T* x_d = x->data<T>();
|
||||
const T* dy_d = dy->data<T>();
|
||||
const T* s_d = scale->data<T>();
|
||||
|
||||
T* dx_d = dx ? dx->mutable_data<T>(ctx.GetPlace()) : nullptr;
|
||||
T* ds_d = dscale ? dscale->mutable_data<T>(ctx.GetPlace()) : nullptr;
|
||||
T* db_d = dbias ? dbias->mutable_data<T>(ctx.GetPlace()) : nullptr;
|
||||
|
||||
const int block = 1024;
|
||||
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
|
||||
const int max_blocks = std::max(max_threads / block, 1);
|
||||
int grid1 = (num + block - 1) / block;
|
||||
int grid2 = std::min(C, max_blocks);
|
||||
if (layout == framework::DataLayout::kNCHW) {
|
||||
if (dx) {
|
||||
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
|
||||
false><<<grid1, block, 0, dev_ctx.stream()>>>(
|
||||
dy_d, s_d, nullptr, C, HxW, num, dx_d);
|
||||
}
|
||||
if (dscale && dbias) {
|
||||
AffineChannelScaleBiasGradientCUDAKernel<
|
||||
T, block, framework::DataLayout::kNCHW><<<grid2, block, 0,
|
||||
dev_ctx.stream()>>>(
|
||||
dy_d, x_d, N, C, HxW, ds_d, db_d);
|
||||
}
|
||||
} else {
|
||||
if (dx) {
|
||||
KeAffineChannelCUDA<T, framework::DataLayout::kNCHW,
|
||||
false><<<grid1, block, 0, dev_ctx.stream()>>>(
|
||||
dy_d, s_d, nullptr, C, HxW, num, dx_d);
|
||||
}
|
||||
if (dscale && dbias) {
|
||||
AffineChannelScaleBiasGradientCUDAKernel<
|
||||
T, block, framework::DataLayout::kNHWC><<<grid2, block, 0,
|
||||
dev_ctx.stream()>>>(
|
||||
dy_d, x_d, N, C, HxW, ds_d, db_d);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(affine_channel,
|
||||
ops::AffineChannelCUDAKernel<CUDA, float>,
|
||||
ops::AffineChannelCUDAKernel<CUDA, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(affine_channel_grad,
|
||||
ops::AffineChannelGradCUDAKernel<CUDA, float>,
|
||||
ops::AffineChannelGradCUDAKernel<CUDA, double>);
|
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
def affine_channel(x, scale, bias, layout):
|
||||
C = x.shape[1] if layout == 'NCHW' else x.shape[-1]
|
||||
if len(x.shape) == 4:
|
||||
new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C)
|
||||
else:
|
||||
new_shape = (1, C)
|
||||
scale = scale.reshape(new_shape)
|
||||
bias = bias.reshape(new_shape)
|
||||
return x * scale + bias
|
||||
|
||||
|
||||
class TestAffineChannelOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "affine_channel"
|
||||
self.init_test_case()
|
||||
|
||||
x = np.random.random(self.shape).astype("float32")
|
||||
scale = np.random.random(self.C).astype("float32")
|
||||
bias = np.random.random(self.C).astype("float32")
|
||||
|
||||
y = affine_channel(x, scale, bias, self.layout)
|
||||
|
||||
self.inputs = {'X': x, 'Scale': scale, 'Bias': bias}
|
||||
self.attrs = {'data_layout': self.layout}
|
||||
self.outputs = {'Out': y}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X', 'Scale', 'Bias'], 'Out')
|
||||
|
||||
def test_check_grad_stopgrad_dx(self):
|
||||
self.check_grad(['Scale', 'Bias'], 'Out', no_grad_set=set('X'))
|
||||
|
||||
def test_check_grad_stopgrad_dscale_dbias(self):
|
||||
self.check_grad(['X'], 'Out', no_grad_set=set(['Scale', 'Bias']))
|
||||
|
||||
def init_test_case(self):
|
||||
self.shape = [2, 32, 14, 14]
|
||||
self.C = 32
|
||||
self.layout = 'NCHW'
|
||||
|
||||
|
||||
class TestAffineChannelNHWC(TestAffineChannelOp):
|
||||
def init_test_case(self):
|
||||
self.shape = [2, 14, 14, 32]
|
||||
self.C = 32
|
||||
self.layout = 'NHWC'
|
||||
|
||||
|
||||
class TestAffineChannel2D(TestAffineChannelOp):
|
||||
def init_test_case(self):
|
||||
self.shape = [16, 64]
|
||||
self.C = 64
|
||||
self.layout = 'NCHW'
|
||||
|
||||
|
||||
class TestAffineChannelNCHWLargeShape(TestAffineChannelOp):
|
||||
def init_test_case(self):
|
||||
self.shape = [64, 128, 112, 112]
|
||||
self.C = 128
|
||||
self.layout = 'NCHW'
|
||||
|
||||
# since the gradient check is very slow in large shape, so skip check_grad
|
||||
def test_check_grad(self):
|
||||
pass
|
||||
|
||||
def test_check_grad_stopgrad_dx(self):
|
||||
pass
|
||||
|
||||
def test_check_grad_stopgrad_dscale_dbias(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAffineChannelNCHWLargeShape(TestAffineChannelNCHWLargeShape):
|
||||
def init_test_case(self):
|
||||
self.shape = [64, 112, 112, 512]
|
||||
self.C = 512
|
||||
self.layout = 'NHWC'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue