parent
4955c97ee8
commit
a2e9af5663
@ -0,0 +1,116 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/tdm_child_op.h"
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/sampler.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
class TDMChildOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() {
|
||||
AddInput("X",
|
||||
"X(Tensor), dtype support int32/int64, X variable is the "
|
||||
"node id of TDM-Tree");
|
||||
AddInput(
|
||||
"TreeInfo",
|
||||
"TreeInfo(Tensor), dtype support int32/int64, it stores the node "
|
||||
"information in the following format: item_id(shape=1), "
|
||||
"layer_id(shape=1), parent_id(shape=1), child_id(shape=child_nums)");
|
||||
AddAttr<int>("child_nums", "child_nums(int)",
|
||||
"The child nums of one node, if the node hasn't enough child, "
|
||||
"it should padding 0 until child nums equal to child_nums");
|
||||
AddOutput("Child",
|
||||
"Return the children's node_id of input node, "
|
||||
"if input don't have child, return 0");
|
||||
AddOutput("LeafMask",
|
||||
"LeafMask has the same shape with Child"
|
||||
"If child is leaf node, LeafMask value = 1, else = 0");
|
||||
AddAttr<int>("dtype",
|
||||
"(int, default INT32) "
|
||||
"Output data type.")
|
||||
.SetDefault(2);
|
||||
AddComment(R"DOC("
|
||||
**Tdm Child**
|
||||
According to the input node_id on the given tree, return the corresponding child node_id and
|
||||
whether child is a leaf node by LeafMask.")DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class TDMChildOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(X) of TdmChild should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("TreeInfo"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Inputs(TreeInfo) of TdmChild should not be null."));
|
||||
|
||||
int child_nums = ctx->Attrs().Get<int>("child_nums");
|
||||
PADDLE_ENFORCE_GT(
|
||||
child_nums, 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"ValueError: The value of the 'child_nums' must greater than 0. "
|
||||
"But received child_nums value = %d, ",
|
||||
child_nums));
|
||||
|
||||
auto info_dims = ctx->GetInputDim("TreeInfo");
|
||||
auto input_dims = ctx->GetInputDim("X");
|
||||
|
||||
PADDLE_ENFORCE_EQ(
|
||||
info_dims.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"ShapeError: The dimensions of the 'tree info' must be 2. "
|
||||
"But received tree info's dimensions = %d, "
|
||||
"tree info's shape = [%s].",
|
||||
info_dims.size(), info_dims));
|
||||
|
||||
auto output_dims = framework::vectorize(input_dims);
|
||||
output_dims.push_back(child_nums);
|
||||
ctx->SetOutputDim("Child", framework::make_ddim(output_dims));
|
||||
ctx->SetOutputDim("LeafMask", framework::make_ddim(output_dims));
|
||||
|
||||
if (ctx->GetOutputsVarType("Child")[0] ==
|
||||
framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("X", /*->*/ "Child");
|
||||
ctx->ShareLoD("X", /*->*/ "LeafMask");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
tdm_child, ops::TDMChildOp, ops::TDMChildOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
tdm_child, ops::TDMChildKernel<paddle::platform::CPUPlace, float>,
|
||||
ops::TDMChildKernel<paddle::platform::CPUPlace, double>,
|
||||
ops::TDMChildKernel<paddle::platform::CPUPlace, int>,
|
||||
ops::TDMChildKernel<paddle::platform::CPUPlace, int64_t>);
|
@ -0,0 +1,179 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/mixed_vector.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using DDim = framework::DDim;
|
||||
using LoD = framework::LoD;
|
||||
|
||||
template <typename T, typename InfoT = int, typename OutT = int>
|
||||
void TDMChildInner(const framework::ExecutionContext &context,
|
||||
const LoDTensor &input, const LoDTensor &tree_info,
|
||||
LoDTensor *child, LoDTensor *mask) {
|
||||
auto child_nums = context.Attr<int>("child_nums");
|
||||
auto info_dims = tree_info.dims();
|
||||
int node_nums = info_dims[0];
|
||||
int length = info_dims[1];
|
||||
|
||||
int input_ids_num = input.numel();
|
||||
VLOG(4) << "TDM child op: input numel -> " << input_ids_num;
|
||||
|
||||
std::vector<OutT> child_vec{};
|
||||
std::vector<OutT> item_mask_vec{};
|
||||
|
||||
auto *input_data = input.data<T>();
|
||||
auto *tree_info_data = tree_info.data<InfoT>();
|
||||
|
||||
// TreeInfo: node_id : item_id; layer_id; ancestor_id; child_id
|
||||
for (int input_ids = 0; input_ids < input_ids_num; ++input_ids) {
|
||||
PADDLE_ENFORCE_LT(
|
||||
input_data[input_ids], node_nums,
|
||||
platform::errors::InvalidArgument(
|
||||
"input id of OP(fluid.contrib.layers.tdm_child) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
node_nums, input_data[input_ids]));
|
||||
PADDLE_ENFORCE_LE(
|
||||
0, input_data[input_ids],
|
||||
platform::errors::InvalidArgument(
|
||||
"input id of OP(fluid.contrib.layers.tdm_child) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
node_nums, input_data[input_ids]));
|
||||
|
||||
bool has_child =
|
||||
(input_data[input_ids] == 0 ||
|
||||
tree_info_data[static_cast<int>(input_data[input_ids]) * length + 3] ==
|
||||
0)
|
||||
? false
|
||||
: true;
|
||||
|
||||
if (has_child) {
|
||||
for (int child_ids = 0; child_ids < child_nums; ++child_ids) {
|
||||
OutT child_id = static_cast<OutT>(
|
||||
tree_info_data[static_cast<int>(input_data[input_ids]) * length +
|
||||
3 + child_ids]);
|
||||
child_vec.push_back(child_id);
|
||||
OutT child_is_item = static_cast<OutT>(
|
||||
tree_info_data[static_cast<int>(child_id) * length] == 0 ? 0 : 1);
|
||||
item_mask_vec.push_back(child_is_item);
|
||||
}
|
||||
} else {
|
||||
for (int child_ids = 0; child_ids < child_nums; ++child_ids) {
|
||||
child_vec.push_back(0);
|
||||
item_mask_vec.push_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int output_nums = child_vec.size();
|
||||
auto *child_data = child->mutable_data<OutT>(context.GetPlace());
|
||||
auto *leaf_mask_data = mask->mutable_data<OutT>(context.GetPlace());
|
||||
|
||||
memcpy(child_data, &child_vec[0], sizeof(OutT) * output_nums);
|
||||
memcpy(leaf_mask_data, &item_mask_vec[0], sizeof(OutT) * output_nums);
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class TDMChildKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *input_var = ctx.InputVar("X");
|
||||
auto *tree_info_var = ctx.InputVar("TreeInfo");
|
||||
|
||||
auto &input_tensor = input_var->Get<LoDTensor>();
|
||||
const auto &input_type = input_tensor.type();
|
||||
bool input_type_match = input_type == framework::proto::VarType::INT32 ||
|
||||
input_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(input_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(input_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
|
||||
auto &tree_info_tensor = tree_info_var->Get<LoDTensor>();
|
||||
const auto &info_type = tree_info_tensor.type();
|
||||
bool info_type_match = info_type == framework::proto::VarType::INT32 ||
|
||||
info_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
info_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(TreeInfo) holds the wrong type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(info_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
|
||||
auto *child_var = ctx.OutputVar("Child");
|
||||
auto *leaf_mask_var = ctx.OutputVar("LeafMask");
|
||||
auto *child_tensor = child_var->GetMutable<framework::LoDTensor>();
|
||||
auto *leaf_mask_tensor = leaf_mask_var->GetMutable<framework::LoDTensor>();
|
||||
|
||||
auto output_type =
|
||||
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
|
||||
bool out_type_match = output_type == framework::proto::VarType::INT32 ||
|
||||
output_type == framework::proto::VarType::INT64;
|
||||
PADDLE_ENFORCE_EQ(out_type_match, true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Ouput(Child) & Output(LeafMask) holds the wrong "
|
||||
"type, it holds %s, but "
|
||||
"desires to be %s or %s",
|
||||
paddle::framework::DataTypeToString(output_type),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT32),
|
||||
paddle::framework::DataTypeToString(
|
||||
framework::proto::VarType::INT64)));
|
||||
|
||||
if (info_type == framework::proto::VarType::INT32 &&
|
||||
output_type == framework::proto::VarType::INT32) {
|
||||
TDMChildInner<T, int, int>(ctx, input_tensor, tree_info_tensor,
|
||||
child_tensor, leaf_mask_tensor);
|
||||
} else if (info_type == framework::proto::VarType::INT64 &&
|
||||
output_type == framework::proto::VarType::INT32) {
|
||||
TDMChildInner<T, int64_t, int>(ctx, input_tensor, tree_info_tensor,
|
||||
child_tensor, leaf_mask_tensor);
|
||||
} else if (info_type == framework::proto::VarType::INT32 &&
|
||||
output_type == framework::proto::VarType::INT64) {
|
||||
TDMChildInner<T, int, int64_t>(ctx, input_tensor, tree_info_tensor,
|
||||
child_tensor, leaf_mask_tensor);
|
||||
} else if (info_type == framework::proto::VarType::INT64 &&
|
||||
output_type == framework::proto::VarType::INT64) {
|
||||
TDMChildInner<T, int64_t, int64_t>(ctx, input_tensor, tree_info_tensor,
|
||||
child_tensor, leaf_mask_tensor);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,170 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.fluid.layers as layers
|
||||
import paddle.fluid as fluid
|
||||
import random
|
||||
import six
|
||||
|
||||
|
||||
def create_tdm_tree():
|
||||
"""Create tdm tree info"""
|
||||
tree_info = [
|
||||
[0, 0, 0, 1, 2],
|
||||
[0, 1, 0, 3, 4],
|
||||
[0, 1, 0, 5, 6],
|
||||
[0, 2, 1, 7, 8],
|
||||
[0, 2, 1, 9, 10],
|
||||
[0, 2, 2, 11, 12],
|
||||
[0, 2, 2, 13, 0],
|
||||
[0, 3, 3, 14, 15],
|
||||
[0, 3, 3, 16, 17],
|
||||
[0, 3, 4, 18, 19],
|
||||
[0, 3, 4, 20, 21],
|
||||
[0, 3, 5, 22, 23],
|
||||
[0, 3, 5, 24, 25],
|
||||
[12, 3, 6, 0, 0],
|
||||
[0, 4, 7, 0, 0],
|
||||
[1, 4, 7, 0, 0],
|
||||
[2, 4, 8, 0, 0],
|
||||
[3, 4, 8, 0, 0],
|
||||
[4, 4, 9, 0, 0],
|
||||
[5, 4, 9, 0, 0],
|
||||
[6, 4, 10, 0, 0],
|
||||
[7, 4, 10, 0, 0],
|
||||
[8, 4, 11, 0, 0],
|
||||
[9, 4, 11, 0, 0],
|
||||
[10, 4, 12, 0, 0],
|
||||
[11, 4, 12, 0, 0],
|
||||
]
|
||||
return tree_info
|
||||
|
||||
|
||||
class TestTDMChildOp(OpTest):
|
||||
def setUp(self):
|
||||
self.__class__.op_type = "tdm_child"
|
||||
self.config()
|
||||
tree_info = create_tdm_tree()
|
||||
tree_info_np = np.array(tree_info).astype(self.info_type)
|
||||
|
||||
x_np = np.random.randint(
|
||||
low=0, high=26, size=self.x_shape).astype(self.x_type)
|
||||
children_res = []
|
||||
leaf_mask_res = []
|
||||
for batch in x_np:
|
||||
for node in batch:
|
||||
children = []
|
||||
if node != 0:
|
||||
children.append(tree_info[node][3])
|
||||
children.append(tree_info[node][4])
|
||||
else:
|
||||
children.append(0)
|
||||
children.append(0)
|
||||
mask = []
|
||||
for child in children:
|
||||
m = int(tree_info[child][0] != 0)
|
||||
mask.append(m)
|
||||
children_res += children
|
||||
leaf_mask_res += mask
|
||||
children_res_np = np.array(children_res).astype(self.info_type)
|
||||
leaf_mask_res_np = np.array(leaf_mask_res).astype(self.info_type)
|
||||
|
||||
child = np.reshape(children_res_np, self.child_shape)
|
||||
leaf_mask = np.reshape(leaf_mask_res_np, self.child_shape)
|
||||
|
||||
self.attrs = {'child_nums': 2}
|
||||
self.inputs = {'X': x_np, 'TreeInfo': tree_info_np}
|
||||
self.outputs = {'Child': child, 'LeafMask': leaf_mask}
|
||||
|
||||
def config(self):
|
||||
"""set test shape & type"""
|
||||
self.x_shape = (10, 20)
|
||||
self.child_shape = (10, 20, 2)
|
||||
self.x_type = 'int32'
|
||||
self.info_type = 'int32'
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCase1(TestTDMChildOp):
|
||||
def config(self):
|
||||
"""check int int64_t """
|
||||
self.x_shape = (10, 20)
|
||||
self.child_shape = (10, 20, 2)
|
||||
self.x_type = 'int32'
|
||||
self.info_type = 'int64'
|
||||
|
||||
|
||||
class TestCase2(TestTDMChildOp):
|
||||
def config(self):
|
||||
"""check int64_t int64_t """
|
||||
self.x_shape = (10, 20)
|
||||
self.child_shape = (10, 20, 2)
|
||||
self.x_type = 'int64'
|
||||
self.info_type = 'int64'
|
||||
|
||||
|
||||
class TestCase3(TestTDMChildOp):
|
||||
def config(self):
|
||||
"""check int64 int32 """
|
||||
self.x_shape = (10, 20)
|
||||
self.child_shape = (10, 20, 2)
|
||||
self.x_type = 'int64'
|
||||
self.info_type = 'int32'
|
||||
|
||||
|
||||
class TestCase4(TestTDMChildOp):
|
||||
def config(self):
|
||||
"""check large shape """
|
||||
self.x_shape = (100, 20)
|
||||
self.child_shape = (100, 20, 2)
|
||||
self.x_type = 'int32'
|
||||
self.info_type = 'int32'
|
||||
|
||||
|
||||
class TestTDMChildShape(unittest.TestCase):
|
||||
def test_shape(self):
|
||||
x = fluid.layers.data(name='x', shape=[1], dtype='int32', lod_level=1)
|
||||
tdm_tree_info = create_tdm_tree()
|
||||
tree_info_np = np.array(tdm_tree_info).astype('int32')
|
||||
|
||||
child, leaf_mask = fluid.contrib.layers.tdm_child(
|
||||
x=x,
|
||||
node_nums=26,
|
||||
child_nums=2,
|
||||
param_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.NumpyArrayInitializer(
|
||||
tree_info_np)))
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place=place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
feed = {
|
||||
'x': np.array([[1], [2], [3], [4], [5], [6], [7], [8], [9], [10],
|
||||
[11], [12]]).astype('int32')
|
||||
}
|
||||
exe.run(feed=feed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue