|
|
|
@ -6148,8 +6148,12 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
|
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
|
for item in shape
|
|
|
|
|
]
|
|
|
|
|
out, _ = core.ops.reshape2(x, 'shape', shape)
|
|
|
|
|
return dygraph_utils._append_activation_in_dygraph(out, act)
|
|
|
|
|
out, _ = core.ops.reshape2(x, None, 'shape', shape)
|
|
|
|
|
elif isinstance(shape, Variable):
|
|
|
|
|
shape.stop_gradient = True
|
|
|
|
|
out, _ = core.ops.reshape2(x, shape)
|
|
|
|
|
|
|
|
|
|
return dygraph_utils._append_activation_in_dygraph(out, act)
|
|
|
|
|
|
|
|
|
|
check_variable_and_dtype(
|
|
|
|
|
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64',
|
|
|
|
@ -10315,13 +10319,19 @@ def expand(x, expand_times, name=None):
|
|
|
|
|
# the shape of expanded_2 is [48, 56].
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
attrs = ()
|
|
|
|
|
expand_times_tensor = None
|
|
|
|
|
if isinstance(expand_times, (list, tuple)):
|
|
|
|
|
expand_times = [
|
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
|
for item in expand_times
|
|
|
|
|
]
|
|
|
|
|
attrs += ('expand_times', expand_times)
|
|
|
|
|
elif isinstance(expand_times, Variable):
|
|
|
|
|
expand_times_tensor = expand_times
|
|
|
|
|
expand_times_tensor.stop_gradient = True
|
|
|
|
|
|
|
|
|
|
return core.ops.expand(x, 'expand_times', expand_times)
|
|
|
|
|
return core.ops.expand(x, expand_times_tensor, *attrs)
|
|
|
|
|
|
|
|
|
|
inputs = {"X": [x]}
|
|
|
|
|
attrs = {}
|
|
|
|
@ -10925,20 +10935,35 @@ def slice(input, axes, starts, ends):
|
|
|
|
|
# sliced_2 is input[0:3, 0:2, 2:4].
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
attrs = ()
|
|
|
|
|
starts_tensor = None
|
|
|
|
|
ends_tensor = None
|
|
|
|
|
infer_flags = list(1 for i in range(len(axes)))
|
|
|
|
|
if isinstance(starts, (list, tuple)) and isinstance(ends,
|
|
|
|
|
(list, tuple)):
|
|
|
|
|
|
|
|
|
|
if isinstance(starts, (list, tuple)):
|
|
|
|
|
starts = [
|
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
|
for item in starts
|
|
|
|
|
]
|
|
|
|
|
attrs += ('starts', starts)
|
|
|
|
|
elif isinstance(starts, Variable):
|
|
|
|
|
starts_tensor = starts
|
|
|
|
|
starts.stop_gradient = True
|
|
|
|
|
infer_flags = list(-1 for i in range(len(axes)))
|
|
|
|
|
|
|
|
|
|
if isinstance(ends, (list, tuple)):
|
|
|
|
|
ends = [
|
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
|
for item in ends
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends',
|
|
|
|
|
ends, 'infer_flags', infer_flags)
|
|
|
|
|
attrs += ('ends', ends)
|
|
|
|
|
elif isinstance(ends, Variable):
|
|
|
|
|
ends_tensor = ends
|
|
|
|
|
ends_tensor.stop_gradient = True
|
|
|
|
|
infer_flags = list(-1 for i in range(len(axes)))
|
|
|
|
|
|
|
|
|
|
return core.ops.slice(input, starts_tensor, ends_tensor, 'axes', axes,
|
|
|
|
|
'infer_flags', infer_flags, *attrs)
|
|
|
|
|
|
|
|
|
|
if not isinstance(starts, (list, tuple, Variable)):
|
|
|
|
|
raise ValueError(
|
|
|
|
|