add reshape in paddle.complex (#24176)
* add reshape in paddle.complex, test=develop * fix typos in paddle.complex.kron's comment, fix unittest, test=developrevert-24314-dev/fix_err_msg
parent
e72832adb4
commit
e146e79e66
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle.common_ops_import import *
|
||||
from ..helper import is_complex, is_real, complex_variable_exists
|
||||
from ...fluid.framework import ComplexVariable
|
||||
from ...fluid import layers
|
||||
|
||||
__all__ = ['reshape', ]
|
||||
|
||||
|
||||
def reshape(x, shape, inplace=False, name=None):
|
||||
"""
|
||||
To change the shape of ``x`` without changing its data.
|
||||
|
||||
There are some tricks when specifying the target shape.
|
||||
|
||||
1. -1 means the value of this dimension is inferred from the total element
|
||||
number of x and remaining dimensions. Thus one and only one dimension can
|
||||
be set -1.
|
||||
|
||||
2. 0 means the actual dimension value is going to be copied from the
|
||||
corresponding dimension of x. The index of 0s in shape can not exceed
|
||||
the dimension of x.
|
||||
|
||||
Here are some examples to explain it.
|
||||
|
||||
1. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
|
||||
is [6, 8], the reshape operator will transform x into a 2-D tensor with
|
||||
shape [6, 8] and leaving x's data unchanged.
|
||||
|
||||
2. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
|
||||
specified is [2, 3, -1, 2], the reshape operator will transform x into a
|
||||
4-D tensor with shape [2, 3, 4, 2] and leaving x's data unchanged. In this
|
||||
case, one dimension of the target shape is set to -1, the value of this
|
||||
dimension is inferred from the total element number of x and remaining
|
||||
dimensions.
|
||||
|
||||
3. Given a 3-D tensor x with a shape [2, 4, 6], and the target shape
|
||||
is [-1, 0, 3, 2], the reshape operator will transform x into a 4-D tensor
|
||||
with shape [2, 4, 3, 2] and leaving x's data unchanged. In this case,
|
||||
besides -1, 0 means the actual dimension value is going to be copied from
|
||||
the corresponding dimension of x.
|
||||
|
||||
Args:
|
||||
x(ComplexVariable): the input. A ``Tensor`` or ``LoDTensor`` , data
|
||||
type: ``complex64`` or ``complex128``.
|
||||
shape(list|tuple|Variable): target shape. At most one dimension of
|
||||
the target shape can be -1. If ``shape`` is a list or tuple, the
|
||||
elements of it should be integers or Tensors with shape [1] and
|
||||
data type ``int32``. If ``shape`` is an Variable, it should be
|
||||
an 1-D Tensor of data type ``int32``.
|
||||
inplace(bool, optional): If ``inplace`` is True, the output of
|
||||
``reshape`` is the same ComplexVariable as the input. Otherwise,
|
||||
the input and output of ``reshape`` are different
|
||||
ComplexVariables. Defaults to False. Note that if ``x``is more
|
||||
than one OPs' input, ``inplace`` must be False.
|
||||
name(str, optional): The default value is None. Normally there is no
|
||||
need for user to set this property. For more information, please
|
||||
refer to :ref:`api_guide_Name` .
|
||||
|
||||
Returns:
|
||||
Variable: A ``Tensor`` or ``LoDTensor``. The data type is same as ``x``. It is a new ComplexVariable if ``inplace`` is ``False``, otherwise it is ``x``.
|
||||
|
||||
Raises:
|
||||
ValueError: If more than one elements of ``shape`` is -1.
|
||||
ValueError: If the element of ``shape`` is 0, the corresponding dimension should be less than or equal to the dimension of ``x``.
|
||||
ValueError: If the elements in ``shape`` is negative except -1.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.complex as cpx
|
||||
import paddle.fluid.dygraph as dg
|
||||
import numpy as np
|
||||
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
y_var = cpx.reshape(x_var, (2, -1))
|
||||
y_np = y_var.numpy()
|
||||
print(y_np.shape)
|
||||
# (2, 12)
|
||||
"""
|
||||
complex_variable_exists([x], "reshape")
|
||||
if inplace:
|
||||
x.real = fluid.layers.reshape(x.real, shape, inplace=inplace, name=name)
|
||||
x.imag = fluid.layers.reshape(x.imag, shape, inplace=inplace, name=name)
|
||||
return x
|
||||
out_real = fluid.layers.reshape(x.real, shape, inplace=inplace, name=name)
|
||||
out_imag = fluid.layers.reshape(x.imag, shape, inplace=inplace, name=name)
|
||||
return ComplexVariable(out_real, out_imag)
|
@ -0,0 +1,51 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle.complex as cpx
|
||||
import paddle.fluid.dygraph as dg
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
class TestComplexReshape(unittest.TestCase):
|
||||
def test_case1(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
shape = (2, -1)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
y_var = cpx.reshape(x_var, shape)
|
||||
y_np = y_var.numpy()
|
||||
|
||||
np.testing.assert_allclose(np.reshape(x_np, shape), y_np)
|
||||
|
||||
def test_case2(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
shape = (0, -1)
|
||||
shape_ = (2, 12)
|
||||
|
||||
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
|
||||
) else fluid.CPUPlace()
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
y_var = cpx.reshape(x_var, shape, inplace=True)
|
||||
y_np = y_var.numpy()
|
||||
|
||||
np.testing.assert_allclose(np.reshape(x_np, shape_), y_np)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue