Fusion: seqpool_cvm_concat (#18471)
* add fusion_seqpool_cvm_concat test=develop * simplify pass, test=develop * fix code style, test=developpadding_in_crf
parent
768059b3a0
commit
ee2f296ef8
@ -0,0 +1,153 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h"
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
namespace {
|
||||
static PDNode* BuildCVMConcatPattern(PDPattern* pattern) {
|
||||
auto cvm_behind_x = [](Node* x) -> bool {
|
||||
Node* adj = x->inputs[0];
|
||||
Node* alt = x->inputs[0]->inputs[0];
|
||||
return x && adj && adj->IsVar() && alt->IsOp() &&
|
||||
alt->Op()->Type() == "cvm";
|
||||
};
|
||||
auto* concat_op_node = pattern->NewNode("concat_op")
|
||||
->assert_is_op("concat")
|
||||
->assert_op_attr<int>("axis", 1)
|
||||
->assert_more(cvm_behind_x);
|
||||
return concat_op_node;
|
||||
}
|
||||
|
||||
static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
auto concat_op_node = BuildCVMConcatPattern(pattern);
|
||||
GraphPatternDetector::handle_t handler = [&](
|
||||
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
|
||||
Node* concat_op = subgraph.at(concat_op_node);
|
||||
concat_nodes->push_back(concat_op);
|
||||
};
|
||||
gpd(graph, handler);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
|
||||
FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
|
||||
std::vector<Node*> concat_nodes;
|
||||
GetConcatNodes(graph, &concat_nodes);
|
||||
|
||||
int count = 0;
|
||||
for (auto* concat_node : concat_nodes) {
|
||||
GraphPatternDetector gpd;
|
||||
auto* pattern = gpd.mutable_pattern();
|
||||
auto concat_before_x = [=](Node* x) -> bool {
|
||||
return x && x->outputs[0] == concat_node;
|
||||
};
|
||||
PDNode* seqpool_in_var_node =
|
||||
pattern->NewNode("seqpool_in_var")
|
||||
->assert_is_only_input_of_op("sequence_pool");
|
||||
PDNode* seqpool_op_node =
|
||||
pattern->NewNode("seqpool_op")
|
||||
->assert_is_op("sequence_pool")
|
||||
->assert_op_attr<std::string>("pooltype", "SUM");
|
||||
PDNode* seqpool_out_var_node =
|
||||
pattern->NewNode("seqpool_out_var")
|
||||
->assert_is_op_nth_output("sequence_pool", "Out", 0)
|
||||
->assert_is_op_nth_input("cvm", "X", 0);
|
||||
PDNode* seqpool_idx_out_var_node =
|
||||
pattern->NewNode("seqpool_idx_out_var")
|
||||
->assert_is_op_nth_output("sequence_pool", "MaxIndex", 0);
|
||||
PDNode* cvm_op_node =
|
||||
pattern->NewNode("cvm_op")->assert_is_op("cvm")->assert_op_attr<bool>(
|
||||
"use_cvm", true);
|
||||
PDNode* cvm_out_var_node = pattern->NewNode("cvm_op_out_var")
|
||||
->assert_is_op_nth_output("cvm", "Y", 0)
|
||||
->assert_more(concat_before_x);
|
||||
PDNode* cvm_cvm_in_var_node = pattern->NewNode("cvm_cvm_in_var")
|
||||
->assert_is_op_nth_input("cvm", "CVM", 0);
|
||||
|
||||
seqpool_op_node->LinksFrom({seqpool_in_var_node})
|
||||
.LinksTo({seqpool_out_var_node, seqpool_idx_out_var_node});
|
||||
seqpool_out_var_node->LinksFrom({seqpool_op_node}).LinksTo({cvm_op_node});
|
||||
cvm_op_node->LinksTo({cvm_out_var_node})
|
||||
.LinksFrom({cvm_cvm_in_var_node, seqpool_out_var_node});
|
||||
|
||||
std::unordered_map<std::string, Node*> ins_to_concat;
|
||||
std::vector<Node*> subgraph_ins;
|
||||
std::vector<std::string> subgraph_ins_name;
|
||||
std::unordered_set<const Node*> marked_nodes;
|
||||
|
||||
Node* cvm_input_of_cvm;
|
||||
Node* concat_out_var = concat_node->outputs[0];
|
||||
|
||||
GraphPatternDetector::handle_t handler = [&](
|
||||
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
|
||||
Node* seqpool_in_var = subgraph.at(seqpool_in_var_node);
|
||||
Node* seqpool_op = subgraph.at(seqpool_op_node);
|
||||
Node* seqpool_out_var = subgraph.at(seqpool_out_var_node);
|
||||
Node* seqpool_idx_out_var = subgraph.at(seqpool_idx_out_var_node);
|
||||
Node* cvm_op = subgraph.at(cvm_op_node);
|
||||
Node* cvm_out_var = subgraph.at(cvm_out_var_node);
|
||||
cvm_input_of_cvm = subgraph.at(cvm_cvm_in_var_node);
|
||||
marked_nodes.insert({seqpool_op, seqpool_out_var, seqpool_idx_out_var,
|
||||
cvm_op, cvm_out_var, concat_node});
|
||||
ins_to_concat[cvm_out_var->Name()] = seqpool_in_var;
|
||||
};
|
||||
gpd(graph, handler);
|
||||
|
||||
if (!ins_to_concat.empty()) {
|
||||
for (const auto* in : concat_node->inputs) {
|
||||
subgraph_ins.push_back(ins_to_concat.at(in->Name()));
|
||||
subgraph_ins_name.push_back(ins_to_concat.at(in->Name())->Name());
|
||||
}
|
||||
|
||||
// Create New OpDesc
|
||||
OpDesc op_desc;
|
||||
op_desc.SetType("fusion_seqpool_cvm_concat");
|
||||
op_desc.SetInput("X", subgraph_ins_name);
|
||||
op_desc.SetInput("CVM", {cvm_input_of_cvm->Name()});
|
||||
op_desc.SetAttr("pooltype", std::string("SUM"));
|
||||
op_desc.SetAttr("use_cvm", true);
|
||||
op_desc.SetAttr("axis", concat_node->Op()->GetAttr("axis"));
|
||||
op_desc.SetOutput("Out", {concat_out_var->Name()});
|
||||
auto* op = graph->CreateOpNode(&op_desc);
|
||||
|
||||
for (size_t i = 0; i < subgraph_ins.size(); ++i) {
|
||||
IR_NODE_LINK_TO(subgraph_ins[i], op);
|
||||
}
|
||||
IR_NODE_LINK_TO(cvm_input_of_cvm, op);
|
||||
IR_NODE_LINK_TO(op, concat_out_var);
|
||||
|
||||
GraphSafeRemoveNodes(graph, marked_nodes);
|
||||
count++;
|
||||
}
|
||||
}
|
||||
AddStatis(count);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(seqpool_cvm_concat_fuse_pass,
|
||||
paddle::framework::ir::SeqPoolCVMConcatFusePass);
|
@ -0,0 +1,54 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/**
|
||||
* Fuse SequencePool(with sum pooltype yet) and Concat;
|
||||
*
|
||||
* Before fuse:
|
||||
* | | |
|
||||
* seq_pool, seq_pool, ... seq_pool
|
||||
* | | |
|
||||
* cvm cvm cvm
|
||||
* \ | ... /
|
||||
* concat
|
||||
* |
|
||||
* After fuse:
|
||||
* \ | /
|
||||
* FusionSeqPoolCVMConcat
|
||||
* |
|
||||
*/
|
||||
class SeqPoolCVMConcatFusePass : public FusePassBase {
|
||||
public:
|
||||
virtual ~SeqPoolCVMConcatFusePass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
|
||||
const std::string name_scope_{"seqpool_cvm_concat_fuse"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,239 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
if (type == "sequence_pool") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
std::string pooltype = "SUM";
|
||||
op->SetAttr("pooltype", pooltype);
|
||||
op->SetOutput("MaxIndex", {outputs[0]});
|
||||
op->SetOutput("Out", {outputs[1]});
|
||||
} else if (type == "concat") {
|
||||
op->SetInput("X", inputs);
|
||||
op->SetAttr("axis", 1);
|
||||
op->SetOutput("Out", {outputs[0]});
|
||||
} else if (type == "cvm") {
|
||||
op->SetInput("X", {inputs[0]});
|
||||
op->SetInput("CVM", {inputs[1]});
|
||||
op->SetOutput("Y", {outputs[0]});
|
||||
op->SetAttr("use_cvm", true);
|
||||
} else {
|
||||
op->SetInput("X", inputs);
|
||||
op->SetOutput("Out", outputs);
|
||||
}
|
||||
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
|
||||
static_cast<int>(OpRole::kForward));
|
||||
}
|
||||
|
||||
int CountOpType(const ir::Graph* graph,
|
||||
const std::string& op_type = "fusion_seqpool_cvm_concat") {
|
||||
int count = 0;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()->Type() == op_type) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
|
||||
std::unique_ptr<ir::Graph> graph, int* before, int* after,
|
||||
const std::string& pass_type = "seqpool_cvm_concat_fuse_pass") {
|
||||
auto pass = PassRegistry::Instance().Get(pass_type);
|
||||
*before = graph->Nodes().size();
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
*after = graph->Nodes().size();
|
||||
return graph;
|
||||
}
|
||||
|
||||
/*
|
||||
* Before fuse:
|
||||
*
|
||||
*
|
||||
* a b c
|
||||
* | | |
|
||||
* op1 op2 op3
|
||||
* / \ / \ / \
|
||||
* d e n f g n h i n
|
||||
* | / | / | /
|
||||
* op4 op5 op6
|
||||
* | | |
|
||||
j k l
|
||||
* \ | /
|
||||
* concat
|
||||
* |
|
||||
* m
|
||||
*
|
||||
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr.
|
||||
* Type of op4, op5 and op6 are cvm, with use_cvm is true.
|
||||
*
|
||||
* After fuse:
|
||||
* a b c n
|
||||
* \ | | /
|
||||
* fusion_seqpool_cvm_concat
|
||||
* |
|
||||
* m
|
||||
*/
|
||||
TEST(SeqPoolCVMConcatFusePass, basic) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v :
|
||||
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h", "i",
|
||||
"j", "k", "l", "m", "n"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
}
|
||||
|
||||
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"d", "e"}));
|
||||
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
|
||||
std::vector<std::string>({"f", "g"}));
|
||||
SetOp(&prog, "sequence_pool", std::vector<std::string>({"c"}),
|
||||
std::vector<std::string>({"h", "i"}));
|
||||
SetOp(&prog, "cvm", std::vector<std::string>({"e", "n"}),
|
||||
std::vector<std::string>({"j"}));
|
||||
SetOp(&prog, "cvm", std::vector<std::string>({"g", "n"}),
|
||||
std::vector<std::string>({"k"}));
|
||||
SetOp(&prog, "cvm", std::vector<std::string>({"i", "n"}),
|
||||
std::vector<std::string>({"l"}));
|
||||
SetOp(&prog, "concat", std::vector<std::string>({"j", "k", "l"}),
|
||||
std::vector<std::string>({"m"}));
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
int before, after;
|
||||
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
|
||||
// Remove 16 Nodes: op1, op2, op3, op4, op5, op6, d, e, f, g, h, i, j, k, l,
|
||||
// concat_op
|
||||
// Add 1 Node: fusion_seqpool_cvm_concat
|
||||
EXPECT_EQ(after, before - 15);
|
||||
EXPECT_EQ(CountOpType(graph.get()), 1);
|
||||
}
|
||||
|
||||
/*
|
||||
* Before fuse:
|
||||
* a b
|
||||
* | / \
|
||||
* op1 k op2 k op3
|
||||
* / \ / / \ / \
|
||||
* c d e f g
|
||||
* | |
|
||||
* op4 op5
|
||||
* | |
|
||||
* h i
|
||||
* \ /
|
||||
* concat
|
||||
* |
|
||||
* j
|
||||
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr.
|
||||
* Type of op4 and op5 are cvm, with use_cvm is true.
|
||||
*
|
||||
* After fuse:
|
||||
* a k b
|
||||
* \ | / \
|
||||
* fusion_seqpool_cvm_concat op3
|
||||
* | |
|
||||
* j g
|
||||
*/
|
||||
TEST(SeqPoolCVMConcatFusePass, advanced) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::vector<std::string>(
|
||||
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
}
|
||||
|
||||
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
|
||||
std::vector<std::string>({"c", "d"}));
|
||||
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
|
||||
std::vector<std::string>({"e", "f"}));
|
||||
SetOp(&prog, "op3", std::vector<std::string>({"b"}),
|
||||
std::vector<std::string>({"g"}));
|
||||
SetOp(&prog, "cvm", std::vector<std::string>({"d", "k"}),
|
||||
std::vector<std::string>({"h"}));
|
||||
SetOp(&prog, "cvm", std::vector<std::string>({"f", "k"}),
|
||||
std::vector<std::string>({"i"}));
|
||||
SetOp(&prog, "concat", std::vector<std::string>({"h", "i"}),
|
||||
std::vector<std::string>({"j"}));
|
||||
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
int before, after;
|
||||
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
|
||||
// Remove 11 Nodes: op1, op2, op4, op5, c, d, e, f, h, i, concat_op
|
||||
// Add 1 Node: fusion_seqpool_cvm_concat
|
||||
EXPECT_EQ(after, before - 10);
|
||||
EXPECT_EQ(CountOpType(graph.get()), 1);
|
||||
}
|
||||
|
||||
ProgramDesc BuildProgramDesc(int num_inputs_of_concat) {
|
||||
ProgramDesc prog;
|
||||
auto new_var = [&](const std::string& name) {
|
||||
auto* var = prog.MutableBlock(0)->Var(name);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
};
|
||||
std::vector<std::string> concat_inputs;
|
||||
new_var("cvm_in");
|
||||
for (int i = 0; i < num_inputs_of_concat; ++i) {
|
||||
std::string seqpool_prefix = "seqpool_op_" + std::to_string(i);
|
||||
new_var(seqpool_prefix + "in");
|
||||
new_var(seqpool_prefix + "out");
|
||||
new_var(seqpool_prefix + "out_unused");
|
||||
SetOp(&prog, "sequence_pool",
|
||||
std::vector<std::string>({seqpool_prefix + "in"}),
|
||||
std::vector<std::string>(
|
||||
{seqpool_prefix + "out_unused", seqpool_prefix + "out"}));
|
||||
|
||||
std::string cvm_prefix = "cvm_op_" + std::to_string(i);
|
||||
new_var(cvm_prefix + "out");
|
||||
SetOp(&prog, "cvm",
|
||||
std::vector<std::string>({seqpool_prefix + "out", "cvm_in"}),
|
||||
std::vector<std::string>({cvm_prefix + "out"}));
|
||||
|
||||
concat_inputs.push_back(cvm_prefix + "out");
|
||||
}
|
||||
SetOp(&prog, "concat", concat_inputs,
|
||||
std::vector<std::string>({"concat_out"}));
|
||||
return prog;
|
||||
}
|
||||
|
||||
// test more inputs of concat
|
||||
TEST(SeqPoolCVMConcatFusePass, more_inputs) {
|
||||
for (int num : {1, 2, 10}) {
|
||||
ProgramDesc prog = BuildProgramDesc(num);
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
int before, after;
|
||||
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
|
||||
// Remove Nodes: n * (seqpool_op, seqpool_out, out_unused, cvm_op, cvm_out),
|
||||
// and concat_op
|
||||
// Add Node: fusion_seqpool_cvm_concat op
|
||||
EXPECT_EQ(after, before - num * 5);
|
||||
EXPECT_EQ(CountOpType(graph.get()), 1);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(seqpool_cvm_concat_fuse_pass);
|
@ -0,0 +1,148 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/fusion_seqpool_cvm_concat_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void FusionSeqPoolCVMConcatOp::InferShape(
|
||||
framework::InferShapeContext* ctx) const {
|
||||
PADDLE_ENFORCE_GE(
|
||||
ctx->Inputs("X").size(), 1UL,
|
||||
"Inputs(X) of FusionSeqPoolCVMConcatOp should not be empty.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FusionSeqPoolCVMConcatOp should not be null.");
|
||||
int axis = ctx->Attrs().Get<int>("axis");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
axis, 1, "FusionSeqPoolCVMConcatOp only supports concat axis=1 yet.");
|
||||
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
use_cvm, true,
|
||||
"FusionSeqPoolCVMConcatOp only supports use_cvm is true yet.");
|
||||
|
||||
auto ins_dims = ctx->GetInputsDim("X");
|
||||
const size_t n = ins_dims.size();
|
||||
PADDLE_ENFORCE_GT(n, 0UL, "Input tensors count should > 0.");
|
||||
if (n == 1) {
|
||||
LOG(WARNING) << "Only have one input, may waste memory";
|
||||
}
|
||||
|
||||
// The output height should be confirmed in Compute,
|
||||
// since input lod is not accessible here.
|
||||
PADDLE_ENFORCE_EQ(ins_dims[0].size(), 2,
|
||||
"The dims size of first input should be 2.");
|
||||
ctx->SetOutputDim("Out", {-1, ins_dims[0][axis] * static_cast<int>(n)});
|
||||
}
|
||||
|
||||
framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
return framework::OpKernelType(
|
||||
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
|
||||
}
|
||||
|
||||
void FusionSeqPoolCVMConcatOpMaker::Make() {
|
||||
AddInput("X", "(LoDTensor) Input tensors of this operator.").AsDuplicable();
|
||||
AddInput("CVM",
|
||||
"(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch "
|
||||
"size, 2 is show and click.");
|
||||
AddOutput("Out", "(LoDTensor) Output tensor of concat operator.");
|
||||
AddAttr<std::string>("pooltype",
|
||||
"(string, default 'SUM') some of the pooling "
|
||||
"pooltype of SequencePoolOp.")
|
||||
.SetDefault("SUM")
|
||||
.InEnum({"AVERAGE", "SUM", "SQRT"});
|
||||
AddAttr<bool>("use_cvm", "bool, use cvm or not").SetDefault(true);
|
||||
AddAttr<int>("axis",
|
||||
"The axis along which the input tensors will be concatenated. "
|
||||
"Only supports concat axis=1 yet.")
|
||||
.SetDefault(1);
|
||||
AddComment(R"DOC(
|
||||
Fusion Sequence Pool of pooltype(sum, average and sqrt) and Concat Operator.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class FusionSeqPoolCVMConcatKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
auto* out = ctx.Output<LoDTensor>("Out");
|
||||
std::string pooltype = ctx.Attr<std::string>("pooltype");
|
||||
auto x0_lod = ins[0]->lod();
|
||||
auto x0_dims = ins[0]->dims();
|
||||
auto y_dims = out->dims();
|
||||
size_t bs = x0_lod[0].size() - 1;
|
||||
out->Resize({static_cast<int64_t>(bs), y_dims[1]});
|
||||
framework::LoD y_lod(1);
|
||||
y_lod[0].resize(bs + 1);
|
||||
for (size_t i = 0; i <= bs; ++i) {
|
||||
y_lod[0][i] = i;
|
||||
}
|
||||
out->set_lod(y_lod);
|
||||
auto place = ctx.GetPlace();
|
||||
T* y_data = out->mutable_data<T>(place);
|
||||
|
||||
int w = ins[0]->numel() / x0_dims[0];
|
||||
PADDLE_ENFORCE_EQ(y_dims[1] % w, 0,
|
||||
"The output of dims[1] should be dividable of w");
|
||||
jit::seq_pool_attr_t attr(w, jit::SeqPoolType::kSum);
|
||||
if (pooltype == "AVERAGE") {
|
||||
attr.type = jit::SeqPoolType::kAvg;
|
||||
} else if (pooltype == "SQRT") {
|
||||
attr.type = jit::SeqPoolType::kSqrt;
|
||||
}
|
||||
auto seqpool =
|
||||
jit::KernelFuncs<jit::SeqPoolTuple<T>, platform::CPUPlace>::Cache().At(
|
||||
attr);
|
||||
size_t n = ins.size();
|
||||
size_t dst_step_size = n * w;
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
auto x_dims = ins[i]->dims();
|
||||
auto x_lod = ins[i]->lod()[0];
|
||||
const T* src = ins[i]->data<T>();
|
||||
T* dst = y_data + i * w;
|
||||
PADDLE_ENFORCE_EQ(static_cast<int>(ins[i]->numel() / x_dims[0]), w,
|
||||
"Width of all inputs should be equal.");
|
||||
PADDLE_ENFORCE_EQ(x_lod.size(), bs + 1,
|
||||
"Batchsize of all inputs should be equal.");
|
||||
for (size_t j = 0; j < bs; ++j) {
|
||||
attr.h = static_cast<int>(x_lod[j + 1] - x_lod[j]);
|
||||
seqpool(src, dst, &attr);
|
||||
|
||||
// Currently only use_cvm is true.
|
||||
dst[0] = log(dst[0] + 1);
|
||||
dst[1] = log(dst[1] + 1) - dst[0];
|
||||
|
||||
dst += dst_step_size;
|
||||
src += attr.h * attr.w;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fusion_seqpool_cvm_concat, ops::FusionSeqPoolCVMConcatOp,
|
||||
ops::FusionSeqPoolCVMConcatOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(fusion_seqpool_cvm_concat,
|
||||
ops::FusionSeqPoolCVMConcatKernel<float>,
|
||||
ops::FusionSeqPoolCVMConcatKernel<double>);
|
@ -0,0 +1,41 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class FusionSeqPoolCVMConcatOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class FusionSeqPoolCVMConcatOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,125 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_reorder_lod_tensor import convert_to_offset
|
||||
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
|
||||
from test_cvm_op import cvm_compute
|
||||
|
||||
|
||||
class TestFusionSeqPoolCVMConcatOp(OpTest):
|
||||
def setUp(self):
|
||||
self.w = 11
|
||||
self.use_cvm = True
|
||||
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
|
||||
self.set_conf()
|
||||
self.set_pooltype()
|
||||
self.op_type = 'fusion_seqpool_cvm_concat'
|
||||
self.axis = 1
|
||||
bs = len(self.lods[0][0])
|
||||
inputs = []
|
||||
outs = []
|
||||
# The cvm variable is not actually used.
|
||||
cvm = np.array([[0.6, 0.4]]).astype("float32")
|
||||
i = 0
|
||||
for lod in self.lods:
|
||||
assert bs == len(lod[0]), 'All lod size should be equal'
|
||||
x = np.random.uniform(0.1, 1,
|
||||
[sum(lod[0]), self.w]).astype('float32')
|
||||
offset = convert_to_offset(lod)
|
||||
out = np.zeros((bs, self.w)).astype('float32')
|
||||
if self.pooltype == "SUM":
|
||||
compute_seqpool_sum(x, offset, out)
|
||||
out = cvm_compute(out, self.w, self.use_cvm)
|
||||
elif self.pooltype == "AVERAGE":
|
||||
compute_seqpool_avg(x, offset, out)
|
||||
out = cvm_compute(out, self.w, self.use_cvm)
|
||||
elif self.pooltype == "SQRT":
|
||||
compute_seqpool_sqrt(x, offset, out)
|
||||
out = cvm_compute(out, self.w, self.use_cvm)
|
||||
else:
|
||||
raise Exception("Unsupported pool type!")
|
||||
inputs.append(('x_{0}'.format(i), (x, lod)))
|
||||
outs.append(out)
|
||||
i = i + 1
|
||||
|
||||
self.inputs = {'X': inputs, "CVM": cvm}
|
||||
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
|
||||
self.attrs = {
|
||||
'pooltype': self.pooltype,
|
||||
'axis': self.axis,
|
||||
}
|
||||
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SUM"
|
||||
|
||||
def set_conf(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestFusionSeqPoolCVMConcatOpCase1(TestFusionSeqPoolCVMConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolCVMConcatOpCase2(TestFusionSeqPoolCVMConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]], [[1]], [[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolCVMConcatOpCase3(TestFusionSeqPoolCVMConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1, 3, 4, 6]]]
|
||||
self.w = 10
|
||||
|
||||
|
||||
class TestFusionSeqPoolCVMConcatOpCase4(TestFusionSeqPoolCVMConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
|
||||
self.w = 3
|
||||
|
||||
|
||||
## test avg pool and sqrt
|
||||
def create_test_avg_sqrt_class(parent):
|
||||
class TestSeqPoolAvgCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "AVERAGE"
|
||||
|
||||
class TestSeqPoolSqrtCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SQRT"
|
||||
|
||||
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
|
||||
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
|
||||
TestSeqPoolAvgCase.__name__ = cls_name_avg
|
||||
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
|
||||
globals()[cls_name_avg] = TestSeqPoolAvgCase
|
||||
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
|
||||
|
||||
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOp)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase1)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase2)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase3)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolCVMConcatOpCase4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue