Make layers.reshape/expand/slice in dygraph support var inputs. (#22920)

* Make layers.reshape/expand/slice in dygraph support var inputs.
Make transpose support size 0.
test=develop

* Update layers.expand and layers.slice to support var inputs.
test=develop
revert-22778-infer_var_type
Guo Sheng 5 years ago committed by GitHub
parent a3b02e44dd
commit da803415cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5948,20 +5948,13 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
warnings.warn( warnings.warn(
"Inplace on reshape is not allowed and will be discarded in dygraph mode currently." "Inplace on reshape is not allowed and will be discarded in dygraph mode currently."
) )
attrs = {}
if isinstance(shape, (list, tuple)): if isinstance(shape, (list, tuple)):
if utils._contain_var(shape): shape = [
raise TypeError( item.numpy()[0] if isinstance(item, Variable) else item
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but " for item in shape
"received %s, which contains Variable." % type(shape)) ]
attrs['shape'] = shape out, _ = core.ops.reshape2(x, 'shape', shape)
else: return dygraph_utils._append_activation_in_dygraph(out, act)
raise TypeError(
"The type of 'shape' in reshape must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
out, _ = core.ops.reshape2(x, 'shape', shape)
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype( check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape') x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape')
@ -9770,16 +9763,12 @@ def expand(x, expand_times, name=None):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(expand_times, (list, tuple)): if isinstance(expand_times, (list, tuple)):
if utils._contain_var(expand_times): expand_times = [
raise TypeError( item.numpy()[0] if isinstance(item, Variable) else item
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but " for item in expand_times
"received %s, which contains Variable." % type(shape)) ]
else:
raise TypeError(
"The type of 'expand_times' in expand must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
return core.ops.expand(x, 'expand_times', expand_times) return core.ops.expand(x, 'expand_times', expand_times)
inputs = {"X": [x]} inputs = {"X": [x]}
attrs = {} attrs = {}
@ -10318,28 +10307,19 @@ def slice(input, axes, starts, ends):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
infer_flags = list(1 for i in range(len(axes))) infer_flags = list(1 for i in range(len(axes)))
if isinstance(starts, (list, tuple)): if isinstance(starts, (list, tuple)) and isinstance(ends,
if utils._contain_var(starts): (list, tuple)):
raise TypeError( starts = [
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " item.numpy()[0] if isinstance(item, Variable) else item
"received %s, which contains Variable." % type(shape)) for item in starts
else: ]
raise TypeError( ends = [
"The type of 'starts' in slice must be list[int] or tuple(int) in Dygraph mode, but " item.numpy()[0] if isinstance(item, Variable) else item
"received %s." % type(shape)) for item in ends
]
if isinstance(ends, (list, tuple)):
if utils._contain_var(ends):
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s, which contains Variable." % type(shape))
else:
raise TypeError(
"The type of 'ends' in slice must be list[int] or tuple(int) in Dygraph mode, but "
"received %s." % type(shape))
return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends', return core.ops.slice(input, 'axes', axes, 'starts', starts, 'ends',
ends, 'infer_flags', infer_flags) ends, 'infer_flags', infer_flags)
if not isinstance(starts, (list, tuple, Variable)): if not isinstance(starts, (list, tuple, Variable)):
raise ValueError( raise ValueError(

@ -725,7 +725,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`. data type is same as `x`.
""" """
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, self.beam_size) + x.shape[1:]) return nn.reshape(x, shape=[-1, self.beam_size] + list(x.shape[1:]))
def _merge_batch_beams(self, x): def _merge_batch_beams(self, x):
""" """
@ -741,7 +741,7 @@ class BeamSearchDecoder(Decoder):
data type is same as `x`. data type is same as `x`.
""" """
# TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch # TODO: avoid fake shape in compile-time like tile_beam_merge_with_batch
return nn.reshape(x, shape=(-1, ) + x.shape[2:]) return nn.reshape(x, shape=[-1] + list(x.shape[2:]))
def _expand_to_beam_size(self, x): def _expand_to_beam_size(self, x):
""" """

Loading…
Cancel
Save