!1986 fixed validator for CumSum

Merge pull request !1986 from jiangjinsheng/issue_fix2
pull/1986/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit af85b2cebf

@ -1001,15 +1001,16 @@ def get_bprop_bessel_i1e(self):
reciprocal = P.Reciprocal() reciprocal = P.Reciprocal()
cast = P.Cast() cast = P.Cast()
dtype = P.DType() dtype = P.DType()
abs_ops = P.Abs()
def bprop(x, out, dout): def bprop(x, out, dout):
zeros = zeros_like(x) zeros = zeros_like(x)
np_eps = const_utils.get_np_eps(dtype(x)) np_eps = const_utils.get_np_eps(dtype(x))
eps = cast(np_eps, dtype(x)) eps = cast(np_eps, dtype(x))
x_is_valid = less(eps, x) x_is_valid = less(eps, abs_ops(x))
x_safe = select(x_is_valid, x, eps + zeros) x_safe = select(x_is_valid, x, eps + zeros)
tmp = bessel_i0e(x_safe) - out * (sign(x) + reciprocal(x_safe)) tmp = bessel_i0e(x_safe) - out * (sign(x_safe) + reciprocal(x_safe))
dx = select(x_is_valid, tmp, 0.5 + zeros) dx = select(x_is_valid, tmp, cast(0.5, dtype(x)) + zeros) * dout
return (dx,) return (dx,)
return bprop return bprop

@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
def __infer__(self, x, axis): def __infer__(self, x, axis):
cls_name = self.name cls_name = self.name
x_shp = x['shape'] x_shp = x['shape']
if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name) validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
'dtype': x['dtype'], 'dtype': x['dtype'],
'value': None} 'value': None}
def infer_value(self, x, axis):
if axis is None:
raise ValueError(f"For {self.name}, axis must be const.")
class AddN(PrimitiveWithInfer): class AddN(PrimitiveWithInfer):
""" """

@ -1767,9 +1767,6 @@ class ApplyRMSProp(PrimitiveWithInfer):
def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon): def infer_value(self, var, mean_square, moment, learning_rate, grad, decay, momentum, epsilon):
if decay is None or momentum is None or epsilon is None: if decay is None or momentum is None or epsilon is None:
raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.") raise ValueError(f"For {self.name}, decay, momentum, epsilon must be const.")
if not self.is_ge and self.is_d:
return None, None, None
return None
class ApplyCenteredRMSProp(PrimitiveWithInfer): class ApplyCenteredRMSProp(PrimitiveWithInfer):

Loading…
Cancel
Save