Merge branch 'develop' of https://github.com/baidu/Paddle into inference
commit
71261be901
Before Width: | Height: | Size: 58 KiB After Width: | Height: | Size: 56 KiB |
Before Width: | Height: | Size: 50 KiB After Width: | Height: | Size: 49 KiB |
Before Width: | Height: | Size: 32 KiB After Width: | Height: | Size: 30 KiB |
@ -0,0 +1,20 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/adam_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(adam,
|
||||
ops::AdamOpKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,82 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class AdamOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
|
||||
auto moment1_out_tensor = ctx.Output<framework::Tensor>("Moment1Out");
|
||||
auto moment2_out_tensor = ctx.Output<framework::Tensor>("Moment2Out");
|
||||
auto beta1_pow_out_tensor = ctx.Output<framework::Tensor>("Beta1PowOut");
|
||||
auto beta2_pow_out_tensor = ctx.Output<framework::Tensor>("Beta2PowOut");
|
||||
|
||||
param_out_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
beta1_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
beta2_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
float beta1 = ctx.Attr<float>("beta1");
|
||||
float beta2 = ctx.Attr<float>("beta2");
|
||||
float epsilon = ctx.Attr<float>("epsilon");
|
||||
|
||||
auto param = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Param"));
|
||||
auto grad = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Grad"));
|
||||
auto moment1 = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Moment1"));
|
||||
auto moment2 = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Moment2"));
|
||||
auto lr = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("LearningRate"));
|
||||
auto beta1_pow = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Beta1Pow"));
|
||||
auto beta2_pow = framework::EigenVector<T>::Flatten(
|
||||
*ctx.Input<framework::Tensor>("Beta2Pow"));
|
||||
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
|
||||
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
|
||||
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
|
||||
auto beta1_pow_out =
|
||||
framework::EigenVector<T>::Flatten(*beta1_pow_out_tensor);
|
||||
auto beta2_pow_out =
|
||||
framework::EigenVector<T>::Flatten(*beta2_pow_out_tensor);
|
||||
auto place = ctx.GetEigenDevice<Place>();
|
||||
|
||||
moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad;
|
||||
moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square();
|
||||
beta1_pow_out.device(place) = beta1_pow * beta1;
|
||||
beta2_pow_out.device(place) = beta2_pow * beta2;
|
||||
// All of these are tensors of 1 element
|
||||
auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out);
|
||||
// Eigen does not support automatic broadcast
|
||||
// Get dimensions of moment vector to broadcast lr_t
|
||||
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
|
||||
param_out.device(place) =
|
||||
param -
|
||||
lr_t.broadcast(m_dsize) *
|
||||
(moment1_out / (moment2_out.sqrt() + epsilon));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,186 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestAdamOp1(OpTest):
|
||||
def setUp(self):
|
||||
'''Test Adam Op with supplied attributes
|
||||
'''
|
||||
self.op_type = "adam"
|
||||
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
# The second moment is positive
|
||||
moment2 = np.random.random((102, 105)).astype("float32")
|
||||
|
||||
learning_rate = 0.004
|
||||
beta1 = 0.78
|
||||
beta2 = 0.836
|
||||
epsilon = 1e-4
|
||||
beta1_pow = beta1**10
|
||||
beta2_pow = beta2**10
|
||||
|
||||
self.inputs = {
|
||||
'Param': param,
|
||||
'Grad': grad,
|
||||
'Moment1': moment1,
|
||||
'Moment2': moment2,
|
||||
'LearningRate': np.array([learning_rate]).astype("float32"),
|
||||
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
|
||||
'Beta2Pow': np.array([beta2_pow]).astype("float32")
|
||||
}
|
||||
|
||||
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
|
||||
|
||||
param_out, moment1_out, moment2_out, beta1_pow_out, \
|
||||
beta2_pow_out = adam_step(self.inputs, self.attrs)
|
||||
|
||||
self.outputs = {
|
||||
'Moment1Out': moment1_out,
|
||||
'Moment2Out': moment2_out,
|
||||
'Beta1PowOut': beta1_pow_out,
|
||||
'Beta2PowOut': beta2_pow_out,
|
||||
'ParamOut': param_out
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestAdamOp2(OpTest):
|
||||
def setUp(self):
|
||||
'''Test Adam Op with supplied attributes
|
||||
'''
|
||||
self.op_type = "adam"
|
||||
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
# The second moment is positive
|
||||
moment2 = np.random.random((102, 105)).astype("float32")
|
||||
|
||||
learning_rate = 0.001
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
epsilon = 1e-8
|
||||
beta1_pow = beta1**10
|
||||
beta2_pow = beta2**10
|
||||
|
||||
self.inputs = {
|
||||
'Param': param,
|
||||
'Grad': grad,
|
||||
'Moment1': moment1,
|
||||
'Moment2': moment2,
|
||||
'LearningRate': np.array([learning_rate]).astype("float32"),
|
||||
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
|
||||
'Beta2Pow': np.array([beta2_pow]).astype("float32")
|
||||
}
|
||||
|
||||
attributes = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
|
||||
|
||||
param_out, moment1_out, moment2_out, beta1_pow_out, \
|
||||
beta2_pow_out = adam_step(self.inputs, attributes)
|
||||
|
||||
self.outputs = {
|
||||
'Moment1Out': moment1_out,
|
||||
'Moment2Out': moment2_out,
|
||||
'Beta1PowOut': beta1_pow_out,
|
||||
'Beta2PowOut': beta2_pow_out,
|
||||
'ParamOut': param_out
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestAdamOpMultipleSteps(OpTest):
|
||||
def setUp(self):
|
||||
'''Test Adam Operator with supplied attributes
|
||||
'''
|
||||
self.op_type = "adam"
|
||||
self.num_steps = 10
|
||||
|
||||
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||
# The second moment is positive
|
||||
moment2 = np.random.random((102, 105)).astype("float32")
|
||||
|
||||
learning_rate = 0.001
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
epsilon = 1e-8
|
||||
beta1_pow = beta1**10
|
||||
beta2_pow = beta2**10
|
||||
|
||||
self.inputs = {
|
||||
'Param': param,
|
||||
'Grad': grad,
|
||||
'Moment1': moment1,
|
||||
'Moment2': moment2,
|
||||
'LearningRate': np.array([learning_rate]).astype("float32"),
|
||||
'Beta1Pow': np.array([beta1_pow]).astype("float32"),
|
||||
'Beta2Pow': np.array([beta2_pow]).astype("float32")
|
||||
}
|
||||
|
||||
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
|
||||
|
||||
def test_check_output(self):
|
||||
for _ in range(self.num_steps):
|
||||
param_out, moment1_out, moment2_out, beta1_pow_out, \
|
||||
beta2_pow_out = adam_step(self.inputs, self.attrs)
|
||||
|
||||
self.outputs = {
|
||||
'Moment1Out': moment1_out,
|
||||
'Moment2Out': moment2_out,
|
||||
'Beta1PowOut': beta1_pow_out,
|
||||
'Beta2PowOut': beta2_pow_out,
|
||||
'ParamOut': param_out
|
||||
}
|
||||
|
||||
# Verify output for this step
|
||||
self.check_output()
|
||||
|
||||
# Output of this step becomes input for next step
|
||||
self.inputs['Param'] = param_out
|
||||
self.inputs['Moment1'] = moment1_out
|
||||
self.inputs['Moment2'] = moment2_out
|
||||
self.inputs['Beta1Pow'] = beta1_pow_out
|
||||
self.inputs['Beta2Pow'] = beta2_pow_out
|
||||
|
||||
# Randomize gradient for next step
|
||||
self.inputs['Grad'] = np.random.uniform(
|
||||
-1, 1, (102, 105)).astype("float32")
|
||||
|
||||
|
||||
def adam_step(inputs, attributes):
|
||||
'''
|
||||
Simulate one step of the adam optimizer
|
||||
:param inputs: dict of inputs
|
||||
:param attributes: dict of attributes
|
||||
:return tuple: tuple of output param, moment1, moment2,
|
||||
beta1 power accumulator and beta2 power accumulator
|
||||
'''
|
||||
param = inputs['Param']
|
||||
grad = inputs['Grad']
|
||||
moment1 = inputs['Moment1']
|
||||
moment2 = inputs['Moment2']
|
||||
lr = inputs['LearningRate']
|
||||
beta1_pow = inputs['Beta1Pow']
|
||||
beta2_pow = inputs['Beta2Pow']
|
||||
|
||||
beta1 = attributes['beta1']
|
||||
beta2 = attributes['beta2']
|
||||
epsilon = attributes['epsilon']
|
||||
|
||||
moment1_out = beta1 * moment1 + (1 - beta1) * grad
|
||||
moment2_out = beta2 * moment2 + (1 - beta2) * np.square(grad)
|
||||
beta1_pow_out = beta1_pow * beta1
|
||||
beta2_pow_out = beta2_pow * beta2
|
||||
lr_t = lr * np.sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out)
|
||||
param_out = param - lr_t * (moment1_out / (np.sqrt(moment2_out) + epsilon))
|
||||
return param_out, moment1_out, moment2_out, beta1_pow_out, beta2_pow_out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,76 @@
|
||||
import unittest
|
||||
from paddle.v2.framework.framework import Variable, g_program
|
||||
import paddle.v2.framework.core as core
|
||||
|
||||
|
||||
class TestOperator(unittest.TestCase):
|
||||
def test_error_type(self):
|
||||
block = g_program.create_block()
|
||||
try:
|
||||
block.append_op()
|
||||
self.assertFail()
|
||||
except ValueError as v_err:
|
||||
self.assertEqual(
|
||||
v_err.message,
|
||||
"`type` to initilized an Operator can not be None.")
|
||||
try:
|
||||
block.append_op(type="no_such_op")
|
||||
self.assertFail()
|
||||
except AssertionError as a_err:
|
||||
self.assertEqual(a_err.message,
|
||||
"Operator \"no_such_op\" has not been registered.")
|
||||
|
||||
def test_op_desc_creation(self):
|
||||
block = g_program.current_block()
|
||||
mul_x = block.create_var(
|
||||
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
|
||||
mul_y = block.create_var(
|
||||
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
|
||||
mul_out = block.create_var(
|
||||
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
|
||||
mul_op = block.append_op(
|
||||
type="mul",
|
||||
inputs={"X": [mul_x],
|
||||
"Y": mul_y},
|
||||
outputs={"Out": [mul_out]},
|
||||
attrs={"x_num_col_dims": 1})
|
||||
self.assertEqual(mul_op.type, "mul")
|
||||
self.assertEqual(mul_op.input_names, ["X", "Y"])
|
||||
self.assertEqual(mul_op.input("X"), ["mul.x"])
|
||||
self.assertEqual(mul_op.input("Y"), ["mul.y"])
|
||||
self.assertEqual(mul_op.output_names, ["Out"])
|
||||
self.assertEqual(mul_op.output("Out"), ["mul.out"])
|
||||
self.assertEqual(
|
||||
set(mul_op.attr_names), set(["x_num_col_dims", "y_num_col_dims"]))
|
||||
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
|
||||
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
|
||||
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
|
||||
self.assertEqual(mul_op.has_attr("y_num_col_dims"), True)
|
||||
self.assertEqual(mul_op.attr_type("y_num_col_dims"), core.AttrType.INT)
|
||||
self.assertEqual(mul_op.attr("y_num_col_dims"), 1)
|
||||
self.assertEqual(mul_out.op, mul_op)
|
||||
|
||||
def test_mult_input(self):
|
||||
block = g_program.current_block()
|
||||
sum_x1 = block.create_var(
|
||||
dtype="int", shape=[3, 4], lod_level=0, name="sum.x1")
|
||||
sum_x2 = block.create_var(
|
||||
dtype="int", shape=[3, 4], lod_level=0, name="sum.x2")
|
||||
sum_x3 = block.create_var(
|
||||
dtype="int", shape=[3, 4], lod_level=0, name="sum.x3")
|
||||
sum_out = block.create_var(
|
||||
dtype="int", shape=[3, 4], lod_level=0, name="sum.out")
|
||||
sum_op = block.append_op(
|
||||
type="sum",
|
||||
inputs={"X": [sum_x1, sum_x2, sum_x3]},
|
||||
outputs={"Out": sum_out})
|
||||
self.assertEqual(sum_op.type, "sum")
|
||||
self.assertEqual(sum_op.input_names, ["X"])
|
||||
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
|
||||
self.assertEqual(sum_op.output_names, ["Out"])
|
||||
self.assertEqual(sum_op.output("Out"), ["sum.out"])
|
||||
self.assertEqual(sum_out.op, sum_op)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue