|
|
|
@ -13,10 +13,8 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""math"""
|
|
|
|
|
import math
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops.operations import _inner_ops as inner
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.common._decorator import deprecated
|
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
|
@ -25,7 +23,6 @@ from ..cell import Cell
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['ReduceLogSumExp',
|
|
|
|
|
'Range',
|
|
|
|
|
'LGamma',
|
|
|
|
@ -140,37 +137,15 @@ class Range(Cell):
|
|
|
|
|
|
|
|
|
|
def __init__(self, start, limit=None, delta=1):
|
|
|
|
|
super(Range, self).__init__()
|
|
|
|
|
validator.check_value_type("start", start, [int, float], self.cls_name)
|
|
|
|
|
validator.check_value_type("delta", delta, [int, float], self.cls_name)
|
|
|
|
|
if delta == 0:
|
|
|
|
|
raise ValueError("The input of `delta` can not be equal to zero.")
|
|
|
|
|
if limit is not None:
|
|
|
|
|
validator.check_value_type("limit", limit, [int, float], self.cls_name)
|
|
|
|
|
if isinstance(start, int) and isinstance(limit, int) and isinstance(delta, int):
|
|
|
|
|
self.dtype = mstype.int32
|
|
|
|
|
else:
|
|
|
|
|
self.dtype = mstype.float32
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(start, int) and isinstance(delta, int):
|
|
|
|
|
self.dtype = mstype.int32
|
|
|
|
|
else:
|
|
|
|
|
self.dtype = mstype.float32
|
|
|
|
|
if isinstance(start, int):
|
|
|
|
|
start = float(start)
|
|
|
|
|
if isinstance(limit, int):
|
|
|
|
|
limit = float(limit)
|
|
|
|
|
if isinstance(delta, int):
|
|
|
|
|
delta = float(delta)
|
|
|
|
|
self.range_x = inner.Range(start, limit, delta)
|
|
|
|
|
if limit is None:
|
|
|
|
|
length_input = math.ceil(start / delta)
|
|
|
|
|
data = np.arange(start, limit, delta)
|
|
|
|
|
if data.dtype == np.float:
|
|
|
|
|
self.ms_dtype = mstype.float32
|
|
|
|
|
else:
|
|
|
|
|
length_input = math.ceil((limit - start) / delta)
|
|
|
|
|
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
|
|
|
|
|
self.ms_dtype = mstype.int32
|
|
|
|
|
self.result_tensor = Tensor(data, dtype=self.ms_dtype)
|
|
|
|
|
|
|
|
|
|
def construct(self):
|
|
|
|
|
range_out = self.range_x(self.input_tensor)
|
|
|
|
|
return range_out
|
|
|
|
|
return self.result_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LGamma(Cell):
|
|
|
|
|