|
|
@ -593,7 +593,6 @@ class DynamicShape(Primitive):
|
|
|
|
"""init Shape"""
|
|
|
|
"""init Shape"""
|
|
|
|
self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
|
|
|
|
self.init_prim_io_names(inputs=['tensor'], outputs=['output'])
|
|
|
|
self.add_prim_attr('is_dynamic_shape', True)
|
|
|
|
self.add_prim_attr('is_dynamic_shape', True)
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Squeeze(PrimitiveWithInfer):
|
|
|
|
class Squeeze(PrimitiveWithInfer):
|
|
|
@ -811,7 +810,6 @@ class Gather(PrimitiveWithCheck):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize index_select"""
|
|
|
|
"""Initialize index_select"""
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, params, indices, axis):
|
|
|
|
def __check__(self, params, indices, axis):
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
@ -836,7 +834,6 @@ class GatherV2(PrimitiveWithCheck):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize index_select"""
|
|
|
|
"""Initialize index_select"""
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, params, indices, axis):
|
|
|
|
def __check__(self, params, indices, axis):
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
@ -1987,7 +1984,6 @@ class UnsortedSegmentMin(PrimitiveWithCheck):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize UnsortedSegmentMin"""
|
|
|
|
"""Initialize UnsortedSegmentMin"""
|
|
|
|
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
|
|
|
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, x, segment_ids, num_segments):
|
|
|
|
def __check__(self, x, segment_ids, num_segments):
|
|
|
|
x_shape = x['shape']
|
|
|
|
x_shape = x['shape']
|
|
|
@ -2043,7 +2039,6 @@ class UnsortedSegmentMax(PrimitiveWithCheck):
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
"""Initialize UnsortedSegmentMax"""
|
|
|
|
"""Initialize UnsortedSegmentMax"""
|
|
|
|
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
|
|
|
self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y'])
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, x, segment_ids, num_segments):
|
|
|
|
def __check__(self, x, segment_ids, num_segments):
|
|
|
|
x_shape = x['shape']
|
|
|
|
x_shape = x['shape']
|
|
|
@ -4980,10 +4975,6 @@ class Range(PrimitiveWithCheck):
|
|
|
|
self.maxlen = maxlen
|
|
|
|
self.maxlen = maxlen
|
|
|
|
self.add_prim_attr('maxlen', maxlen)
|
|
|
|
self.add_prim_attr('maxlen', maxlen)
|
|
|
|
|
|
|
|
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [0])
|
|
|
|
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [1])
|
|
|
|
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_shape(self, start_shape, limit_shape, delta_shape):
|
|
|
|
def check_shape(self, start_shape, limit_shape, delta_shape):
|
|
|
|
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
|
|
|
|
validator.check("start_shape", len(start_shape), "", 0, Rel.EQ, self.name)
|
|
|
|
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
|
|
|
|
validator.check("limit_shape", len(limit_shape), "", 0, Rel.EQ, self.name)
|
|
|
|