|
|
|
@ -160,6 +160,17 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_gradient_clip(clip, param_list=None, program=None):
|
|
|
|
|
"""
|
|
|
|
|
To specify parameters that require gradient clip.
|
|
|
|
|
Args:
|
|
|
|
|
clip(BaseGradientClipAttr): An instance of some derived class of BaseGradientClipAttr,
|
|
|
|
|
which describes the type and detailed attributes of required gradient clip.
|
|
|
|
|
param_list(list, None by default): Parameters that require gradient clip.
|
|
|
|
|
It can be a list of parameter or a list of parameter's name.
|
|
|
|
|
When it's None, all parameters in the program will be included.
|
|
|
|
|
program(Program, None by default): The program where parameters are.
|
|
|
|
|
Will be the default main program when assigned with None.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(clip, BaseGradientClipAttr):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"'clip' should be an instance of BaseGradientClipAttr's derived class"
|
|
|
|
|