Add deprecated function

pull/12033/head
l00591931 5 years ago
parent cc58feebee
commit edbe3bfd3b

@ -15,12 +15,13 @@
"""Providing decorators.""" """Providing decorators."""
def deprecated(version, substitute): def deprecated(version, substitute, use_substitute_name=False):
"""deprecated warning """deprecated warning
Args: Args:
version (str): version that the operator or function is deprecated. version (str): version that the operator or function is deprecated.
substitute (str): the substitute name for deprecated operator or function. substitute (str): the substitute name for deprecated operator or function.
use_substitute_name (bool): flag for whether to use substitute name for deprecated operator or function
""" """
def decorate(func): def decorate(func):
@ -29,6 +30,8 @@ def deprecated(version, substitute):
name = cls.__name__ if cls else func.__name__ name = cls.__name__ if cls else func.__name__
print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, " print(f"WARNING: '{name}' is deprecated from version {version} and will be removed in a future version, "
f"use '{substitute}' instead.") f"use '{substitute}' instead.")
if cls and use_substitute_name:
cls.substitute_name = substitute
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
return ret return ret

@ -33,6 +33,7 @@ from .. import signature as sig
from ..._checkparam import Rel from ..._checkparam import Rel
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common._decorator import deprecated
from ...common.parameter import Parameter from ...common.parameter import Parameter
from ...common.tensor import Tensor from ...common.tensor import Tensor
@ -820,10 +821,29 @@ class Gather(PrimitiveWithCheck):
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name) validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
def GatherV2(): class GatherV2(PrimitiveWithCheck):
"""Warning: This will be changed later""" """
logger.warning("WARN_DEPRECATED: The usage of GatherV2 is deprecated. Please use Gather.") Same as operator Gather. GatherV2 will be deprecated in the future.
return Gather() Please use Gather instead.
"""
#deprecate_new_name = "Gather"
@deprecated("1.1", "Gather", True)
@prim_attr_register
def __init__(self):
"""Initialize index_select"""
self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output'])
self.add_prim_attr("dynamic_shape_depends", [2])
def __check__(self, params, indices, axis):
validator.check_subclass("params", params['dtype'], mstype.tensor, self.name)
validator.check_tensor_dtype_valid("indices", indices['dtype'], mstype.int_type, self.name)
validator.check_subclass("axis", axis['dtype'], [mstype.number], self.name)
axis_v = axis['value']
validator.check_value_type('axis', axis_v, [int], self.name)
rank = len(params['shape'])
validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, "axis", self.name)
class SparseGatherV2(Gather): class SparseGatherV2(Gather):
""" """

@ -18,13 +18,13 @@
import copy import copy
import numpy as np import numpy as np
from mindspore import log as logger
from ... import context from ... import context
from .. import signature as sig from .. import signature as sig
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
from ...common.tensor import Tensor from ...common.tensor import Tensor
from ...common._decorator import deprecated
from .._utils import get_broadcast_shape from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@ -161,10 +161,28 @@ class Add(_MathBinaryOp):
return Tensor(out) return Tensor(out)
return None return None
def TensorAdd():
"""Warning: This will be changed later""" class TensorAdd(_MathBinaryOp):
logger.warning("WARN_DEPRECATED: The usage of TensorAdd is deprecated. Please use Add.") """
return Add() Same as operator Add. TensorAdd will be deprecated in the future.
Please use Add instead.
"""
#deprecate_new_name = "Add"
@deprecated("1.1", "Add", True)
@prim_attr_register
def __init__(self):
_MathBinaryOp.__init__(self)
def infer_value(self, x, y):
if x is not None and y is not None:
x = x.asnumpy()
y = y.asnumpy()
out = x + y
out = np.array(out, x.dtype)
return Tensor(out)
return None
class AssignAdd(PrimitiveWithInfer): class AssignAdd(PrimitiveWithInfer):
""" """

@ -466,10 +466,13 @@ def prim_attr_register(fn):
""" """
def deco(self, *args, **kwargs): def deco(self, *args, **kwargs):
class_name = self.__class__.__name__
if hasattr(self.__class__, "substitute_name"):
class_name = self.__class__.substitute_name
if isinstance(self, PrimitiveWithInfer): if isinstance(self, PrimitiveWithInfer):
PrimitiveWithInfer.__init__(self, self.__class__.__name__) PrimitiveWithInfer.__init__(self, class_name)
elif isinstance(self, PrimitiveWithCheck): elif isinstance(self, PrimitiveWithCheck):
PrimitiveWithCheck.__init__(self, self.__class__.__name__) PrimitiveWithCheck.__init__(self, class_name)
else: else:
Primitive.__init__(self, self.__class__.__name__) Primitive.__init__(self, self.__class__.__name__)
bound_args = inspect.signature(fn).bind(self, *args, **kwargs) bound_args = inspect.signature(fn).bind(self, *args, **kwargs)

Loading…
Cancel
Save