From bbfab3ed7c0adb1967262d73e86b3fafcb02605b Mon Sep 17 00:00:00 2001 From: liangzelang Date: Sat, 20 Jun 2020 09:57:45 +0800 Subject: [PATCH] fix some type cast bug --- mindspore/nn/optim/adam.py | 6 +++--- mindspore/ops/operations/nn_ops.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 92cab56a05..0a9a00cda6 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -72,9 +72,9 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad update_with_lr = op_mul(lr, update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) - next_v = F.depend(next_v, F.assign(param, next_param)) - next_v = F.depend(next_v, F.assign(m, next_m)) - next_v = F.depend(next_v, F.assign(v, next_v)) + next_v = F.depend(next_v, F.assign(param, op_cast(next_param, mstype.float16))) + next_v = F.depend(next_v, F.assign(m, op_cast(next_m, mstype.float16))) + next_v = F.depend(next_v, F.assign(v, op_cast(next_v, mstype.float16))) return next_v diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index d65ac0a276..a5c1684fce 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1544,9 +1544,9 @@ class ApplyMomentum(PrimitiveWithInfer): ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), + sig_dtype.T1), ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2) ) @prim_attr_register