|
|
|
@ -13,7 +13,7 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import functools
|
|
|
|
|
import layers
|
|
|
|
|
from framework import Variable
|
|
|
|
|
import framework
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -128,8 +128,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def check_init(cls):
|
|
|
|
|
if not (isinstance(cls.global_norm_var, Variable) and
|
|
|
|
|
isinstance(cls.clip_norm_var, Variable)):
|
|
|
|
|
if not (isinstance(cls.global_norm_var, framework.Variable) and
|
|
|
|
|
isinstance(cls.clip_norm_var, framework.Variable)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Class 'GradientClipByGlobalNorm' has not been properly initialized. \
|
|
|
|
|
Please call GradientClipByGlobalNorm.init() first.")
|
|
|
|
@ -158,6 +158,23 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
|
|
|
|
|
return param, new_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradient_clip_by_global_norm(clip_norm, param_list=None, program=None):
|
|
|
|
|
if program is None:
|
|
|
|
|
program = framework.default_main_program()
|
|
|
|
|
if param_list is None:
|
|
|
|
|
param_list = program.block(0).all_parameters()
|
|
|
|
|
if all(isinstance(elem, basestring) for elem in param_list):
|
|
|
|
|
param_list = [program.block(0).var(elem) for elem in param_list]
|
|
|
|
|
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"'param_list' should be a list of Parameter or basestring(parameter's name)."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
GradientClipByGlobalNorm.init(clip_norm)
|
|
|
|
|
for param in param_list:
|
|
|
|
|
param.gradient_clip_attr = GradientClipByGlobalNorm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_gradient_clip_ops(param_grad):
|
|
|
|
|
context = dict()
|
|
|
|
|
create_op_callbacks = []
|
|
|
|
|