!13648 Add float64 support to cumsum

From: @peilin-wang
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/13648/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5c6f0ed2f4

@ -972,7 +972,7 @@ class CumSum(PrimitiveWithInfer):
if axis['value'] is None: if axis['value'] is None:
raise ValueError(f"For {self.name}, axis must be const.") raise ValueError(f"For {self.name}, axis must be const.")
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_dtypes = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32, mstype.float64]
validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name) validator.check_tensor_dtype_valid('x', x['dtype'], valid_dtypes, cls_name)
return {'shape': x_shp, return {'shape': x_shp,
'dtype': x['dtype'], 'dtype': x['dtype'],

Loading…
Cancel
Save