|
|
|
@ -33,6 +33,7 @@ from .. import signature as sig
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
|
|
|
from ...common._decorator import deprecated
|
|
|
|
from ...common.parameter import Parameter
|
|
|
|
from ...common.parameter import Parameter
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
|
|
|
|
|
|
|
@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
|
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def GatherV2():
|
|
|
|
class GatherV2(PrimitiveWithCheck):
|
|
|
|
"""Warning: This will be changed later"""
|
|
|
|
"""
|
|
|
|
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.")
|
|
|
|
Same as operator Gather. GatherV2 will be deprecated in the future.
|
|
|
|
return Gather()
|
|
|
|
Please use Gather instead.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
#deprecate_new_name = "Gather"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@deprecated("1.1", "Gather", True)
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
|
|
"""Initialize index_select"""
|
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
|
|
|
|
|
|
|
|
self.add_prim_attr("dynamic_shape_depends", [2])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __check__(self, params, indices, axis):
|
|
|
|
|
|
|
|
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
|
|
|
|
|
|
|
|
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
|
|
|
|
|
|
|
|
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
|
|
|
|
|
|
|
|
axis_v = axis['value']
|
|
|
|
|
|
|
|
validator.check_value_type('axis', axis_v, [int], self.name)
|
|
|
|
|
|
|
|
rank = len(params['shape'])
|
|
|
|
|
|
|
|
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SparseGatherV2(Gather):
|
|
|
|
class SparseGatherV2(Gather):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|