add layer norm to Layers, add transformer test in imperative mode (#16092)

* add layer norm to Layers, add transformer prepare encoding

* little change

* finish encoder part

* add decoder part

* finish model part

* add test case and part of data feed

* add transformer test

* add to_parameter, add remove in set_attr

* test=develop, fix pos encoding bug, create_parameter with stantard name

* test=develop, rm dropout test in imperative

* test=develop, fix cpu error

* test=develop, fix minize bug

* test=develop, fix one hot not stop gradient

* test=develop, fix one hot not stop gradient

* test=develop, refine parameter name

* test=develop, fix transformer test in imperative mode

* test=develop, fix transformer test in imperative mode

* test=develop, fix boost and mkl download error

* test=develop, fix boost and mkl download error

* test=develop, fix ci and refine code

* test=develop, fix ci and refine code
move-code
Jiabin Yang 7 years ago committed by GitHub
parent 0fff666f14
commit f735102eab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -315,6 +315,9 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
framework::Variable* grad = outputs[i]->var_; framework::Variable* grad = outputs[i]->var_;
framework::Variable* orig_grad = origin_outputs[i]->var_; framework::Variable* orig_grad = origin_outputs[i]->var_;
VLOG(3) << "AddTo Called with orig_grad is: "
<< origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_;
AddTo(grad, orig_grad, place_); AddTo(grad, orig_grad, place_);
delete grad; delete grad;
} }

@ -277,6 +277,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
VarBase* var = current_vars_map[var_it->second]; VarBase* var = current_vars_map[var_it->second];
InitGrad(var, prepared_op.GetDeviceContext()); InitGrad(var, prepared_op.GetDeviceContext());
grad_out_vars.push_back(var->grads_); grad_out_vars.push_back(var->grads_);
VLOG(3) << "grads output var name: " << var->name_;
} }
} }
} }

@ -15,7 +15,7 @@
WMT14 dataset. WMT14 dataset.
The original WMT14 dataset is too large and a small set of data for set is The original WMT14 dataset is too large and a small set of data for set is
provided. This module will download dataset from provided. This module will download dataset from
http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and http://paddlepaddle.bj.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """

@ -44,7 +44,7 @@ def guard(place=None):
yield yield
def to_variable(value, block=None): def to_variable(value, block=None, name=None):
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
assert enabled(), "to_variable could only be called in imperative mode" assert enabled(), "to_variable could only be called in imperative mode"
@ -53,7 +53,7 @@ def to_variable(value, block=None):
py_var = framework.Variable( py_var = framework.Variable(
block, block,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
name=None, name=name,
shape=value.shape, shape=value.shape,
dtype=value.dtype) dtype=value.dtype)
var = py_var._ivar.value() var = py_var._ivar.value()

@ -105,6 +105,7 @@ class LayerObjectHelper(LayerHelperBase):
Returns dtype of the input Returns dtype of the input
""" """
inputs_in = inputs_in if (inputs_in is not None) else []
inputs = self._multiple_input(inputs_in) inputs = self._multiple_input(inputs_in)
dtype = None dtype = None
for each in inputs: for each in inputs:

@ -17,10 +17,12 @@ import contextlib
import sys import sys
import numpy as np import numpy as np
import collections import collections
import six
from .. import unique_name from .. import unique_name
from paddle.fluid import core from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper from .layer_object_helper import LayerObjectHelper
from paddle.fluid import framework from paddle.fluid import framework
from ..param_attr import ParamAttr
__all__ = ['Layer', 'PyLayer'] __all__ = ['Layer', 'PyLayer']
@ -72,6 +74,10 @@ class Layer(core.Layer):
Returns created parameter Variable. Returns created parameter Variable.
""" """
if isinstance(attr, ParamAttr) and (attr.name is not None):
attr.name = ".".join([self._full_name, attr.name])
elif isinstance(attr, six.string_types):
attr = ".".join([self._full_name, attr])
return self._helper.create_parameter(attr, shape, dtype, is_bias, return self._helper.create_parameter(attr, shape, dtype, is_bias,
default_initializer) default_initializer)
@ -164,6 +170,7 @@ class Layer(core.Layer):
the sublayer passed in. the sublayer passed in.
""" """
assert isinstance(sublayer, core.Layer) assert isinstance(sublayer, core.Layer)
self._sub_layers[name] = sublayer self._sub_layers[name] = sublayer
return sublayer return sublayer

@ -20,10 +20,12 @@ from .. import core
from ..layers import utils from ..layers import utils
from . import layers from . import layers
from ..framework import Variable, OpProtoHolder from ..framework import Variable, OpProtoHolder
from ..layers import layer_function_generator
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Normal, Constant from ..initializer import Normal, Constant
__all__ = [
__all__ = ['Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit'] 'Conv2D', 'Pool2D', 'FC', 'BatchNorm', 'Embedding', 'GRUUnit', 'LayerNorm'
]
class Conv2D(layers.Layer): class Conv2D(layers.Layer):
@ -438,7 +440,6 @@ class Embedding(layers.Layer):
self._size = size self._size = size
self._is_sparse = is_sparse self._is_sparse = is_sparse
self._is_distributed = is_distributed self._is_distributed = is_distributed
self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
size[0] + padding_idx) size[0] + padding_idx)
@ -471,6 +472,131 @@ class Embedding(layers.Layer):
return out return out
class LayerNorm(layers.Layer):
def __init__(self,
name_scope,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
act=None):
"""
${comment}
The formula is as follows:
.. math::
\\mu & = \\frac{1}{H}\\sum_{i=1}^{H} a_i
\\sigma & = \\sqrt{\\frac{1}{H}\sum_{i=1}^{H}(a_i - \\mu)^2}
h & = f(\\frac{g}{\\sigma}(a - \\mu) + b)
* :math:`a`: the vector representation of the summed inputs to the neurons
in that layer.
* :math:`H`: the number of hidden units in a layers
* :math:`g`: the trainable scale parameter.
* :math:`b`: the trainable bias parameter.
Args:
input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization. Default True.
shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization. Default True.
begin_norm_axis(int): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
Default 1.
epsilon(float): The small value added to the variance to prevent
division by zero. Default 1e-05.
param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`. If :attr:`scale` is False, :attr:`param_attr` is
omitted. If :attr:`scale` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as scale. The
:attr:`param_attr` is initialized as 1 if it is added. Default None.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`. If :attr:`shift` is False, :attr:`bias_attr` is
omitted. If :attr:`shift` is True and :attr:`param_attr` is None,
a default :code:`ParamAttr` would be added as bias. The
:attr:`bias_attr` is initialized as 0 if it is added. Default None.
act(str): Activation to be applied to the output of layer normalizaiton.
Default None.
Returns:
${y_comment}
Examples:
>>> data = fluid.layers.data(name='data', shape=[3, 32, 32],
>>> dtype='float32')
>>> x = fluid.layers.layer_norm(input=data, begin_norm_axis=1)
"""
super(LayerNorm, self).__init__(name_scope)
self._scale = scale
self._shift = shift
self._begin_norm_axis = begin_norm_axis
self._epsilon = epsilon
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
def _build_once(self, input):
self._dtype = self._helper.input_dtype(input)
input_shape = input.shape
param_shape = [
reduce(lambda x, y: x * y, input_shape[self._begin_norm_axis:])
]
if self._scale:
self._scale_w = self.create_parameter(
attr=self._param_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
if self._shift:
assert self._bias_attr is not False
self._bias_w = self.create_parameter(
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
def forward(self, input):
inputs = dict()
inputs['X'] = input
if self._scale:
inputs['Scale'] = self._scale_w
if self._shift:
inputs['Bias'] = self._bias_w
# create output
mean_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
variance_out = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True)
layer_norm_out = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": layer_norm_out,
"Mean": mean_out,
"Variance": variance_out,
},
attrs={
"epsilon": self._epsilon,
"begin_norm_axis": self._begin_norm_axis
})
return self._helper.append_activation(layer_norm_out)
class GRUUnit(layers.Layer): class GRUUnit(layers.Layer):
""" """
**GRU unit layer** **GRU unit layer**

@ -268,11 +268,9 @@ class LayerHelperBase(object):
""" """
# Deepcopy the attr so that parameters can be shared in program # Deepcopy the attr so that parameters can be shared in program
attr = copy.deepcopy(attr) attr = copy.deepcopy(attr)
if attr is None: attr = ParamAttr._to_attr(attr)
attr = ParamAttr._to_attr(attr)
if not attr: if not attr:
return None return None
assert isinstance(attr, ParamAttr) assert isinstance(attr, ParamAttr)
suffix = 'b' if is_bias else 'w' suffix = 'b' if is_bias else 'w'
if attr.name is None: if attr.name is None:

@ -6206,7 +6206,8 @@ def one_hot(input, depth):
type="one_hot", type="one_hot",
inputs={'X': input}, inputs={'X': input},
attrs={'depth': depth}, attrs={'depth': depth},
outputs={'Out': one_hot_out}) outputs={'Out': one_hot_out},
stop_gradient=True)
return one_hot_out return one_hot_out

@ -397,13 +397,14 @@ class Optimizer(object):
for param in parameters: for param in parameters:
if not param.trainable: if not param.trainable:
continue continue
# create gradient variable if param._ivar._grad_ivar() is not None:
grad_var = Variable( # create gradient variable
block=loss.block, grad_var = Variable(
name=param._ivar._grad_name(), block=loss.block,
stop_gradient=True, name=param._ivar._grad_name(),
ivar=param._ivar._grad_ivar()) stop_gradient=True,
params_grads.append((param, grad_var)) ivar=param._ivar._grad_ivar())
params_grads.append((param, grad_var))
with program_guard(framework.default_main_program(), with program_guard(framework.default_main_program(),
framework.default_startup_program()): framework.default_startup_program()):
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)

@ -70,6 +70,34 @@ class LayerTest(unittest.TestCase):
class TestLayer(LayerTest): class TestLayer(LayerTest):
def test_layer_norm(self):
inp = np.ones([3, 32, 32], dtype='float32')
with self.static_graph():
t = layers.data(
name='data',
shape=[3, 32, 32],
dtype='float32',
append_batch_size=False)
ret = layers.layer_norm(t)
static_ret = self.get_static_graph_result(
feed={'data': inp}, fetch_list=[ret])[0]
with self.static_graph():
t = layers.data(
name='data',
shape=[3, 32, 32],
dtype='float32',
append_batch_size=False)
lm = nn.LayerNorm('layer_norm')
ret = lm(t)
static_ret2 = self.get_static_graph_result(
feed={'data': inp}, fetch_list=[ret])[0]
with self.dynamic_graph():
lm = nn.LayerNorm('layer_norm')
dy_ret = lm(base.to_variable(inp))
self.assertTrue(np.allclose(static_ret, static_ret2))
self.assertTrue(np.allclose(dy_ret._numpy(), static_ret2))
def test_relu(self): def test_relu(self):
with self.static_graph(): with self.static_graph():
t = layers.data(name='t', shape=[3, 3], dtype='float32') t = layers.data(name='t', shape=[3, 3], dtype='float32')

Loading…
Cancel
Save