You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/tests/unittests/test_tdm_child_op.py

171 lines
5.0 KiB

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import random
import six
def create_tdm_tree():
"""Create tdm tree info"""
tree_info = [
[0, 0, 0, 1, 2],
[0, 1, 0, 3, 4],
[0, 1, 0, 5, 6],
[0, 2, 1, 7, 8],
[0, 2, 1, 9, 10],
[0, 2, 2, 11, 12],
[0, 2, 2, 13, 0],
[0, 3, 3, 14, 15],
[0, 3, 3, 16, 17],
[0, 3, 4, 18, 19],
[0, 3, 4, 20, 21],
[0, 3, 5, 22, 23],
[0, 3, 5, 24, 25],
[12, 3, 6, 0, 0],
[0, 4, 7, 0, 0],
[1, 4, 7, 0, 0],
[2, 4, 8, 0, 0],
[3, 4, 8, 0, 0],
[4, 4, 9, 0, 0],
[5, 4, 9, 0, 0],
[6, 4, 10, 0, 0],
[7, 4, 10, 0, 0],
[8, 4, 11, 0, 0],
[9, 4, 11, 0, 0],
[10, 4, 12, 0, 0],
[11, 4, 12, 0, 0],
]
return tree_info
class TestTDMChildOp(OpTest):
def setUp(self):
self.__class__.op_type = "tdm_child"
self.config()
tree_info = create_tdm_tree()
tree_info_np = np.array(tree_info).astype(self.info_type)
x_np = np.random.randint(
low=0, high=26, size=self.x_shape).astype(self.x_type)
children_res = []
leaf_mask_res = []
for batch in x_np:
for node in batch:
children = []
if node != 0:
children.append(tree_info[node][3])
children.append(tree_info[node][4])
else:
children.append(0)
children.append(0)
mask = []
for child in children:
m = int(tree_info[child][0] != 0)
mask.append(m)
children_res += children
leaf_mask_res += mask
children_res_np = np.array(children_res).astype(self.info_type)
leaf_mask_res_np = np.array(leaf_mask_res).astype(self.info_type)
child = np.reshape(children_res_np, self.child_shape)
leaf_mask = np.reshape(leaf_mask_res_np, self.child_shape)
self.attrs = {'child_nums': 2}
self.inputs = {'X': x_np, 'TreeInfo': tree_info_np}
self.outputs = {'Child': child, 'LeafMask': leaf_mask}
def config(self):
"""set test shape & type"""
self.x_shape = (10, 20)
self.child_shape = (10, 20, 2)
self.x_type = 'int32'
self.info_type = 'int32'
def test_check_output(self):
self.check_output()
class TestCase1(TestTDMChildOp):
def config(self):
"""check int int64_t """
self.x_shape = (10, 20)
self.child_shape = (10, 20, 2)
self.x_type = 'int32'
self.info_type = 'int64'
class TestCase2(TestTDMChildOp):
def config(self):
"""check int64_t int64_t """
self.x_shape = (10, 20)
self.child_shape = (10, 20, 2)
self.x_type = 'int64'
self.info_type = 'int64'
class TestCase3(TestTDMChildOp):
def config(self):
"""check int64 int32 """
self.x_shape = (10, 20)
self.child_shape = (10, 20, 2)
self.x_type = 'int64'
self.info_type = 'int32'
class TestCase4(TestTDMChildOp):
def config(self):
"""check large shape """
self.x_shape = (100, 20)
self.child_shape = (100, 20, 2)
self.x_type = 'int32'
self.info_type = 'int32'
class TestTDMChildShape(unittest.TestCase):
def test_shape(self):
x = fluid.layers.data(name='x', shape=[1], dtype='int32', lod_level=1)
tdm_tree_info = create_tdm_tree()
tree_info_np = np.array(tdm_tree_info).astype('int32')
child, leaf_mask = fluid.contrib.layers.tdm_child(
x=x,
node_nums=26,
child_nums=2,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
tree_info_np)))
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
exe.run(fluid.default_startup_program())
feed = {
'x': np.array([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
[11], [12]]).astype('int32')
}
exe.run(feed=feed)
if __name__ == "__main__":
unittest.main()