[NPU] add npu kernel for adam (#31644)
* add npu kernel for adam * refine code * disable test * modify atolrevert-31562-mean
parent
795b0f92d3
commit
1e956001ec
@ -0,0 +1,161 @@
|
|||||||
|
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/npu_op_runner.h"
|
||||||
|
#include "paddle/fluid/operators/optimizers/adam_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
using LoDTensor = framework::LoDTensor;
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class AdamNPUKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
const auto* param_var = ctx.InputVar("Param");
|
||||||
|
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The Var(%s)'s type should be LoDTensor, "
|
||||||
|
"but the received is %s",
|
||||||
|
ctx.InputNames("Param").front(),
|
||||||
|
framework::ToTypeName(param_var->Type())));
|
||||||
|
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
|
||||||
|
auto* param = ctx.Input<LoDTensor>("Param");
|
||||||
|
auto* grad_var = ctx.InputVar("Grad");
|
||||||
|
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The Grad(%s)'s type should be LoDTensor, "
|
||||||
|
"but the received is %s",
|
||||||
|
ctx.InputNames("Grad").front(),
|
||||||
|
framework::ToTypeName(param_var->Type())));
|
||||||
|
auto* grad = ctx.Input<LoDTensor>("Grad");
|
||||||
|
auto* mom1 = ctx.Input<LoDTensor>("Moment1");
|
||||||
|
auto* mom2 = ctx.Input<LoDTensor>("Moment2");
|
||||||
|
auto* lr = ctx.Input<LoDTensor>("LearningRate");
|
||||||
|
|
||||||
|
auto* beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
|
||||||
|
auto* beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
|
||||||
|
|
||||||
|
auto* param_out = ctx.Output<LoDTensor>("ParamOut");
|
||||||
|
auto* mom1_out = ctx.Output<LoDTensor>("Moment1Out");
|
||||||
|
auto* mom2_out = ctx.Output<LoDTensor>("Moment2Out");
|
||||||
|
auto* beta1_pow_out = ctx.Output<LoDTensor>("Beta1PowOut");
|
||||||
|
auto* beta2_pow_out = ctx.Output<LoDTensor>("Beta2PowOut");
|
||||||
|
|
||||||
|
param_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
mom1_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
mom2_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
||||||
|
if (ctx.HasInput("Beta1Tensor")) {
|
||||||
|
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
|
||||||
|
PADDLE_ENFORCE_EQ(beta1_tensor->numel(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(Beta1Tensor) size must be 1, but get %d",
|
||||||
|
beta1_tensor->numel()));
|
||||||
|
beta1 = static_cast<T>(GetAttrFromTensor(beta1_tensor));
|
||||||
|
}
|
||||||
|
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
|
||||||
|
if (ctx.HasInput("Beta2Tensor")) {
|
||||||
|
auto* beta2_tensor = ctx.Input<framework::Tensor>("Beta2Tensor");
|
||||||
|
PADDLE_ENFORCE_EQ(beta2_tensor->numel(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"Input(Beta2Tensor) size must be 1, but get %d",
|
||||||
|
beta2_tensor->numel()));
|
||||||
|
beta2 = static_cast<T>(GetAttrFromTensor(beta2_tensor));
|
||||||
|
}
|
||||||
|
VLOG(3) << "beta1_pow.numel() : " << beta1_pow->numel()
|
||||||
|
<< "beta2_pow.numel() : " << beta2_pow->numel();
|
||||||
|
VLOG(3) << "param.numel(): " << param->numel();
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"beta1 pow output size should be 1, but received "
|
||||||
|
"value is:%d.",
|
||||||
|
beta1_pow_out->numel()));
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), 1,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"beta2 pow output size should be 1, but received "
|
||||||
|
"value is:%d.",
|
||||||
|
beta2_pow_out->numel()));
|
||||||
|
|
||||||
|
// reshape
|
||||||
|
Tensor beta1_tensor(framework::proto::VarType::FP32);
|
||||||
|
beta1_tensor.mutable_data<float>({1}, ctx.GetPlace());
|
||||||
|
TensorFromVector(std::vector<T>{beta1}, ctx.device_context(),
|
||||||
|
&beta1_tensor);
|
||||||
|
Tensor beta2_tensor(framework::proto::VarType::FP32);
|
||||||
|
beta2_tensor.mutable_data<float>({1}, ctx.GetPlace());
|
||||||
|
TensorFromVector(std::vector<T>{beta2}, ctx.device_context(),
|
||||||
|
&beta2_tensor);
|
||||||
|
|
||||||
|
Tensor epsilon_tensor(framework::proto::VarType::FP32);
|
||||||
|
epsilon_tensor.mutable_data<T>({1}, ctx.GetPlace());
|
||||||
|
TensorFromVector(std::vector<T>{epsilon}, ctx.device_context(),
|
||||||
|
&epsilon_tensor);
|
||||||
|
auto stream =
|
||||||
|
ctx.template device_context<paddle::platform::NPUDeviceContext>()
|
||||||
|
.stream();
|
||||||
|
auto runner =
|
||||||
|
NpuOpRunner("ApplyAdamD",
|
||||||
|
{
|
||||||
|
*param, *mom1, *mom2, *beta1_pow, *beta2_pow, *lr,
|
||||||
|
beta1_tensor, beta2_tensor, epsilon_tensor, *grad,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
*param_out, *mom1_out, *mom2_out,
|
||||||
|
},
|
||||||
|
{});
|
||||||
|
runner.Run(stream);
|
||||||
|
|
||||||
|
// NOTE(zhiqiu): ApplyAdamD updates params inplace, so
|
||||||
|
// if param and param_out is not same, we need to do copy.
|
||||||
|
if (param_out->data<T>() != param->data<T>()) {
|
||||||
|
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
|
||||||
|
framework::TensorCopySync(*param, ctx.GetPlace(), param_out);
|
||||||
|
}
|
||||||
|
if (mom1_out->data<T>() != mom1->data<T>()) {
|
||||||
|
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
|
||||||
|
framework::TensorCopySync(*mom1, ctx.GetPlace(), mom1_out);
|
||||||
|
}
|
||||||
|
if (mom2_out->data<T>() != mom2->data<T>()) {
|
||||||
|
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
|
||||||
|
framework::TensorCopySync(*mom2, ctx.GetPlace(), mom2_out);
|
||||||
|
}
|
||||||
|
auto runner_m1 =
|
||||||
|
NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {});
|
||||||
|
runner_m1.Run(stream);
|
||||||
|
auto runner_m2 =
|
||||||
|
NpuOpRunner("Mul", {*beta2_pow, beta2_tensor}, {*beta2_pow_out}, {});
|
||||||
|
runner_m2.Run(stream);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OP_NPU_KERNEL(
|
||||||
|
adam, ops::AdamNPUKernel<paddle::platform::NPUDeviceContext, float>,
|
||||||
|
ops::AdamNPUKernel<paddle::platform::NPUDeviceContext,
|
||||||
|
paddle::platform::float16>);
|
@ -0,0 +1,148 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
sys.path.append("..")
|
||||||
|
from op_test import OpTest
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from test_adam_op import adam_step
|
||||||
|
|
||||||
|
paddle.enable_static()
|
||||||
|
SEED = 2021
|
||||||
|
|
||||||
|
|
||||||
|
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||||
|
"core is not compiled with NPU")
|
||||||
|
class TestSGD(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.set_npu()
|
||||||
|
self.place = paddle.NPUPlace(0)
|
||||||
|
self.op_type = "adam"
|
||||||
|
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||||
|
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||||
|
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||||
|
# The second moment is positive
|
||||||
|
moment2 = np.random.random((102, 105)).astype("float32")
|
||||||
|
|
||||||
|
learning_rate = 0.004
|
||||||
|
beta1 = 0.78
|
||||||
|
beta2 = 0.836
|
||||||
|
epsilon = 1e-4
|
||||||
|
beta1_pow = beta1**10
|
||||||
|
beta2_pow = beta2**10
|
||||||
|
|
||||||
|
self.inputs = {
|
||||||
|
'Param': param,
|
||||||
|
'Grad': grad,
|
||||||
|
'Moment1': moment1,
|
||||||
|
'Moment2': moment2,
|
||||||
|
'LearningRate': np.array([learning_rate]).astype("float32"),
|
||||||
|
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
|
||||||
|
'Beta2Pow': np.array([beta2_pow]).astype("float32")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
|
||||||
|
|
||||||
|
param_out, moment1_out, \
|
||||||
|
moment2_out = adam_step(self.inputs, self.attrs)
|
||||||
|
|
||||||
|
self.outputs = {
|
||||||
|
'Moment1Out': moment1_out,
|
||||||
|
'Moment2Out': moment2_out,
|
||||||
|
'ParamOut': param_out,
|
||||||
|
'Beta1PowOut': np.array([beta1_pow]).astype("float32") * beta1,
|
||||||
|
'Beta2PowOut': np.array([beta2_pow]).astype("float32") * beta2
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_npu(self):
|
||||||
|
self.__class__.use_npu = True
|
||||||
|
|
||||||
|
def init_dtype(self):
|
||||||
|
self.dtype = np.float32
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output_with_place(self.place, atol=1e-5, check_dygraph=False)
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
# TODO(zhiqiu): The following test may let 0-3 card down.
|
||||||
|
# we need to analyze it and open it.
|
||||||
|
|
||||||
|
@unittest.skipIf(not paddle.is_compiled_with_npu(),
|
||||||
|
"core is not compiled with NPU")
|
||||||
|
class TestNet(unittest.TestCase):
|
||||||
|
def _test(self, run_npu=True):
|
||||||
|
main_prog = paddle.static.Program()
|
||||||
|
startup_prog = paddle.static.Program()
|
||||||
|
main_prog.random_seed = SEED
|
||||||
|
startup_prog.random_seed = SEED
|
||||||
|
np.random.seed(SEED)
|
||||||
|
|
||||||
|
a_np = np.random.random(size=(32, 32)).astype('float32')
|
||||||
|
b_np = np.random.random(size=(32, 32)).astype('float32')
|
||||||
|
label_np = np.random.randint(2, size=(32, 1)).astype('int64')
|
||||||
|
|
||||||
|
with paddle.static.program_guard(main_prog, startup_prog):
|
||||||
|
a = paddle.static.data(name="a", shape=[32, 32], dtype='float32')
|
||||||
|
b = paddle.static.data(name="b", shape=[32, 32], dtype='float32')
|
||||||
|
label = paddle.static.data(
|
||||||
|
name="label", shape=[32, 1], dtype='int64')
|
||||||
|
|
||||||
|
sum = paddle.add(a, b)
|
||||||
|
z = paddle.pow(sum, 2.0)
|
||||||
|
|
||||||
|
fc_1 = fluid.layers.fc(input=z, size=128)
|
||||||
|
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
|
||||||
|
|
||||||
|
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||||
|
loss = fluid.layers.reduce_mean(cost)
|
||||||
|
adam = fluid.optimizer.Adam(learning_rate=0.01)
|
||||||
|
adam.minimize(loss)
|
||||||
|
|
||||||
|
if run_npu:
|
||||||
|
place = paddle.NPUPlace(0)
|
||||||
|
else:
|
||||||
|
place = paddle.CPUPlace()
|
||||||
|
|
||||||
|
exe = paddle.static.Executor(place)
|
||||||
|
exe.run(startup_prog)
|
||||||
|
|
||||||
|
print("Start run on {}".format(place))
|
||||||
|
for epoch in range(100):
|
||||||
|
|
||||||
|
pred_res, loss_res = exe.run(
|
||||||
|
main_prog,
|
||||||
|
feed={"a": a_np,
|
||||||
|
"b": b_np,
|
||||||
|
"label": label_np},
|
||||||
|
fetch_list=[prediction, loss])
|
||||||
|
if epoch % 10 == 0:
|
||||||
|
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
|
||||||
|
epoch, pred_res[0], loss_res))
|
||||||
|
|
||||||
|
return pred_res, loss_res
|
||||||
|
|
||||||
|
def test_npu(self):
|
||||||
|
cpu_pred, cpu_loss = self._test(False)
|
||||||
|
npu_pred, npu_loss = self._test(True)
|
||||||
|
|
||||||
|
self.assertTrue(np.allclose(npu_pred, cpu_pred))
|
||||||
|
self.assertTrue(np.allclose(npu_loss, cpu_loss))
|
||||||
|
'''
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue