[NPU] Support npu op layer_norm and layer_norm_grad (#31310)
* init commit, add layer_norm npu kernel * fix typo * add unittest * add unittest * fix bug * fix bug * refine utrevert-31562-mean
parent
45765d6eb6
commit
0310945f5c
@ -0,0 +1,195 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/layer_norm_op.h"
|
||||
#include "paddle/fluid/operators/npu_op_runner.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename T>
|
||||
class LayerNormNPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
||||
const auto epsilon = ctx.Attr<float>("epsilon");
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* bias = ctx.Input<Tensor>("Bias");
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
auto* mean = ctx.Output<Tensor>("Mean");
|
||||
auto* variance = ctx.Output<Tensor>("Variance");
|
||||
const auto& x_dims = x->dims();
|
||||
std::vector<int> axes;
|
||||
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
||||
int right = static_cast<int>(matrix_dim[1]);
|
||||
|
||||
// The shape of scale and bias should be equal to x.shape[begin_norm_axis:],
|
||||
// required by Ascend.
|
||||
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
|
||||
axes.push_back(x_dims[i]);
|
||||
}
|
||||
auto place = ctx.GetPlace();
|
||||
auto stream =
|
||||
ctx.template device_context<paddle::platform::NPUDeviceContext>()
|
||||
.stream();
|
||||
|
||||
Tensor default_scale(x->type());
|
||||
if (!scale) {
|
||||
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
|
||||
Tensor value(x->type());
|
||||
value.mutable_data<T>({1}, place);
|
||||
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
|
||||
ctx.device_context(), &value);
|
||||
auto runner =
|
||||
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
|
||||
runner.Run(stream);
|
||||
scale = &default_scale;
|
||||
} else {
|
||||
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
|
||||
}
|
||||
|
||||
Tensor default_bias(x->type());
|
||||
if (!bias) {
|
||||
default_bias.mutable_data<T>(framework::make_ddim(axes), place);
|
||||
Tensor value(x->type());
|
||||
value.mutable_data<T>({1}, place);
|
||||
TensorFromVector(std::vector<T>{static_cast<T>(0)}, ctx.device_context(),
|
||||
&value);
|
||||
auto runner =
|
||||
NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}});
|
||||
runner.Run(stream);
|
||||
bias = &default_bias;
|
||||
} else {
|
||||
const_cast<Tensor*>(bias)->Resize(framework::make_ddim(axes));
|
||||
}
|
||||
y->mutable_data<T>(ctx.GetPlace());
|
||||
mean->mutable_data<T>(ctx.GetPlace());
|
||||
variance->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto runner =
|
||||
NpuOpRunner("LayerNorm", {*x, *scale, *bias}, {*y, *mean, *variance},
|
||||
{{"begin_norm_axis", begin_norm_axis},
|
||||
{"begin_params_axis", begin_norm_axis},
|
||||
{"epsilon", epsilon}});
|
||||
runner.Run(stream);
|
||||
// revert shape of scale and bias
|
||||
// TODO(zhiqiu): better implementation, use tmp tensor to avoid write input
|
||||
// tensor.
|
||||
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
|
||||
const_cast<Tensor*>(bias)->Resize(framework::make_ddim({right}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LayerNormGradNPUKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
const auto& x_dims = x->dims();
|
||||
const auto* mean = ctx.Input<Tensor>("Mean");
|
||||
const auto* variance = ctx.Input<Tensor>("Variance");
|
||||
const auto* scale = ctx.Input<Tensor>("Scale");
|
||||
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
||||
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
||||
|
||||
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
|
||||
int right = static_cast<int>(matrix_dim[1]);
|
||||
|
||||
std::vector<int> axes;
|
||||
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
|
||||
axes.push_back(x_dims[i]);
|
||||
}
|
||||
|
||||
auto place = ctx.GetPlace();
|
||||
auto stream =
|
||||
ctx.template device_context<paddle::platform::NPUDeviceContext>()
|
||||
.stream();
|
||||
|
||||
// No need to compute any gradient, jusr return
|
||||
if (!dx && !dscale && !dbias) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The rank of mean should be equal to x, required by Ascend.
|
||||
std::vector<int> new_shape;
|
||||
for (auto i = 0; i < begin_norm_axis; ++i) {
|
||||
new_shape.push_back(x_dims[i]);
|
||||
}
|
||||
for (auto i = begin_norm_axis; i < x_dims.size(); ++i) {
|
||||
new_shape.push_back(1);
|
||||
}
|
||||
|
||||
auto mean_dims = mean->dims();
|
||||
const_cast<Tensor*>(mean)->Resize(framework::make_ddim({new_shape}));
|
||||
const_cast<Tensor*>(variance)->Resize(framework::make_ddim({new_shape}));
|
||||
|
||||
Tensor default_scale(x->type());
|
||||
if (!scale) {
|
||||
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
|
||||
Tensor value(x->type());
|
||||
value.mutable_data<T>({1}, place);
|
||||
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
|
||||
ctx.device_context(), &value);
|
||||
auto runner =
|
||||
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
|
||||
runner.Run(stream);
|
||||
scale = &default_scale;
|
||||
} else {
|
||||
const_cast<Tensor*>(scale)->Resize(framework::make_ddim(axes));
|
||||
}
|
||||
|
||||
Tensor dx_(dy->type()), dscale_(dy->type()), dbias_(dy->type());
|
||||
dx = (dx == nullptr) ? &dx_ : dx;
|
||||
dscale = (dscale == nullptr) ? &dscale_ : dscale;
|
||||
dbias = (dbias == nullptr) ? &dbias_ : dbias;
|
||||
|
||||
dscale->Resize(framework::make_ddim(axes));
|
||||
dscale->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
dbias->Resize(framework::make_ddim(axes));
|
||||
dbias->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
dx->Resize(x->dims());
|
||||
dx->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto runner =
|
||||
NpuOpRunner("LayerNormGrad", {*dy, *x, *variance, *mean, *scale},
|
||||
{*dx, *dscale, *dbias}, {});
|
||||
runner.Run(stream);
|
||||
|
||||
const_cast<Tensor*>(mean)->Resize(mean_dims);
|
||||
const_cast<Tensor*>(variance)->Resize(mean_dims);
|
||||
const_cast<Tensor*>(scale)->Resize(framework::make_ddim({right}));
|
||||
dscale->Resize(framework::make_ddim({right}));
|
||||
dbias->Resize(framework::make_ddim({right}));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
|
||||
REGISTER_OP_NPU_KERNEL(layer_norm, ops::LayerNormNPUKernel<float>,
|
||||
ops::LayerNormNPUKernel<plat::float16>);
|
||||
REGISTER_OP_NPU_KERNEL(layer_norm_grad, ops::LayerNormGradNPUKernel<float>,
|
||||
ops::LayerNormGradNPUKernel<plat::float16>);
|
@ -0,0 +1,191 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from op_test import OpTest
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.core as core
|
||||
from test_layer_norm_op import _reference_layer_norm_naive, _reference_layer_norm_grad
|
||||
|
||||
paddle.enable_static()
|
||||
|
||||
SEED = 2021
|
||||
EPOCH = 100
|
||||
|
||||
from op_test import _set_use_system_allocator
|
||||
|
||||
_set_use_system_allocator(False)
|
||||
|
||||
|
||||
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||
"core is not compiled with NPU")
|
||||
class TestLayerNormOp(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.use_cudnn = True
|
||||
self.set_npu()
|
||||
self.init_dtype()
|
||||
|
||||
def set_npu(self):
|
||||
self.__class__.use_npu = True
|
||||
self.place = paddle.NPUPlace(0)
|
||||
|
||||
def init_dtype(self):
|
||||
self.dtype = np.float32
|
||||
|
||||
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
|
||||
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
|
||||
|
||||
def check_forward_backward(self,
|
||||
shape,
|
||||
begin_norm_axis,
|
||||
has_scale=True,
|
||||
has_bias=True,
|
||||
y_grad_scale=1.0,
|
||||
use_mkldnn=False):
|
||||
def test_with_place(place,
|
||||
shape,
|
||||
begin_norm_axis,
|
||||
use_mkldnn=use_mkldnn):
|
||||
# attr
|
||||
epsilon = 0.00001
|
||||
x_shape = shape
|
||||
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
|
||||
scale_shape = [D]
|
||||
|
||||
np.random.seed(123)
|
||||
x = np.random.random_sample(x_shape).astype(np.float32)
|
||||
scale = np.random.random_sample(scale_shape).astype(
|
||||
np.float32) if has_scale else None
|
||||
bias = np.random.random_sample(scale_shape).astype(
|
||||
np.float32) if has_bias else None
|
||||
y_grad = (np.random.random_sample(x_shape) *
|
||||
y_grad_scale).astype(np.float32)
|
||||
|
||||
# reference forward & backward
|
||||
y, mean, variance = _reference_layer_norm_naive(
|
||||
x, scale, bias, epsilon, begin_norm_axis)
|
||||
x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
|
||||
x, y_grad, scale, bias, mean, variance, begin_norm_axis)
|
||||
|
||||
var_dict = locals()
|
||||
var_dict['y@GRAD'] = y_grad
|
||||
var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD']
|
||||
if has_scale:
|
||||
var_names += ['scale']
|
||||
if has_bias:
|
||||
var_names += ['bias']
|
||||
ground_truth = {name: var_dict[name] for name in var_names}
|
||||
|
||||
program = fluid.Program()
|
||||
with fluid.program_guard(program):
|
||||
block = program.global_block()
|
||||
for name in ground_truth:
|
||||
block.create_var(
|
||||
name=name,
|
||||
dtype='float32',
|
||||
shape=ground_truth[name].shape)
|
||||
inputs = {"X": block.var('x')}
|
||||
fetch_list = [
|
||||
'y',
|
||||
'mean',
|
||||
'variance',
|
||||
'x@GRAD',
|
||||
]
|
||||
if has_scale:
|
||||
inputs["Scale"] = block.var('scale')
|
||||
fetch_list += ['scale@GRAD']
|
||||
if has_bias:
|
||||
inputs["Bias"] = block.var('bias')
|
||||
fetch_list += ['bias@GRAD']
|
||||
layer_norm_op = block.append_op(
|
||||
type="layer_norm",
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"Y": block.var('y'),
|
||||
"Mean": block.var('mean'), # share the same memory
|
||||
"Variance":
|
||||
block.var('variance'), # share the same memory
|
||||
},
|
||||
attrs={
|
||||
"epsilon": epsilon,
|
||||
"begin_norm_axis": begin_norm_axis,
|
||||
"use_mkldnn": use_mkldnn
|
||||
})
|
||||
# generate backward op_desc
|
||||
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
|
||||
layer_norm_op.desc, set(), [])
|
||||
grad_op_desc = grad_op_desc_list[0]
|
||||
new_op_desc = block.desc.append_op()
|
||||
new_op_desc.copy_from(grad_op_desc)
|
||||
for var_name in grad_op_desc.output_arg_names():
|
||||
block.desc.var(var_name.encode("ascii"))
|
||||
grad_op_desc.infer_var_type(block.desc)
|
||||
grad_op_desc.infer_shape(block.desc)
|
||||
for arg in grad_op_desc.output_arg_names():
|
||||
grad_var = block.desc.find_var(arg.encode("ascii"))
|
||||
grad_var.set_dtype(core.VarDesc.VarType.FP32)
|
||||
|
||||
program._sync_with_cpp()
|
||||
exe = fluid.Executor(place)
|
||||
out = exe.run(program,
|
||||
feed={
|
||||
name: var_dict[name]
|
||||
for name in ['x', 'scale', 'bias', 'y@GRAD']
|
||||
},
|
||||
fetch_list=fetch_list)
|
||||
self.__assert_close(y, out[0], "y")
|
||||
self.__assert_close(mean, out[1], "mean")
|
||||
self.__assert_close(variance, out[2], "variance", 1e-3)
|
||||
self.__assert_close(x_grad, out[3], "x_grad", 1e-2)
|
||||
if has_scale:
|
||||
self.__assert_close(scale_grad,
|
||||
out[fetch_list.index('scale@GRAD')],
|
||||
"scale_grad", 1e-3)
|
||||
if has_bias:
|
||||
self.__assert_close(bias_grad,
|
||||
out[fetch_list.index('bias@GRAD')],
|
||||
"bias_grad")
|
||||
|
||||
test_with_place(self.place, shape, begin_norm_axis)
|
||||
|
||||
def test_check_forward_backward_with_scale_and_bias(self):
|
||||
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
|
||||
self.check_forward_backward(
|
||||
shape=[2, 3, 4, 5],
|
||||
begin_norm_axis=1,
|
||||
has_scale=False,
|
||||
has_bias=True)
|
||||
self.check_forward_backward(
|
||||
shape=[2, 3, 4, 5],
|
||||
begin_norm_axis=1,
|
||||
has_scale=True,
|
||||
has_bias=False)
|
||||
self.check_forward_backward(
|
||||
shape=[2, 3, 4, 5],
|
||||
begin_norm_axis=1,
|
||||
has_scale=False,
|
||||
has_bias=False)
|
||||
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue