Add batch_norm and layer_norm XPU kernels (#27818)

my_2.0rc
hong19860320 4 years ago committed by GitHub
parent ddcd1b5381
commit c90d35564b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,167 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/batch_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class BatchNormXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto momentum = ctx.Attr<float>("momentum");
const auto is_test = ctx.Attr<bool>("is_test");
const auto use_global_stats = ctx.Attr<bool>("use_global_stats");
const auto trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
bool global_stats = test_mode || use_global_stats;
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But "
"recevived 'data_layout' is [%s].",
data_layout_str));
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But "
"received X's shape = [%s], X's dimension = [%d].",
x_dims, x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
const auto* x_data = x->data<T>();
const auto* scale_data = scale->data<T>();
const auto* bias_data = bias->data<T>();
auto* y = ctx.Output<Tensor>("Y");
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
if (!global_stats) {
auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
auto* mean_out_data = mean_out->data<T>();
auto* variance_out_data = variance_out->data<T>();
auto* saved_mean_data = saved_mean->data<T>();
auto* saved_variance_data = saved_variance->data<T>();
int r = xpu::batch_norm_train_forward(
dev_ctx.x_context(), epsilon, momentum, N, C, H, W, x_data, y_data,
scale_data, bias_data, mean_out_data, variance_out_data,
saved_mean_data, saved_variance_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_train_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
} else {
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
int r = xpu::batch_norm_infer_forward(
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data,
bias_data, mean_data, variance_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
}
};
template <typename DeviceContext, typename T>
class BatchNormGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<Tensor>("X");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator
const auto* saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But "
"recevived 'data_layout' is [%s].",
data_layout_str));
const auto& x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But "
"received X's shape = [%s], X's dimension = [%d].",
x_dims, x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const auto* x_data = x->data<T>();
const auto* dy_data = dy->data<T>();
const auto* scale_data = scale->data<T>();
const auto* saved_mean_data = saved_mean->data<T>();
const auto* saved_inv_variance_data = saved_inv_variance->data<T>();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dscale_data = dscale->mutable_data<T>(ctx.GetPlace());
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data,
dy_data, scale_data, saved_mean_data,
saved_inv_variance_data, dx_data,
dscale_data, dbias_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
batch_norm,
ops::BatchNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU

@ -0,0 +1,114 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/layer_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class LayerNormXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* variance = ctx.Output<Tensor>("Variance");
const auto* x_data = x->data<T>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
const auto* bias_data = (bias == nullptr ? nullptr : bias->data<T>());
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace());
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data,
scale_data, bias_data, epsilon, mean_data,
variance_data, false);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
}
};
template <typename DeviceContext, typename T>
class LayerNormGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto* x_data = x->data<T>();
const auto* dy_data = dy->data<T>();
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto* dscale_data =
(dscale == nullptr ? nullptr : dscale->mutable_data<T>(ctx.GetPlace()));
auto* dbias_data =
(dbias == nullptr ? nullptr : dbias->mutable_data<T>(ctx.GetPlace()));
auto* dx_data =
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm_backward(
dev_ctx.x_context(), left, right, x_data, scale_data, variance_data,
mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm_backward) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
layer_norm,
ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU

@ -0,0 +1,111 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import sys
import unittest
from functools import reduce
sys.path.append("..")
from op_test import OpTest
from operator import mul
paddle.enable_static()
def ref_layer_norm(x, scale, bias, epsilon, begin_norm_axis=1):
x_shape = x.shape
left = reduce(mul, x_shape[0:begin_norm_axis], 1)
right = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
x.shape = [left, right]
mean = np.mean(x, axis=1)
variance = np.var(x, axis=1) + epsilon
y = np.divide((x - mean.reshape([left, 1])),
(np.sqrt(variance)).reshape([left, 1]))
if scale is not None:
y = scale.reshape([1, right]) * y
if bias is not None:
y = y + bias.reshape([1, right])
x.shape, y.shape = x_shape, x_shape
return y, mean, variance
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp(OpTest):
def setUp(self):
self.op_type = "layer_norm"
self.dtype = np.float32
self.shape = [2, 3, 4, 5]
self.epsilon = 1e-05
self.begin_norm_axis = 1
self.set_attrs()
right = reduce(mul, self.shape[self.begin_norm_axis:len(self.shape)], 1)
np.random.seed(10)
x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
scale_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
bias_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm(
x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis)
self.inputs = {'X': x_np, 'Scale': scale_np, 'Bias': bias_np}
self.outputs = {
'Y': ref_y_np,
'Mean': ref_mean_np,
'Variance': ref_variance_np
}
self.attrs = {'begin_norm_axis': self.begin_norm_axis, 'use_xpu': True}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0), atol=1e-4)
def test_check_grad(self):
self.check_grad_with_place(
paddle.XPUPlace(0), ['X'], 'Y', max_relative_error=0.02)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOpAxis2(TestXPULayerNormOp):
def set_attrs(self):
self.begin_norm_axis = 2
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOpAxis3(TestXPULayerNormOp):
def set_attrs(self):
self.begin_norm_axis = 3
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp2D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [10, 12]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp3D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [4, 5, 6]
if __name__ == "__main__":
unittest.main()
Loading…
Cancel
Save