update cross op parameters for API 2.0

* update cross op parameters
fix_copy_if_different
LutaoChu 5 years ago committed by GitHub
parent 1ab60544f2
commit bbe8f7bdcd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -79,7 +79,7 @@ class TestCrossAPI(unittest.TestCase):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[-1, 3])
y = fluid.layers.data(name='y', shape=[-1, 3])
z = paddle.cross(x, y, dim=1)
z = paddle.cross(x, y, axis=1)
exe = fluid.Executor(fluid.CPUPlace())
res, = exe.run(feed={'x': self.data_x,
'y': self.data_y},
@ -103,6 +103,14 @@ class TestCrossAPI(unittest.TestCase):
[-1.0, -1.0, -1.0]])
self.assertTrue(np.allclose(expect_out, np.array(res)))
# case 3:
with program_guard(Program(), Program()):
x = fluid.data(name="x", shape=[-1, 3], dtype="float32")
y = fluid.data(name='y', shape=[-1, 3], dtype='float32')
y_1 = paddle.cross(x, y, name='result')
self.assertEqual(('result' in y_1.name), True)
def test_dygraph_api(self):
self.input_data()
# case 1:
@ -119,7 +127,7 @@ class TestCrossAPI(unittest.TestCase):
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(self.data_x)
y = fluid.dygraph.to_variable(self.data_y)
z = paddle.cross(x, y, dim=1)
z = paddle.cross(x, y, axis=1)
np_z = z.numpy()
expect_out = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0],
[0.0, 0.0, 0.0]])

@ -583,66 +583,69 @@ def t(input, name=None):
return out
def cross(input, other, dim=None):
def cross(x, y, axis=None, name=None):
"""
:alias_main: paddle.cross
:alias: paddle.cross,paddle.tensor.cross,paddle.tensor.linalg.cross
Returns the cross product of vectors in dimension `dim` of the `input` and `other` tensor.
Inputs must have the same shape, and the size of their dim-th dimension should be equla to 3.
If `dim` is not given, it defaults to the first dimension found with the size 3.
Computes the cross product between two tensors along an axis.
Inputs must have the same shape, and the length of their axes should be equal to 3.
If `axis` is not given, it defaults to the first axis found with the length 3.
Args:
input (Variable): The first input tensor variable.
other (Variable): The second input tensor variable.
dim (int): The dimension to take the cross-product in.
x (Variable): The first input tensor variable.
y (Variable): The second input tensor variable.
axis (int, optional): The axis along which to compute the cross product. It defaults to the first axis found with the length 3.
name (str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Variable: A Tensor with same data type as `input`.
Variable: A Tensor with same data type as `x`.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
from paddle.imperative import to_variable
import numpy as np
paddle.enable_imperative()
data_x = np.array([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
data_y = np.array([[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]])
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(data_x)
y = fluid.dygraph.to_variable(data_y)
out_z1 = paddle.cross(x, y)
print(out_z1.numpy())
#[[-1. -1. -1.]
# [ 2. 2. 2.]
# [-1. -1. -1.]]
out_z2 = paddle.cross(x, y, dim=1)
print(out_z2.numpy())
#[[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 0.]]
x = to_variable(data_x)
y = to_variable(data_y)
z1 = paddle.cross(x, y)
print(z1.numpy())
# [[-1. -1. -1.]
# [ 2. 2. 2.]
# [-1. -1. -1.]]
z2 = paddle.cross(x, y, axis=1)
print(z2.numpy())
# [[0. 0. 0.]
# [0. 0. 0.]
# [0. 0. 0.]]
"""
helper = LayerHelper("cross", **locals())
if in_dygraph_mode():
if dim:
return core.ops.cross(input, other, 'dim', dim)
if axis:
return core.ops.cross(x, y, 'dim', axis)
else:
return core.ops.cross(input, other)
return core.ops.cross(x, y)
out = helper.create_variable_for_type_inference(input.dtype)
helper = LayerHelper("cross", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
attrs = dict()
if dim:
attrs['dim'] = dim
attrs['dim'] = axis
helper.append_op(
type='cross',
inputs={'X': input,
'Y': other},
inputs={'X': x,
'Y': y},
outputs={'Out': out},
attrs=attrs)
return out

Loading…
Cancel
Save