polish_tril_triu_docstring and add dygraph (#24055)

* Update creation.py
revert-22778-infer_var_type
WuHaobo 5 years ago committed by GitHub
parent d31a174f51
commit 79eaac55ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -134,6 +134,14 @@ class TestTrilTriuOpAPI(unittest.TestCase):
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
def test_api_with_dygraph(self):
with fluid.dygraph.guard():
data = np.random.random([1, 9, 9, 4]).astype('float32')
x = fluid.dygraph.to_variable(data)
tril_out, triu_out = tensor.tril(x).numpy(), tensor.triu(x).numpy()
self.assertTrue(np.allclose(tril_out, np.tril(data)))
self.assertTrue(np.allclose(triu_out, np.triu(data)))
if __name__ == '__main__':
unittest.main()

@ -696,8 +696,6 @@ def tril(input, diagonal=0, name=None):
# [ 5, 6, 0, 0],
# [ 9, 10, 11, 0]])
.. code-block:: python
# example 2, positive diagonal value
tril = tensor.tril(x, diagonal=2)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
@ -706,8 +704,6 @@ def tril(input, diagonal=0, name=None):
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
.. code-block:: python
# example 3, negative diagonal value
tril = tensor.tril(x, diagonal=-1)
tril_out, = exe.run(fluid.default_main_program(), feed={"x": data},
@ -716,7 +712,10 @@ def tril(input, diagonal=0, name=None):
# [ 5, 0, 0, 0],
# [ 9, 10, 0, 0]])
"""
"""
if in_dygraph_mode():
op = getattr(core.ops, 'tril_triu')
return op(input, 'diagonal', diagonal, "lower", True)
return _tril_triu_op(LayerHelper('tril', **locals()))
@ -771,8 +770,6 @@ def triu(input, diagonal=0, name=None):
# [ 0, 6, 7, 8],
# [ 0, 0, 11, 12]])
.. code-block:: python
# example 2, positive diagonal value
triu = tensor.triu(x, diagonal=2)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
@ -781,8 +778,6 @@ def triu(input, diagonal=0, name=None):
# [0, 0, 0, 8],
# [0, 0, 0, 0]])
.. code-block:: python
# example 3, negative diagonal value
triu = tensor.triu(x, diagonal=-1)
triu_out, = exe.run(fluid.default_main_program(), feed={"x": data},
@ -792,6 +787,9 @@ def triu(input, diagonal=0, name=None):
# [ 0, 10, 11, 12]])
"""
if in_dygraph_mode():
op = getattr(core.ops, 'tril_triu')
return op(input, 'diagonal', diagonal, "lower", False)
return _tril_triu_op(LayerHelper('triu', **locals()))

Loading…
Cancel
Save