!11153 Update nn.Range for GPU backend.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
pull/11153/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fac93540a0

@ -15,6 +15,7 @@
"""math"""
import math
import numpy as np
import mindspore.context as context
from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.common.tensor import Tensor
@ -116,7 +117,7 @@ class Range(Cell):
Tensor, the dtype is int if the dtype of `start`, `limit` and `delta` all are int. Otherwise, dtype is float.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> net = nn.Range(1, 8, 2)
@ -127,6 +128,7 @@ class Range(Cell):
def __init__(self, start, limit=None, delta=1):
super(Range, self).__init__()
self.is_gpu = context.get_context("device_target") == "GPU"
validator.check_value_type("start", start, [int, float], self.cls_name)
validator.check_value_type("delta", delta, [int, float], self.cls_name)
if delta == 0:
@ -155,8 +157,17 @@ class Range(Cell):
length_input = math.ceil((limit - start) / delta)
self.input_tensor = Tensor(list(range(length_input)), self.dtype)
if self.is_gpu:
self.start = Tensor(start, self.dtype)
self.limit = Tensor(limit, self.dtype)
self.delta = Tensor(delta, self.dtype)
self.range_gpu = P.Range(length_input)
def construct(self):
range_out = self.range_x(self.input_tensor)
if self.is_gpu:
range_out = self.range_gpu(self.start, self.limit, self.delta)
else:
range_out = self.range_x(self.input_tensor)
return range_out

@ -4761,7 +4761,7 @@ class Range(PrimitiveWithCheck):
[0, 4, 8]
Supported Platforms:
``Ascend`` ``GPU``
``GPU``
"""
@prim_attr_register

Loading…
Cancel
Save