|
|
|
@ -672,6 +672,8 @@ class CumSum(PrimitiveWithInfer):
|
|
|
|
|
def __infer__(self, x, axis):
|
|
|
|
|
cls_name = self.name
|
|
|
|
|
x_shp = x['shape']
|
|
|
|
|
if axis['value'] is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, axis must be const.")
|
|
|
|
|
validator.check_value_type('axis', axis['value'], [int], cls_name)
|
|
|
|
|
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
|
|
|
|
|
validator.check_tensor_type_same({'x': x['dtype']}, valid_types, cls_name)
|
|
|
|
@ -679,10 +681,6 @@ class CumSum(PrimitiveWithInfer):
|
|
|
|
|
'dtype': x['dtype'],
|
|
|
|
|
'value': None}
|
|
|
|
|
|
|
|
|
|
def infer_value(self, x, axis):
|
|
|
|
|
if axis is None:
|
|
|
|
|
raise ValueError(f"For {self.name}, axis must be const.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AddN(PrimitiveWithInfer):
|
|
|
|
|
"""
|
|
|
|
|