|
|
@ -13,19 +13,16 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
from . import framework
|
|
|
|
from . import framework
|
|
|
|
from .framework import in_dygraph_mode, _varbase_creator
|
|
|
|
from .framework import in_dygraph_mode, _varbase_creator
|
|
|
|
from . import core
|
|
|
|
from . import core
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
|
|
|
|
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_regularization_of_grad(param,
|
|
|
|
def _create_regularization_of_grad(param, grad, regularization=None):
|
|
|
|
grad,
|
|
|
|
|
|
|
|
regularization=None,
|
|
|
|
|
|
|
|
_repeat_regularizer=None):
|
|
|
|
|
|
|
|
""" Create and add backward regularization Operators
|
|
|
|
""" Create and add backward regularization Operators
|
|
|
|
|
|
|
|
|
|
|
|
Function helper of append_regularization_ops.
|
|
|
|
Function helper of append_regularization_ops.
|
|
|
@ -35,8 +32,6 @@ def _create_regularization_of_grad(param,
|
|
|
|
return grad
|
|
|
|
return grad
|
|
|
|
regularization_term = None
|
|
|
|
regularization_term = None
|
|
|
|
if param.regularizer is not None:
|
|
|
|
if param.regularizer is not None:
|
|
|
|
if regularization is not None:
|
|
|
|
|
|
|
|
_repeat_regularizer.append(param.name)
|
|
|
|
|
|
|
|
# Add variable for regularization term in grad block
|
|
|
|
# Add variable for regularization term in grad block
|
|
|
|
regularization_term = param.regularizer(param, grad, grad.block)
|
|
|
|
regularization_term = param.regularizer(param, grad, grad.block)
|
|
|
|
elif regularization is not None:
|
|
|
|
elif regularization is not None:
|
|
|
@ -89,25 +84,25 @@ def append_regularization_ops(parameters_and_grads, regularization=None):
|
|
|
|
Exception: Unknown regularization type
|
|
|
|
Exception: Unknown regularization type
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
params_and_grads = []
|
|
|
|
params_and_grads = []
|
|
|
|
_repeat_regularizer = []
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
if in_dygraph_mode():
|
|
|
|
for param, grad in parameters_and_grads:
|
|
|
|
for param, grad in parameters_and_grads:
|
|
|
|
new_grad = _create_regularization_of_grad(
|
|
|
|
new_grad = _create_regularization_of_grad(param, grad,
|
|
|
|
param, grad, regularization, _repeat_regularizer)
|
|
|
|
regularization)
|
|
|
|
params_and_grads.append((param, new_grad))
|
|
|
|
params_and_grads.append((param, new_grad))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
repeate_regularizer = False
|
|
|
|
with framework.name_scope('regularization'):
|
|
|
|
with framework.name_scope('regularization'):
|
|
|
|
for param, grad in parameters_and_grads:
|
|
|
|
for param, grad in parameters_and_grads:
|
|
|
|
|
|
|
|
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
|
|
|
|
|
|
|
|
repeate_regularizer = True
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
|
|
|
|
|
|
|
|
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
|
|
|
|
|
|
|
|
% regularization.__str__())
|
|
|
|
with param.block.program._optimized_guard([param, grad]):
|
|
|
|
with param.block.program._optimized_guard([param, grad]):
|
|
|
|
new_grad = _create_regularization_of_grad(
|
|
|
|
new_grad = _create_regularization_of_grad(param, grad,
|
|
|
|
param, grad, regularization, _repeat_regularizer)
|
|
|
|
regularization)
|
|
|
|
params_and_grads.append((param, new_grad))
|
|
|
|
params_and_grads.append((param, new_grad))
|
|
|
|
if len(_repeat_regularizer) > 0:
|
|
|
|
|
|
|
|
param_name_strlist = ", ".join(_repeat_regularizer)
|
|
|
|
|
|
|
|
logging.info(
|
|
|
|
|
|
|
|
"Regularization of [%s] have been set by ParamAttr or WeightNormParamAttr already. "
|
|
|
|
|
|
|
|
"So, the Regularization of Optimizer will not take effect for these parameters!"
|
|
|
|
|
|
|
|
% param_name_strlist)
|
|
|
|
|
|
|
|
return params_and_grads
|
|
|
|
return params_and_grads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|