|
|
@ -4841,7 +4841,7 @@ def split(input, num_or_sections, dim=-1, name=None):
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(dim, Variable):
|
|
|
|
if isinstance(dim, Variable):
|
|
|
|
dim = dim.numpy()
|
|
|
|
dim = dim.numpy()
|
|
|
|
dim = dim[0]
|
|
|
|
dim = dim.item(0)
|
|
|
|
dim = (len(input.shape) + dim) if dim < 0 else dim
|
|
|
|
dim = (len(input.shape) + dim) if dim < 0 else dim
|
|
|
|
attrs += ('axis', dim)
|
|
|
|
attrs += ('axis', dim)
|
|
|
|
|
|
|
|
|
|
|
@ -5885,7 +5885,7 @@ def one_hot(input, depth, allow_out_of_range=False):
|
|
|
|
depth = depth.numpy()
|
|
|
|
depth = depth.numpy()
|
|
|
|
assert depth.shape == (
|
|
|
|
assert depth.shape == (
|
|
|
|
1, ), "depth of type Variable should have shape [1]"
|
|
|
|
1, ), "depth of type Variable should have shape [1]"
|
|
|
|
depth = depth[0]
|
|
|
|
depth = depth.item(0)
|
|
|
|
out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range',
|
|
|
|
out = core.ops.one_hot(input, 'depth', depth, 'allow_out_of_range',
|
|
|
|
allow_out_of_range)
|
|
|
|
allow_out_of_range)
|
|
|
|
out.stop_gradient = True
|
|
|
|
out.stop_gradient = True
|
|
|
@ -6067,7 +6067,7 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if isinstance(shape, (list, tuple)):
|
|
|
|
if isinstance(shape, (list, tuple)):
|
|
|
|
shape = [
|
|
|
|
shape = [
|
|
|
|
item.numpy()[0] if isinstance(item, Variable) else item
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
for item in shape
|
|
|
|
for item in shape
|
|
|
|
]
|
|
|
|
]
|
|
|
|
out, _ = core.ops.reshape2(x, 'shape', shape)
|
|
|
|
out, _ = core.ops.reshape2(x, 'shape', shape)
|
|
|
@ -10195,7 +10195,7 @@ def expand(x, expand_times, name=None):
|
|
|
|
if in_dygraph_mode():
|
|
|
|
if in_dygraph_mode():
|
|
|
|
if isinstance(expand_times, (list, tuple)):
|
|
|
|
if isinstance(expand_times, (list, tuple)):
|
|
|
|
expand_times = [
|
|
|
|
expand_times = [
|
|
|
|
item.numpy()[0] if isinstance(item, Variable) else item
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
for item in expand_times
|
|
|
|
for item in expand_times
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
@ -10806,11 +10806,11 @@ def slice(input, axes, starts, ends):
|
|
|
|
if isinstance(starts, (list, tuple)) and isinstance(ends,
|
|
|
|
if isinstance(starts, (list, tuple)) and isinstance(ends,
|
|
|
|
(list, tuple)):
|
|
|
|
(list, tuple)):
|
|
|
|
starts = [
|
|
|
|
starts = [
|
|
|
|
item.numpy()[0] if isinstance(item, Variable) else item
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
for item in starts
|
|
|
|
for item in starts
|
|
|
|
]
|
|
|
|
]
|
|
|
|
ends = [
|
|
|
|
ends = [
|
|
|
|
item.numpy()[0] if isinstance(item, Variable) else item
|
|
|
|
item.numpy().item(0) if isinstance(item, Variable) else item
|
|
|
|
for item in ends
|
|
|
|
for item in ends
|
|
|
|
]
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|