Add batch_norm and layer_norm XPU kernels (#27818)
parent
ddcd1b5381
commit
c90d35564b
@ -0,0 +1,167 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
|
||||
#include "paddle/fluid/operators/batch_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BatchNormXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto epsilon = ctx.Attr<float>("epsilon");
|
||||
const auto momentum = ctx.Attr<float>("momentum");
|
||||
const auto is_test = ctx.Attr<bool>("is_test");
|
||||
const auto use_global_stats = ctx.Attr<bool>("use_global_stats");
|
||||
const auto trainable_stats = ctx.Attr<bool>("trainable_statistics");
|
||||
bool test_mode = is_test && (!trainable_stats);
|
||||
bool global_stats = test_mode || use_global_stats;
|
||||
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
|
||||
const auto data_layout = framework::StringToDataLayout(data_layout_str);
|
||||
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
|
||||
platform::errors::InvalidArgument(
|
||||
"The 'data_layout' attribute must be NCHW. But "
|
||||
"recevived 'data_layout' is [%s].",
|
||||
data_layout_str));
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto& x_dims = x->dims();
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor X's dimension must equal to 4. But "
|
||||
"received X's shape = [%s], X's dimension = [%d].",
|
||||
x_dims, x_dims.size()));
|
||||
const int N = x_dims[0];
|
||||
const int C = x_dims[1];
|
||||
const int H = x_dims[2];
|
||||
const int W = x_dims[3];
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* bias = ctx.Input<Tensor>("Bias");
|
||||
const auto* x_data = x->data<T>();
|
||||
const auto* scale_data = scale->data<T>();
|
||||
const auto* bias_data = bias->data<T>();
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
if (!global_stats) {
|
||||
auto* mean_out = ctx.Output<Tensor>("MeanOut");
|
||||
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
|
||||
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
|
||||
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
|
||||
mean_out->mutable_data<T>(ctx.GetPlace());
|
||||
variance_out->mutable_data<T>(ctx.GetPlace());
|
||||
saved_mean->mutable_data<T>(ctx.GetPlace());
|
||||
saved_variance->mutable_data<T>(ctx.GetPlace());
|
||||
auto* mean_out_data = mean_out->data<T>();
|
||||
auto* variance_out_data = variance_out->data<T>();
|
||||
auto* saved_mean_data = saved_mean->data<T>();
|
||||
auto* saved_variance_data = saved_variance->data<T>();
|
||||
int r = xpu::batch_norm_train_forward(
|
||||
dev_ctx.x_context(), epsilon, momentum, N, C, H, W, x_data, y_data,
|
||||
scale_data, bias_data, mean_out_data, variance_out_data,
|
||||
saved_mean_data, saved_variance_data);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API(batch_norm_train_forward) return "
|
||||
"wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
} else {
|
||||
const auto* mean = ctx.Input<Tensor>("Mean");
|
||||
const auto* variance = ctx.Input<Tensor>("Variance");
|
||||
const auto* mean_data = mean->data<T>();
|
||||
const auto* variance_data = variance->data<T>();
|
||||
int r = xpu::batch_norm_infer_forward(
|
||||
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data,
|
||||
bias_data, mean_data, variance_data);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API(batch_norm_infer_forward) return "
|
||||
"wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class BatchNormGradXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* saved_mean = ctx.Input<Tensor>("SavedMean");
|
||||
// SavedVariance have been reverted in forward operator
|
||||
const auto* saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
|
||||
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
|
||||
const auto data_layout = framework::StringToDataLayout(data_layout_str);
|
||||
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
|
||||
platform::errors::InvalidArgument(
|
||||
"The 'data_layout' attribute must be NCHW. But "
|
||||
"recevived 'data_layout' is [%s].",
|
||||
data_layout_str));
|
||||
const auto& x_dims = x->dims();
|
||||
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
|
||||
platform::errors::InvalidArgument(
|
||||
"The input tensor X's dimension must equal to 4. But "
|
||||
"received X's shape = [%s], X's dimension = [%d].",
|
||||
x_dims, x_dims.size()));
|
||||
const int N = x_dims[0];
|
||||
const int C = x_dims[1];
|
||||
const int H = x_dims[2];
|
||||
const int W = x_dims[3];
|
||||
const auto* x_data = x->data<T>();
|
||||
const auto* dy_data = dy->data<T>();
|
||||
const auto* scale_data = scale->data<T>();
|
||||
const auto* saved_mean_data = saved_mean->data<T>();
|
||||
const auto* saved_inv_variance_data = saved_inv_variance->data<T>();
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
||||
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
|
||||
auto* dscale_data = dscale->mutable_data<T>(ctx.GetPlace());
|
||||
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data,
|
||||
dy_data, scale_data, saved_mean_data,
|
||||
saved_inv_variance_data, dx_data,
|
||||
dscale_data, dbias_data);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API(batch_norm_infer_forward) return "
|
||||
"wrong value[%d], please check whether "
|
||||
"Baidu Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
batch_norm,
|
||||
ops::BatchNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
batch_norm_grad,
|
||||
ops::BatchNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
|
||||
#endif // PADDLE_WITH_XPU
|
@ -0,0 +1,114 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#ifdef PADDLE_WITH_XPU
|
||||
|
||||
#include "paddle/fluid/operators/layer_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LayerNormXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
||||
const auto epsilon = ctx.Attr<float>("epsilon");
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto& x_dims = x->dims();
|
||||
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
||||
int left = static_cast<int>(matrix_dim[0]);
|
||||
int right = static_cast<int>(matrix_dim[1]);
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* bias = ctx.Input<Tensor>("Bias");
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
auto* mean = ctx.Output<Tensor>("Mean");
|
||||
auto* variance = ctx.Output<Tensor>("Variance");
|
||||
const auto* x_data = x->data<T>();
|
||||
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
|
||||
const auto* bias_data = (bias == nullptr ? nullptr : bias->data<T>());
|
||||
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
|
||||
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace());
|
||||
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace());
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data,
|
||||
scale_data, bias_data, epsilon, mean_data,
|
||||
variance_data, false);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API(layer_norm) return wrong "
|
||||
"value[%d], please check whether Baidu "
|
||||
"Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class LayerNormGradXPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
||||
const auto epsilon = ctx.Attr<float>("epsilon");
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto& x_dims = x->dims();
|
||||
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
||||
int left = static_cast<int>(matrix_dim[0]);
|
||||
int right = static_cast<int>(matrix_dim[1]);
|
||||
const auto* mean = ctx.Input<Tensor>("Mean");
|
||||
const auto* variance = ctx.Input<Tensor>("Variance");
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
||||
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
const auto* x_data = x->data<T>();
|
||||
const auto* dy_data = dy->data<T>();
|
||||
const auto* mean_data = mean->data<T>();
|
||||
const auto* variance_data = variance->data<T>();
|
||||
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
|
||||
auto* dscale_data =
|
||||
(dscale == nullptr ? nullptr : dscale->mutable_data<T>(ctx.GetPlace()));
|
||||
auto* dbias_data =
|
||||
(dbias == nullptr ? nullptr : dbias->mutable_data<T>(ctx.GetPlace()));
|
||||
auto* dx_data =
|
||||
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
int r = xpu::layer_norm_backward(
|
||||
dev_ctx.x_context(), left, right, x_data, scale_data, variance_data,
|
||||
mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon);
|
||||
PADDLE_ENFORCE_EQ(
|
||||
r, XPU_SUCCESS,
|
||||
platform::errors::External("XPU API(layer_norm_backward) return wrong "
|
||||
"value[%d], please check whether Baidu "
|
||||
"Kunlun Card is properly installed.",
|
||||
r));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
layer_norm,
|
||||
ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
REGISTER_OP_XPU_KERNEL(
|
||||
layer_norm_grad,
|
||||
ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
|
||||
|
||||
#endif // PADDLE_WITH_XPU
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,111 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
import sys
|
||||
import unittest
|
||||
from functools import reduce
|
||||
sys.path.append("..")
|
||||
from op_test import OpTest
|
||||
from operator import mul
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
|
||||
def ref_layer_norm(x, scale, bias, epsilon, begin_norm_axis=1):
|
||||
x_shape = x.shape
|
||||
left = reduce(mul, x_shape[0:begin_norm_axis], 1)
|
||||
right = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
|
||||
x.shape = [left, right]
|
||||
mean = np.mean(x, axis=1)
|
||||
variance = np.var(x, axis=1) + epsilon
|
||||
y = np.divide((x - mean.reshape([left, 1])),
|
||||
(np.sqrt(variance)).reshape([left, 1]))
|
||||
if scale is not None:
|
||||
y = scale.reshape([1, right]) * y
|
||||
if bias is not None:
|
||||
y = y + bias.reshape([1, right])
|
||||
x.shape, y.shape = x_shape, x_shape
|
||||
return y, mean, variance
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULayerNormOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "layer_norm"
|
||||
self.dtype = np.float32
|
||||
self.shape = [2, 3, 4, 5]
|
||||
self.epsilon = 1e-05
|
||||
self.begin_norm_axis = 1
|
||||
self.set_attrs()
|
||||
|
||||
right = reduce(mul, self.shape[self.begin_norm_axis:len(self.shape)], 1)
|
||||
np.random.seed(10)
|
||||
x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
|
||||
scale_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
|
||||
bias_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
|
||||
ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm(
|
||||
x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis)
|
||||
|
||||
self.inputs = {'X': x_np, 'Scale': scale_np, 'Bias': bias_np}
|
||||
self.outputs = {
|
||||
'Y': ref_y_np,
|
||||
'Mean': ref_mean_np,
|
||||
'Variance': ref_variance_np
|
||||
}
|
||||
self.attrs = {'begin_norm_axis': self.begin_norm_axis, 'use_xpu': True}
|
||||
|
||||
def set_attrs(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output_with_place(paddle.XPUPlace(0), atol=1e-4)
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad_with_place(
|
||||
paddle.XPUPlace(0), ['X'], 'Y', max_relative_error=0.02)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULayerNormOpAxis2(TestXPULayerNormOp):
|
||||
def set_attrs(self):
|
||||
self.begin_norm_axis = 2
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULayerNormOpAxis3(TestXPULayerNormOp):
|
||||
def set_attrs(self):
|
||||
self.begin_norm_axis = 3
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULayerNormOp2D(TestXPULayerNormOp):
|
||||
def set_attrs(self):
|
||||
self.shape = [10, 12]
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
|
||||
"core is not compiled with XPU")
|
||||
class TestXPULayerNormOp3D(TestXPULayerNormOp):
|
||||
def set_attrs(self):
|
||||
self.shape = [4, 5, 6]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue