You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
171 lines
5.0 KiB
171 lines
5.0 KiB
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import unittest
|
|
import numpy as np
|
|
from op_test import OpTest
|
|
import paddle.fluid.core as core
|
|
from paddle.fluid.op import Operator
|
|
import paddle.fluid.layers as layers
|
|
import paddle.fluid as fluid
|
|
import random
|
|
import six
|
|
|
|
|
|
def create_tdm_tree():
|
|
"""Create tdm tree info"""
|
|
tree_info = [
|
|
[0, 0, 0, 1, 2],
|
|
[0, 1, 0, 3, 4],
|
|
[0, 1, 0, 5, 6],
|
|
[0, 2, 1, 7, 8],
|
|
[0, 2, 1, 9, 10],
|
|
[0, 2, 2, 11, 12],
|
|
[0, 2, 2, 13, 0],
|
|
[0, 3, 3, 14, 15],
|
|
[0, 3, 3, 16, 17],
|
|
[0, 3, 4, 18, 19],
|
|
[0, 3, 4, 20, 21],
|
|
[0, 3, 5, 22, 23],
|
|
[0, 3, 5, 24, 25],
|
|
[12, 3, 6, 0, 0],
|
|
[0, 4, 7, 0, 0],
|
|
[1, 4, 7, 0, 0],
|
|
[2, 4, 8, 0, 0],
|
|
[3, 4, 8, 0, 0],
|
|
[4, 4, 9, 0, 0],
|
|
[5, 4, 9, 0, 0],
|
|
[6, 4, 10, 0, 0],
|
|
[7, 4, 10, 0, 0],
|
|
[8, 4, 11, 0, 0],
|
|
[9, 4, 11, 0, 0],
|
|
[10, 4, 12, 0, 0],
|
|
[11, 4, 12, 0, 0],
|
|
]
|
|
return tree_info
|
|
|
|
|
|
class TestTDMChildOp(OpTest):
|
|
def setUp(self):
|
|
self.__class__.op_type = "tdm_child"
|
|
self.config()
|
|
tree_info = create_tdm_tree()
|
|
tree_info_np = np.array(tree_info).astype(self.info_type)
|
|
|
|
x_np = np.random.randint(
|
|
low=0, high=26, size=self.x_shape).astype(self.x_type)
|
|
children_res = []
|
|
leaf_mask_res = []
|
|
for batch in x_np:
|
|
for node in batch:
|
|
children = []
|
|
if node != 0:
|
|
children.append(tree_info[node][3])
|
|
children.append(tree_info[node][4])
|
|
else:
|
|
children.append(0)
|
|
children.append(0)
|
|
mask = []
|
|
for child in children:
|
|
m = int(tree_info[child][0] != 0)
|
|
mask.append(m)
|
|
children_res += children
|
|
leaf_mask_res += mask
|
|
children_res_np = np.array(children_res).astype(self.info_type)
|
|
leaf_mask_res_np = np.array(leaf_mask_res).astype(self.info_type)
|
|
|
|
child = np.reshape(children_res_np, self.child_shape)
|
|
leaf_mask = np.reshape(leaf_mask_res_np, self.child_shape)
|
|
|
|
self.attrs = {'child_nums': 2}
|
|
self.inputs = {'X': x_np, 'TreeInfo': tree_info_np}
|
|
self.outputs = {'Child': child, 'LeafMask': leaf_mask}
|
|
|
|
def config(self):
|
|
"""set test shape & type"""
|
|
self.x_shape = (10, 20)
|
|
self.child_shape = (10, 20, 2)
|
|
self.x_type = 'int32'
|
|
self.info_type = 'int32'
|
|
|
|
def test_check_output(self):
|
|
self.check_output()
|
|
|
|
|
|
class TestCase1(TestTDMChildOp):
|
|
def config(self):
|
|
"""check int int64_t """
|
|
self.x_shape = (10, 20)
|
|
self.child_shape = (10, 20, 2)
|
|
self.x_type = 'int32'
|
|
self.info_type = 'int64'
|
|
|
|
|
|
class TestCase2(TestTDMChildOp):
|
|
def config(self):
|
|
"""check int64_t int64_t """
|
|
self.x_shape = (10, 20)
|
|
self.child_shape = (10, 20, 2)
|
|
self.x_type = 'int64'
|
|
self.info_type = 'int64'
|
|
|
|
|
|
class TestCase3(TestTDMChildOp):
|
|
def config(self):
|
|
"""check int64 int32 """
|
|
self.x_shape = (10, 20)
|
|
self.child_shape = (10, 20, 2)
|
|
self.x_type = 'int64'
|
|
self.info_type = 'int32'
|
|
|
|
|
|
class TestCase4(TestTDMChildOp):
|
|
def config(self):
|
|
"""check large shape """
|
|
self.x_shape = (100, 20)
|
|
self.child_shape = (100, 20, 2)
|
|
self.x_type = 'int32'
|
|
self.info_type = 'int32'
|
|
|
|
|
|
class TestTDMChildShape(unittest.TestCase):
|
|
def test_shape(self):
|
|
x = fluid.layers.data(name='x', shape=[1], dtype='int32', lod_level=1)
|
|
tdm_tree_info = create_tdm_tree()
|
|
tree_info_np = np.array(tdm_tree_info).astype('int32')
|
|
|
|
child, leaf_mask = fluid.contrib.layers.tdm_child(
|
|
x=x,
|
|
node_nums=26,
|
|
child_nums=2,
|
|
param_attr=fluid.ParamAttr(
|
|
initializer=fluid.initializer.NumpyArrayInitializer(
|
|
tree_info_np)))
|
|
|
|
place = fluid.CPUPlace()
|
|
exe = fluid.Executor(place=place)
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
feed = {
|
|
'x': np.array([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
|
[11], [12]]).astype('int32')
|
|
}
|
|
exe.run(feed=feed)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|