|
|
|
@ -12,14 +12,18 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
|
|
|
|
|
import functools
|
|
|
|
|
import layers
|
|
|
|
|
import framework
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'GradientClipByValue',
|
|
|
|
|
'ErrorClipByValue',
|
|
|
|
|
'GradientClipByValue',
|
|
|
|
|
'GradientClipByNorm',
|
|
|
|
|
'GradientClipByGlobalNorm',
|
|
|
|
|
'append_gradient_clip_ops',
|
|
|
|
|
'error_clip_callback',
|
|
|
|
|
]
|
|
|
|
@ -155,10 +159,11 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
|
|
|
|
|
return param, new_grad
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gradient_clip_by_global_norm(clip_norm,
|
|
|
|
|
param_list=None,
|
|
|
|
|
group_name="default_group",
|
|
|
|
|
program=None):
|
|
|
|
|
def set_gradient_clip(clip, param_list=None, program=None):
|
|
|
|
|
if not isinstance(clip, BaseGradientClipAttr):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"'clip' should be an instance of BaseGradientClipAttr's derived class"
|
|
|
|
|
)
|
|
|
|
|
if program is None:
|
|
|
|
|
program = framework.default_main_program()
|
|
|
|
|
if param_list is None:
|
|
|
|
@ -171,8 +176,7 @@ def gradient_clip_by_global_norm(clip_norm,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for param in param_list:
|
|
|
|
|
param.gradient_clip_attr = GradientClipByGlobalNorm(clip_norm,
|
|
|
|
|
group_name)
|
|
|
|
|
param.gradient_clip_attr = copy.deepcopy(clip)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_gradient_clip_ops(param_grad):
|
|
|
|
|