|
|
|
@ -17,6 +17,7 @@ from __future__ import print_function
|
|
|
|
|
import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid.framework import program_guard, Program
|
|
|
|
|
|
|
|
|
@ -52,6 +53,7 @@ class TestGatherTreeOp(OpTest):
|
|
|
|
|
|
|
|
|
|
class TestGatherTreeOpAPI(unittest.TestCase):
|
|
|
|
|
def test_case(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
ids = fluid.layers.data(
|
|
|
|
|
name='ids', shape=[5, 2, 2], dtype='int64', append_batch_size=False)
|
|
|
|
|
parents = fluid.layers.data(
|
|
|
|
@ -60,10 +62,19 @@ class TestGatherTreeOpAPI(unittest.TestCase):
|
|
|
|
|
dtype='int64',
|
|
|
|
|
append_batch_size=False)
|
|
|
|
|
final_sequences = fluid.layers.gather_tree(ids, parents)
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
def test_case2(self):
|
|
|
|
|
ids = paddle.to_tensor(
|
|
|
|
|
[[[2, 2], [6, 1]], [[3, 9], [6, 1]], [[0, 1], [9, 0]]])
|
|
|
|
|
parents = paddle.to_tensor(
|
|
|
|
|
[[[0, 0], [1, 1]], [[1, 0], [1, 0]], [[0, 0], [0, 1]]])
|
|
|
|
|
final_sequences = paddle.nn.functional.gather_tree(ids, parents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGatherTreeOpError(unittest.TestCase):
|
|
|
|
|
def test_errors(self):
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
ids = fluid.layers.data(
|
|
|
|
|
name='ids',
|
|
|
|
@ -111,6 +122,7 @@ class TestGatherTreeOpError(unittest.TestCase):
|
|
|
|
|
fluid.layers.gather_tree(ids, bad_parents)
|
|
|
|
|
|
|
|
|
|
self.assertRaises(TypeError, test_type_parents)
|
|
|
|
|
paddle.disable_static()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|