From 7a58431c0aa532901da1caa1e50538244ce0f03b Mon Sep 17 00:00:00 2001 From: zhang wenhui Date: Wed, 14 Oct 2020 11:13:29 +0800 Subject: [PATCH] fix norm api doc, test=develop (#27652) * fix norm api doc, test=develop * fix error message, test=develop * fix api norm, test=develop * add adagrad, test=develop * fix bug, test=develop * fix bug, test=develop * add spetral_norm, test=develop * fix adagrad, test=develop * merge , test=develop --- paddle/fluid/operators/batch_norm_op.cc | 9 +- .../distributed_ops/fl_listen_and_serv_op.cc | 9 +- python/paddle/fluid/layers/nn.py | 7 +- .../tests/unittests/test_adagrad_op_v2.py | 41 ++++++ .../fluid/tests/unittests/test_layers.py | 6 +- python/paddle/nn/__init__.py | 1 - python/paddle/nn/layer/__init__.py | 2 +- python/paddle/nn/layer/norm.py | 18 +-- python/paddle/optimizer/__init__.py | 3 +- python/paddle/optimizer/adagrad.py | 136 ++++++++++++++++++ 10 files changed, 203 insertions(+), 29 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py create mode 100644 python/paddle/optimizer/adagrad.py diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 7a88403aa9..370ba8619f 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -381,7 +381,8 @@ class BatchNormKernel break; } default: - PADDLE_THROW("Unknown storage order: %s", data_layout_str); + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknown storage order: %s", data_layout_str)); } // if MomentumTensor is set, use MomentumTensor value, momentum @@ -446,7 +447,8 @@ class BatchNormKernel break; } default: - PADDLE_THROW("Unknown storage order: %d", data_layout); + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknown storage order: %d", data_layout)); } } }; @@ -799,7 +801,8 @@ class BatchNormGradKernel break; } default: - PADDLE_THROW("Unknown storage order: %s", data_layout_str); + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknown storage order: %s", data_layout_str)); } } }; diff --git a/paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc b/paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc index 80b322fbe6..2e54bb3961 100644 --- a/paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc +++ b/paddle/fluid/operators/distributed_ops/fl_listen_and_serv_op.cc @@ -108,7 +108,8 @@ void FlListenAndServOp::RunSyncLoop(framework::Executor *executor, auto optimize_blocks = Attr>(kOptimizeBlocks); PADDLE_ENFORCE_GE(num_blocks, 2, - "server program should have at least 2 blocks"); + platform::errors::InvalidArgument( + "server program should have at least 2 blocks")); // Prepare all the server block std::vector optimize_blocks_list; @@ -192,7 +193,8 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope, auto fan_in = Attr("Fanin"); auto inputs = Inputs("X"); - PADDLE_ENFORCE_EQ(!rpc_service_, true, "rpc_service_ must null"); + PADDLE_ENFORCE_EQ(!rpc_service_, true, platform::errors::InvalidArgument( + "rpc_service_ must null")); std::string endpoint = Attr("endpoint"); VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in @@ -215,7 +217,8 @@ void FlListenAndServOp::RunImpl(const framework::Scope &scope, Attr>(kOptimizeBlocks); PADDLE_ENFORCE_GE( optimize_blocks.size(), 1, - "optimize blocks should be 1 at least on the pserver side."); + platform::errors::InvalidArgument( + "optimize blocks should be 1 at least on the pserver side.")); auto *program = optimize_blocks[0]->Program(); framework::Executor executor(dev_place); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 83e282920d..91cce50f3f 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3674,10 +3674,11 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None): Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle - weight = fluid.data(name='weight', shape=[2, 8, 32, 32], dtype='float32') - x = fluid.layers.spectral_norm(weight=weight, dim=1, power_iters=2) + paddle.enable_static() + weight = paddle.data(name='weight', shape=[2, 8, 32, 32], dtype='float32') + x = paddle.static.nn.spectral_norm(weight=weight, dim=1, power_iters=2) """ helper = LayerHelper('spectral_norm', **locals()) check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], diff --git a/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py b/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py new file mode 100644 index 0000000000..0ccd42aa67 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py @@ -0,0 +1,41 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from op_test import OpTest +import math + + +class TestAdagradOpV2(unittest.TestCase): + def test_v20_coverage(self): + paddle.disable_static() + inp = paddle.rand(shape=[10, 10]) + linear = paddle.nn.Linear(10, 10) + out = linear(inp) + loss = paddle.mean(out) + adagrad = paddle.optimizer.Adagrad( + learning_rate=0.1, parameters=linear.parameters()) + out.backward() + adagrad.step() + adagrad.clear_grad() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e0ec676f1b..cb1a5a6bdf 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1369,7 +1369,7 @@ class TestLayer(LayerTest): dy_rlt_value = dy_ret.numpy() with self.dynamic_graph(): - instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + instanceNorm = nn.InstanceNorm(num_channels=shape[1]) dy_ret = instanceNorm(base.to_variable(input)) dy_rlt_value2 = dy_ret.numpy() @@ -1380,7 +1380,7 @@ class TestLayer(LayerTest): with self.static_graph(): # the input of InstanceNorm must be Variable. def test_Variable(): - instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + instanceNorm = nn.InstanceNorm(num_channels=shape[1]) ret1 = instanceNorm(input) self.assertRaises(TypeError, test_Variable) @@ -1388,7 +1388,7 @@ class TestLayer(LayerTest): # the input dtype of InstanceNorm must be float32 or float64 def test_type(): input = np.random.random(shape).astype('int32') - instanceNorm = paddle.nn.InstanceNorm(num_channels=shape[1]) + instanceNorm = nn.InstanceNorm(num_channels=shape[1]) ret2 = instanceNorm(input) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/nn/__init__.py b/python/paddle/nn/__init__.py index b506b52ec9..b1f3737805 100644 --- a/python/paddle/nn/__init__.py +++ b/python/paddle/nn/__init__.py @@ -139,7 +139,6 @@ from .layer.norm import SyncBatchNorm #DEFINE_ALIAS from .layer.norm import GroupNorm #DEFINE_ALIAS from .layer.norm import LayerNorm #DEFINE_ALIAS from .layer.norm import SpectralNorm #DEFINE_ALIAS -from .layer.norm import InstanceNorm #DEFINE_ALIAS from .layer.norm import InstanceNorm1d #DEFINE_ALIAS from .layer.norm import InstanceNorm2d #DEFINE_ALIAS from .layer.norm import InstanceNorm3d #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/__init__.py b/python/paddle/nn/layer/__init__.py index 8a234e779e..afd2cc3a23 100644 --- a/python/paddle/nn/layer/__init__.py +++ b/python/paddle/nn/layer/__init__.py @@ -102,7 +102,7 @@ from .norm import SyncBatchNorm #DEFINE_ALIAS from .norm import GroupNorm #DEFINE_ALIAS from .norm import LayerNorm #DEFINE_ALIAS from .norm import SpectralNorm #DEFINE_ALIAS -from .norm import InstanceNorm #DEFINE_ALIAS +#from .norm import InstanceNorm #DEFINE_ALIAS from .norm import LocalResponseNorm #DEFINE_ALIAS # from .rnn import RNNCell #DEFINE_ALIAS # from .rnn import GRUCell #DEFINE_ALIAS diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 50f7904c41..ecc89b6c1a 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -28,7 +28,7 @@ # TODO: define normalization api import six -from ...fluid.dygraph.nn import InstanceNorm +#from ...fluid.dygraph.nn import InstanceNorm from ...fluid.dygraph import BatchNorm #DEFINE_ALIAS #from ...fluid.dygraph import GroupNorm #DEFINE_ALIAS @@ -54,19 +54,9 @@ from ...fluid.dygraph.base import no_grad from .. import functional as F __all__ = [ - 'BatchNorm', - 'GroupNorm', - 'LayerNorm', - 'SpectralNorm', - 'InstanceNorm', - 'BatchNorm1d', - 'BatchNorm2d', - 'BatchNorm3d', - 'InstanceNorm1d', - 'InstanceNorm2d', - 'InstanceNorm3d', - 'SyncBatchNorm', - 'LocalResponseNorm', + 'BatchNorm', 'GroupNorm', 'LayerNorm', 'SpectralNorm', 'BatchNorm1d', + 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d', 'InstanceNorm2d', + 'InstanceNorm3d', 'SyncBatchNorm', 'LocalResponseNorm' ] diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index d041cb85d5..1ca52a806d 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -20,11 +20,12 @@ __all__ = [ ] -from ..fluid.optimizer import Momentum, Adagrad, Dpsgd, DecayedAdagrad, Ftrl,\ +from ..fluid.optimizer import Momentum, Dpsgd, DecayedAdagrad, Ftrl,\ AdagradOptimizer, DpsgdOptimizer, DecayedAdagradOptimizer, \ FtrlOptimizer, AdadeltaOptimizer from .optimizer import Optimizer +from .adagrad import Adagrad from .adam import Adam from .adamw import AdamW from .adamax import Adamax diff --git a/python/paddle/optimizer/adagrad.py b/python/paddle/optimizer/adagrad.py new file mode 100644 index 0000000000..ed55ebd0bf --- /dev/null +++ b/python/paddle/optimizer/adagrad.py @@ -0,0 +1,136 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizer import Optimizer +from ..fluid import core +from ..fluid import framework +from ..fluid.framework import Variable + +__all__ = ["Adagrad"] + + +class Adagrad(Optimizer): + """ + The Adaptive Gradient optimizer (Adagrad for short) use an optimization described + in paper: `Adaptive Subgradient Methods for Online Learning and + Stochastic Optimization `_. + + The parameter ``param_out`` update rule with gradient ``grad``: + + .. math:: + + moment\_out &= moment + grad * grad + + param\_out &= param - \\frac{learning\_rate * grad}{\sqrt{moment\_out} + \epsilon} + + + The original paper does not have the ``epsilon`` attribute. It is added here + in our implementation as also proposed `Per-parameter adaptive learning rate + methods `_ + for numerical stability to avoid the division by zero error. + + Args: + learning_rate (float|Tensor): The learning rate used to update ``Parameter``. + It can be a float value or a ``Variable`` with a float type. + epsilon (float, optional): A small float value for numerical stability. + The default value is 1e-06. + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies, + ClipGradByGlobalNorm, ClipGradByNorm and ClipGradByValue. Default None, + meaning there is no gradient clipping. + name (str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name`. + The default value is None. + initial_accumulator_value (float, optional): Initial value for moment accumulator. + The default value is 0.0. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + inp = paddle.rand(shape=[10, 10]) + linear = paddle.nn.Linear(10, 10) + out = linear(inp) + loss = paddle.mean(out) + adagrad = paddle.optimizer.Adagrad(learning_rate=0.1, + parameters=linear.parameters()) + out.backward() + adagrad.step() + adagrad.clear_grad() + + """ + _moment_acc_str = "moment" + + def __init__(self, + learning_rate, + epsilon=1.0e-6, + parameters=None, + weight_decay=None, + grad_clip=None, + name=None, + initial_accumulator_value=0.0): + assert learning_rate is not None + assert epsilon is not None + super(Adagrad, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name) + self.type = "adagrad" + self._epsilon = epsilon + self.initial_accumulator_value = initial_accumulator_value + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator( + self._moment_acc_str, + p, + fill_value=self.initial_accumulator_value) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + moment_acc = self._get_accumulator(self._moment_acc_str, + param_and_grad[0]) + # Create the adagrad optimizer op + adagrad_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "Moment": moment_acc, + "LearningRate": self._create_param_lr(param_and_grad) + }, + outputs={"ParamOut": param_and_grad[0], + "MomentOut": moment_acc}, + attrs={"epsilon": self._epsilon}, + stop_gradient=True) + + return adagrad_op