Follow comments and refine the python wrapper of reshape_op

helinwang-patch-1
guosheng 7 years ago
parent 454b0a96be
commit d4bb2ca71f

@ -3361,7 +3361,9 @@ def reshape(x, shape, act=None, inplace=True, name=None):
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[2, 4, 6], dtype='float32')
data = fluid.layers.data(
name='data', shape=[2, 4, 6], dtype='float32'
)
reshaped = fluid.layers.reshape(
x=data, shape=[-1, 0, 3, 2], act='tanh', inplace=True
)
@ -3371,6 +3373,21 @@ def reshape(x, shape, act=None, inplace=True, name=None):
if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.")
# Validate the shape
unk_dim_idx = -1
for dim_idx, dim_size in enumerate(shape):
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one dimension in shape can be unknown.")
unk_dim_idx = dim_idx
elif dim_size == 0:
assert dim_idx < len(x.shape), (
"The indice of 0s in shape can not exceed Rank(X).")
else:
assert dim_size > 0, (
"Each dimension size given in shape must not be negtive "
"except one unknown dimension.")
helper = LayerHelper("reshape", **locals())
reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(

Loading…
Cancel
Save