lamb_op_xpu;test=kunlun (#31012)
	
		
	
				
					
				
			* lamb_op_xpu;test=kunlun * modify lamb_op_xpu.cc;test=kunlun * delete atol lamb_op_xpu; test=kunlun * update xpu.cmake;test=kunlun * test_error 1e-5,lamb_op_xpu;test=kunlun * error1e-5,lamb_op_xpu,test=kunlun * delete atol lamb_xpu;test=kunlun * modify atol,lamb_op_xpy;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu, XPUOptest;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu,modify xpu_cmake; test=kunlun * lamb_op_xpu;test=kunlun * lamb_op_xpu,modify xpucmake;test=kunluntest_model_benchmark_ci
							parent
							
								
									d1075df2e8
								
							
						
					
					
						commit
						d79fdc3d62
					
				@ -0,0 +1,125 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/optimizers/lamb_op.h"
 | 
				
			||||
#include "gflags/gflags.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
using Tensor = framework::Tensor;
 | 
				
			||||
 | 
				
			||||
#ifdef PADDLE_WITH_XPU
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class LambOpXPUKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    using paddle::framework::LoDTensor;
 | 
				
			||||
    const auto* param_var = ctx.InputVar("Param");
 | 
				
			||||
    PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The Var(%s)'s type should be LoDTensor, "
 | 
				
			||||
                          "but the received is %s",
 | 
				
			||||
                          ctx.InputNames("Param").front(),
 | 
				
			||||
                          framework::ToTypeName(param_var->Type())));
 | 
				
			||||
 | 
				
			||||
    using paddle::framework::LoDTensor;
 | 
				
			||||
 | 
				
			||||
    // inputs
 | 
				
			||||
    T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
 | 
				
			||||
    T weight_decay = static_cast<T>(ctx.Attr<float>("weight_decay"));
 | 
				
			||||
    T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
 | 
				
			||||
    T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
 | 
				
			||||
    auto& param = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Param"), "Input",
 | 
				
			||||
                                  "Param", "Lamb");
 | 
				
			||||
    auto* grad_var = ctx.InputVar("Grad");
 | 
				
			||||
    auto& mom1 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment1"), "Input",
 | 
				
			||||
                                 "Moment1", "Lamb");
 | 
				
			||||
    auto& mom2 = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Moment2"), "Input",
 | 
				
			||||
                                 "Moment2", "Lamb");
 | 
				
			||||
    auto& lr = GET_DATA_SAFELY(ctx.Input<LoDTensor>("LearningRate"), "Input",
 | 
				
			||||
                               "LearningRate", "Lamb");
 | 
				
			||||
 | 
				
			||||
    auto& beta1_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta1Pow"), "Input",
 | 
				
			||||
                                      "Beta1Pow", "Lamb");
 | 
				
			||||
    auto& beta2_pow = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Beta2Pow"), "Input",
 | 
				
			||||
                                      "Beta2Pow", "Lamb");
 | 
				
			||||
 | 
				
			||||
    auto& param_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("ParamOut"),
 | 
				
			||||
                                      "Output", "ParamOut", "Lamb");
 | 
				
			||||
    auto& mom1_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment1Out"),
 | 
				
			||||
                                     "Output", "Moment1Out", "Lamb");
 | 
				
			||||
    auto& mom2_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Moment2Out"),
 | 
				
			||||
                                     "Output", "Moment2Out", "Lamb");
 | 
				
			||||
    auto& beta1_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta1PowOut"),
 | 
				
			||||
                                          "Output", "Beta1PowOut", "Lamb");
 | 
				
			||||
    auto& beta2_pow_out = GET_DATA_SAFELY(ctx.Output<LoDTensor>("Beta2PowOut"),
 | 
				
			||||
                                          "Output", "Beta2PowOut", "Lamb");
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
 | 
				
			||||
    if (grad_var->IsType<framework::LoDTensor>()) {
 | 
				
			||||
      auto& grad = *ctx.Input<LoDTensor>("Grad");
 | 
				
			||||
      int r = xpu::lamb(dev_ctx.x_context(), grad.template data<T>(),
 | 
				
			||||
                        mom1.template data<T>(), mom2.template data<T>(),
 | 
				
			||||
                        param.template data<T>(), beta1_pow.template data<T>(),
 | 
				
			||||
                        beta2_pow.template data<T>(), beta1, beta2, epsilon,
 | 
				
			||||
                        weight_decay, lr.template data<T>(),
 | 
				
			||||
                        mom1_out.template mutable_data<T>(ctx.GetPlace()),
 | 
				
			||||
                        mom2_out.template mutable_data<T>(ctx.GetPlace()),
 | 
				
			||||
                        param_out.template mutable_data<T>(ctx.GetPlace()),
 | 
				
			||||
                        beta1_pow_out.template mutable_data<T>(ctx.GetPlace()),
 | 
				
			||||
                        beta2_pow_out.template mutable_data<T>(ctx.GetPlace()),
 | 
				
			||||
                        param.numel());
 | 
				
			||||
 | 
				
			||||
      if (r == xpu::Error_t::INVALID_PARAM) {
 | 
				
			||||
        PADDLE_ENFORCE_EQ(
 | 
				
			||||
            r, xpu::Error_t::SUCCESS,
 | 
				
			||||
            platform::errors::InvalidArgument(
 | 
				
			||||
                "XPU kernel error of LambOp, error message: INVALID_PARAM, "
 | 
				
			||||
                "please check your input & output."));
 | 
				
			||||
      } else if (r == xpu::Error_t::RUNTIME_ERROR) {
 | 
				
			||||
        PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
 | 
				
			||||
                          platform::errors::Unavailable(
 | 
				
			||||
                              "XPU kernel error of LambOp, error message: "
 | 
				
			||||
                              "RUNTIME_ERROR, please check whether Baidu "
 | 
				
			||||
                              "Kunlun Card is properly installed."));
 | 
				
			||||
      } else if (r == xpu::Error_t::NO_ENOUGH_WORKSPACE) {
 | 
				
			||||
        PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
 | 
				
			||||
                          platform::errors::ResourceExhausted(
 | 
				
			||||
                              "XPU kernel error of LambOp, error "
 | 
				
			||||
                              "message: NO_ENOUGH_WORKSPACE, XPU "
 | 
				
			||||
                              "has no enough memory."));
 | 
				
			||||
      } else {
 | 
				
			||||
        PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
 | 
				
			||||
                          platform::errors::ResourceExhausted(
 | 
				
			||||
                              "XPU kernel error of LambOp, error "
 | 
				
			||||
                              "message: OTHER "
 | 
				
			||||
                              "XPU API returns error code: %d.",
 | 
				
			||||
                              r));
 | 
				
			||||
      }
 | 
				
			||||
    } else {
 | 
				
			||||
      PADDLE_THROW(platform::errors::InvalidArgument(
 | 
				
			||||
          "Variable type not supported by lamb_op. Expect LoDTensor, "
 | 
				
			||||
          "but got %s",
 | 
				
			||||
          framework::ToTypeName(param_var->Type())));
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
REGISTER_OP_XPU_KERNEL(
 | 
				
			||||
    lamb, ops::LambOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
 | 
				
			||||
#endif
 | 
				
			||||
@ -0,0 +1,121 @@
 | 
				
			||||
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
import sys
 | 
				
			||||
sys.path.append("..")
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
from op_test_xpu import XPUOpTest
 | 
				
			||||
from paddle.fluid import core
 | 
				
			||||
from paddle.fluid.op import Operator
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
import paddle
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestLambOp1(XPUOpTest):
 | 
				
			||||
    def set_attrs(self):
 | 
				
			||||
        self.attrs = {
 | 
				
			||||
            'epsilon': 1e-6,
 | 
				
			||||
            'beta1': 0.9,
 | 
				
			||||
            'beta2': 0.999,
 | 
				
			||||
            'weight_decay': 0.01
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        '''Test Lamb Op with supplied attributes
 | 
				
			||||
        '''
 | 
				
			||||
        self.op_type = "lamb"
 | 
				
			||||
        param = np.random.uniform(-1, 1, 5000).astype("float32")
 | 
				
			||||
        grad = np.random.uniform(-1, 1, 5000).astype("float32")
 | 
				
			||||
        moment1 = np.random.uniform(-1, 1, 5000).astype("float32")
 | 
				
			||||
        moment2 = np.random.random(5000).astype("float32")
 | 
				
			||||
 | 
				
			||||
        self.set_attrs()
 | 
				
			||||
        learning_rate = 0.001
 | 
				
			||||
        beta1_pow = self.attrs['beta1']
 | 
				
			||||
        beta2_pow = self.attrs['beta2']
 | 
				
			||||
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'Param': param,
 | 
				
			||||
            'Grad': grad,
 | 
				
			||||
            'Moment1': moment1,
 | 
				
			||||
            'Moment2': moment2,
 | 
				
			||||
            'LearningRate': np.array([learning_rate]).astype("float32"),
 | 
				
			||||
            'Beta1Pow': np.array([beta1_pow]).astype("float32"),
 | 
				
			||||
            'Beta2Pow': np.array([beta2_pow]).astype("float32")
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
        param_out, moment1_out, moment2_out, \
 | 
				
			||||
            beta1_pow_out, beta2_pow_out = lamb_step(self.inputs, self.attrs)
 | 
				
			||||
 | 
				
			||||
        self.outputs = {
 | 
				
			||||
            'Moment1Out': moment1_out,
 | 
				
			||||
            'Moment2Out': moment2_out,
 | 
				
			||||
            'ParamOut': param_out,
 | 
				
			||||
            'Beta1PowOut': beta1_pow_out,
 | 
				
			||||
            'Beta2PowOut': beta2_pow_out
 | 
				
			||||
        }
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output_with_place(paddle.XPUPlace(0))
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
def lamb_step(inputs, attributes):
 | 
				
			||||
    '''
 | 
				
			||||
    Simulate one step of the lamb optimizer
 | 
				
			||||
    :param inputs: dict of inputs
 | 
				
			||||
    :param attributes: dict of attributes
 | 
				
			||||
    :return tuple: tuple of output param, moment1, moment2,
 | 
				
			||||
    beta1 power accumulator and beta2 power accumulator
 | 
				
			||||
    '''
 | 
				
			||||
    param = inputs['Param']
 | 
				
			||||
    grad = inputs['Grad']
 | 
				
			||||
    moment1 = inputs['Moment1']
 | 
				
			||||
    moment2 = inputs['Moment2']
 | 
				
			||||
    lr = inputs['LearningRate']
 | 
				
			||||
    beta1_pow = inputs['Beta1Pow']
 | 
				
			||||
    beta2_pow = inputs['Beta2Pow']
 | 
				
			||||
 | 
				
			||||
    beta1 = attributes['beta1']
 | 
				
			||||
    beta2 = attributes['beta2']
 | 
				
			||||
    epsilon = attributes['epsilon']
 | 
				
			||||
    weight_decay = attributes['weight_decay']
 | 
				
			||||
 | 
				
			||||
    moment1_out = beta1 * moment1 + (1 - beta1) * grad
 | 
				
			||||
    moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
 | 
				
			||||
 | 
				
			||||
    moment1_unbiased = moment1_out / (1 - beta1_pow)
 | 
				
			||||
    moment2_unbiased = moment2_out / (1 - beta2_pow)
 | 
				
			||||
 | 
				
			||||
    r_1 = np.linalg.norm(param)
 | 
				
			||||
    r_2 = np.linalg.norm(moment1_unbiased / (np.sqrt(moment2_unbiased) + epsilon
 | 
				
			||||
                                             ) + weight_decay * param)
 | 
				
			||||
    if r_1 > 0.0 and r_2 > 0.0:
 | 
				
			||||
        lr_t = lr * r_1 / r_2
 | 
				
			||||
    else:
 | 
				
			||||
        lr_t = 1.0
 | 
				
			||||
 | 
				
			||||
    param_out = param - lr_t * (moment1_unbiased / (
 | 
				
			||||
        np.sqrt(moment2_unbiased) + epsilon) + weight_decay * param)
 | 
				
			||||
 | 
				
			||||
    beta1_pow_out = beta1_pow * beta1
 | 
				
			||||
    beta2_pow_out = beta2_pow * beta2
 | 
				
			||||
 | 
				
			||||
    return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    paddle.enable_static()
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue