enhance regularizer.py

shanyi15-patch-2
chengduoZH 7 years ago
parent 0d49b92140
commit 74523c41f1

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import framework import framework
from . import core
__all__ = [ __all__ = [
'append_regularization_ops', 'append_regularization_ops',
@ -46,9 +47,9 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
regularization_term = None regularization_term = None
if param.regularizer is not None: if param.regularizer is not None:
# Add variable for regularization term in grad block # Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block) regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None: elif regularization is not None:
regularization_term = regularization(param, grad.block) regularization_term = regularization(param, grad, grad.block)
# If no gradient or no regularization specified, # If no gradient or no regularization specified,
# then we don't need to do anything # then we don't need to do anything
@ -82,7 +83,7 @@ class WeightDecayRegularizer(object):
def __init__(self): def __init__(self):
pass pass
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add corresponding weight decay operations to the network """Add corresponding weight decay operations to the network
""" """
raise NotImplementedError() raise NotImplementedError()
@ -102,7 +103,7 @@ class L2DecayRegularizer(WeightDecayRegularizer):
super(L2DecayRegularizer, self).__init__() super(L2DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff self._regularization_coeff = regularization_coeff
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add L2 weight decay ops to network """Add L2 weight decay ops to network
Adds L2 weight decay ops. Adds L2 weight decay ops.
@ -117,8 +118,23 @@ class L2DecayRegularizer(WeightDecayRegularizer):
""" """
assert isinstance(param, framework.Parameter) assert isinstance(param, framework.Parameter)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
param = decay
# Append Op to calculate decay # Append Op to calculate decay
block.append_op( block.append_op(
type='scale', type='scale',
@ -141,7 +157,7 @@ class L1DecayRegularizer(WeightDecayRegularizer):
super(L1DecayRegularizer, self).__init__() super(L1DecayRegularizer, self).__init__()
self._regularization_coeff = regularization_coeff self._regularization_coeff = regularization_coeff
def __call__(self, param, block): def __call__(self, param, grad, block):
"""Add L1 weight decay ops to network """Add L1 weight decay ops to network
Adds L1 weight decay ops. Adds L1 weight decay ops.
@ -158,6 +174,20 @@ class L1DecayRegularizer(WeightDecayRegularizer):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
decay = block.create_var( decay = block.create_var(
dtype="float32", shape=param.shape, lod_level=param.lod_level) dtype="float32", shape=param.shape, lod_level=param.lod_level)
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# add concat_rows
decay = block.create_var(
dtype="float32",
shape=param.shape,
type=core.VarDesc.VarType.SELECTED_ROWS)
block.append_op(
type='lookup_table',
inputs={'W': param,
'Ids': grad},
outputs={'Out': decay},
attrs={'is_sparse': True})
# Append sign op # Append sign op
block.append_op( block.append_op(
type='sign', inputs={"X": param}, outputs={"Out": decay}) type='sign', inputs={"X": param}, outputs={"Out": decay})

Loading…
Cancel
Save