|
|
|
@ -18,7 +18,7 @@ from six.moves import reduce
|
|
|
|
|
from ..layer_helper import LayerHelper
|
|
|
|
|
from ..param_attr import ParamAttr
|
|
|
|
|
from ..initializer import Initializer
|
|
|
|
|
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator
|
|
|
|
|
from ..framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard
|
|
|
|
|
from ..framework import Variable
|
|
|
|
|
from ..initializer import Constant
|
|
|
|
|
from ..core import VarDesc
|
|
|
|
@ -1394,17 +1394,20 @@ def range(start, end, step, dtype, name=None):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(dtype)
|
|
|
|
|
|
|
|
|
|
if not isinstance(start, Variable):
|
|
|
|
|
start = fill_constant([1], dtype, start)
|
|
|
|
|
with device_guard("cpu"):
|
|
|
|
|
start = fill_constant([1], dtype, start)
|
|
|
|
|
elif start.dtype != dtype:
|
|
|
|
|
start = cast(start, dtype)
|
|
|
|
|
|
|
|
|
|
if not isinstance(end, Variable):
|
|
|
|
|
end = fill_constant([1], dtype, end)
|
|
|
|
|
with device_guard("cpu"):
|
|
|
|
|
end = fill_constant([1], dtype, end)
|
|
|
|
|
elif end.dtype != dtype:
|
|
|
|
|
end = cast(end, dtype)
|
|
|
|
|
|
|
|
|
|
if not isinstance(step, Variable):
|
|
|
|
|
step = fill_constant([1], dtype, step)
|
|
|
|
|
with device_guard("cpu"):
|
|
|
|
|
step = fill_constant([1], dtype, step)
|
|
|
|
|
elif step.dtype != dtype:
|
|
|
|
|
step = cast(step, dtype)
|
|
|
|
|
|
|
|
|
|