add + - * / @ [] operator to ComplexVariable (#28217)
* add + - * / @ [] operator to ComplexVariable, also add unittest * fix circular reference bug * fit for py2.7 * remove reverse oprators which not supported nowrevert-28284-dev/pybind_version
parent
a98c69b6c6
commit
6cebd71454
@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
import paddle.fluid.dygraph as dg
|
||||
|
||||
|
||||
class TestComplexGetitemLayer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._places = [fluid.CPUPlace()]
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
self._places.append(fluid.CUDAPlace(0))
|
||||
|
||||
def test_case1(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
def test_case2(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0][1]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0][1]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
def test_case3(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0][1][2]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0][1][2]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
def test_case4(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0][1][0:3]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0][1][0:3]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
def test_case5(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0][1][0:4:2]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0][1][0:4:2]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
def test_case6(self):
|
||||
x_np = np.random.randn(2, 3, 4) + 1j * np.random.randn(2, 3, 4)
|
||||
x_np_slice = x_np[0][1:3][0:4:2]
|
||||
|
||||
for place in self._places:
|
||||
with dg.guard(place):
|
||||
x_var = dg.to_variable(x_np)
|
||||
x_var_slice = x_var[0][1:3][0:4:2]
|
||||
|
||||
np.testing.assert_allclose(x_var_slice.numpy(), x_np_slice)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
from ...fluid import framework
|
||||
from . import tensor
|
||||
|
||||
|
||||
def monkey_patch_math_complex():
|
||||
# complexVariable do not support scaler type now, so here not contains
|
||||
# reverse methods, such as "__radd__", "__rsub__", "__rmul__", "__rdiv__",
|
||||
# "__rtruediv__", "__rmatmul__".
|
||||
complex_methods = [
|
||||
('__add__', _binary_creator_('__add__', "elementwise_add", False)),
|
||||
('__sub__', _binary_creator_('__sub__', "elementwise_sub", False)),
|
||||
('__mul__', _binary_creator_('__mul__', "elementwise_mul", False)),
|
||||
('__div__', _binary_creator_('__div__', "elementwise_div", False)),
|
||||
('__truediv__', _binary_creator_('__truediv__', "elementwise_div",
|
||||
False)),
|
||||
('__matmul__', _binary_creator_('__matmul__', "matmul", False)),
|
||||
]
|
||||
|
||||
for method in complex_methods:
|
||||
method_name = method[0]
|
||||
method_impl = method[1]
|
||||
if method_impl:
|
||||
setattr(framework.ComplexVariable, method_name, method_impl)
|
||||
|
||||
for method in tensor.__all__:
|
||||
method_impl = getattr(tensor, method)
|
||||
if method_impl:
|
||||
setattr(framework.ComplexVariable, method, method_impl)
|
||||
|
||||
|
||||
# for binary operator such as elementwise
|
||||
def _binary_creator_(method_name, op_type, reverse=False):
|
||||
def __impl__(self, other_var):
|
||||
math_op = getattr(tensor, op_type)
|
||||
return math_op(self, other_var)
|
||||
|
||||
__impl__.__name__ = method_name
|
||||
return __impl__
|
Loading…
Reference in new issue