Add a pass to replace dropout_op with scale_op when is_test is true (#19297)
* Add simplify_with_basic_ops_pass to replace dropout_op with scale_op when is_test is true. test=develop * Delete dropout_op directly when upscale_in_train is true. test=develop * Improve the debug string, adding the print of op_desc information. * Fix the case when dropout's input x is reused as the next op's output. * Add the pass to inference. test=develop * Change the log level. test=develop * Add unittest for inplace case. * Add comment to explain the pass. * Apply the pass for CPU inference. test=develop * Fix the typo. test=develop * Add the check of AttrType. test=developfix_crf_doc
parent
e169538886
commit
fcec365d29
@ -0,0 +1,206 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
struct Layers {
|
||||
public:
|
||||
const ProgramDesc& main_program() { return program_; }
|
||||
|
||||
VarDesc* data(std::string name) { return lod_tensor(name); }
|
||||
|
||||
VarDesc* mul(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
|
||||
return binary_op("mul", x, y, out);
|
||||
}
|
||||
|
||||
VarDesc* elementwise_add(VarDesc* x, VarDesc* y, VarDesc* out = nullptr) {
|
||||
return binary_op("elementwise_add", x, y, out);
|
||||
}
|
||||
|
||||
VarDesc* dropout(VarDesc* x, float dropout_prob,
|
||||
std::string dropout_implementation) {
|
||||
VarDesc* out = lod_tensor(unique_name());
|
||||
OpDesc* op = program_.MutableBlock(0)->AppendOp();
|
||||
op->SetType("dropout");
|
||||
op->SetInput("X", {x->Name()});
|
||||
op->SetOutput("Out", {out->Name()});
|
||||
op->SetAttr("is_test", true);
|
||||
op->SetAttr("dropout_prob", dropout_prob);
|
||||
op->SetAttr("dropout_implementation", dropout_implementation);
|
||||
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
|
||||
static_cast<int>(OpRole::kForward));
|
||||
return out;
|
||||
}
|
||||
|
||||
private:
|
||||
VarDesc* lod_tensor(std::string name) {
|
||||
auto* var = program_.MutableBlock(0)->Var(name);
|
||||
var->SetType(proto::VarType::LOD_TENSOR);
|
||||
return var;
|
||||
}
|
||||
|
||||
VarDesc* binary_op(std::string type, VarDesc* x, VarDesc* y,
|
||||
VarDesc* out = nullptr) {
|
||||
if (!out) {
|
||||
out = lod_tensor(unique_name());
|
||||
}
|
||||
OpDesc* op = program_.MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetInput("X", {x->Name()});
|
||||
op->SetInput("Y", {y->Name()});
|
||||
op->SetOutput("Out", {out->Name()});
|
||||
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
|
||||
static_cast<int>(OpRole::kForward));
|
||||
return out;
|
||||
}
|
||||
|
||||
std::string unique_name() { return "tmp_" + std::to_string(idx_++); }
|
||||
|
||||
private:
|
||||
ProgramDesc program_;
|
||||
int idx_{0};
|
||||
};
|
||||
|
||||
static std::string DebugString(OpDesc* op) {
|
||||
std::ostringstream os;
|
||||
os << "Op(" << op->Type() << "), inputs:{";
|
||||
bool is_first = true;
|
||||
for (auto& name : op->InputNames()) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << name << "[";
|
||||
bool is_first_var_name = true;
|
||||
for (auto& var_name : op->Input(name)) {
|
||||
if (!is_first_var_name) {
|
||||
os << ", ";
|
||||
}
|
||||
os << var_name;
|
||||
is_first_var_name = false;
|
||||
}
|
||||
os << "]";
|
||||
is_first = false;
|
||||
}
|
||||
|
||||
os << "}, outputs:{";
|
||||
is_first = true;
|
||||
for (auto& name : op->OutputNames()) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << name << "[";
|
||||
bool is_first_var_name = true;
|
||||
for (auto& var_name : op->Output(name)) {
|
||||
if (!is_first_var_name) {
|
||||
os << ", ";
|
||||
}
|
||||
os << var_name;
|
||||
is_first_var_name = false;
|
||||
}
|
||||
os << "]";
|
||||
is_first = false;
|
||||
}
|
||||
os << "}";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static std::string DebugString(Node* node) {
|
||||
std::ostringstream os;
|
||||
if (node->IsOp() && node->Op()) {
|
||||
OpDesc* op = node->Op();
|
||||
os << "Node(" << DebugString(op) << "), inputs:{";
|
||||
bool is_first = true;
|
||||
for (auto* in : node->inputs) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << in->Name();
|
||||
is_first = false;
|
||||
}
|
||||
os << "}, outputs:{";
|
||||
is_first = true;
|
||||
for (auto* out : node->outputs) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
os << out->Name();
|
||||
is_first = false;
|
||||
}
|
||||
os << "}.";
|
||||
} else if (node->IsVar() && node->Var()) {
|
||||
os << "Node(" << node->Name() << "), inputs:{";
|
||||
bool is_first = true;
|
||||
for (auto* in : node->inputs) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
if (in->IsOp() && in->Op()) {
|
||||
os << in->Op()->Type();
|
||||
}
|
||||
is_first = false;
|
||||
}
|
||||
os << "}, outputs:{";
|
||||
is_first = true;
|
||||
for (auto* out : node->outputs) {
|
||||
if (!is_first) {
|
||||
os << ", ";
|
||||
}
|
||||
if (out->IsOp() && out->Op()) {
|
||||
os << out->Op()->Type();
|
||||
}
|
||||
is_first = false;
|
||||
}
|
||||
os << "}";
|
||||
}
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static std::string DebugString(const std::unique_ptr<Graph>& graph) {
|
||||
std::ostringstream os;
|
||||
os << "Graph: {\n";
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op()) {
|
||||
os << " ";
|
||||
} else if (node->IsVar() && node->Var()) {
|
||||
os << " ";
|
||||
}
|
||||
os << DebugString(node) << "\n";
|
||||
}
|
||||
os << "}\n";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
|
||||
std::string op_type) {
|
||||
int num_nodes = 0;
|
||||
for (auto* node : graph->Nodes()) {
|
||||
if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) {
|
||||
num_nodes++;
|
||||
}
|
||||
}
|
||||
return num_nodes;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,202 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h"
|
||||
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* This pass is to simplify the Grpah, it may contains:
|
||||
* - replace comlicated op with basic op
|
||||
* - remove some unnecessary op
|
||||
*
|
||||
* In the current implementation, it supports:
|
||||
* - remove dropout_op (upscale_in_train) or
|
||||
* replace dropout_op with scale_op (downgrade_in_infer) when is_test is true
|
||||
*/
|
||||
void SimplifyWithBasicOpsPass::ApplyImpl(Graph* graph) const {
|
||||
VLOG(3) << "Simplify the Graph with basic ops.";
|
||||
std::unordered_set<const Node*> del_node_set;
|
||||
for (Node* n : graph->Nodes()) {
|
||||
if (n->IsOp() && n->Op()) {
|
||||
if (n->Op()->Type() == "dropout") {
|
||||
SimplifyDropout(graph, n, &del_node_set);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GraphSafeRemoveNodes(graph, del_node_set);
|
||||
}
|
||||
|
||||
bool SimplifyWithBasicOpsPass::SimplifyDropout(
|
||||
Graph* graph, Node* n,
|
||||
std::unordered_set<const Node*>* del_node_set) const {
|
||||
OpDesc* dropout_op_desc = n->Op();
|
||||
bool is_test = false;
|
||||
// In the model used in test_analyzer_bert, the is_test's AttrType of
|
||||
// dropout_op is INT.
|
||||
if (dropout_op_desc->HasAttr("is_test")) {
|
||||
if (dropout_op_desc->GetAttrType("is_test") == proto::AttrType::BOOLEAN) {
|
||||
is_test = boost::get<bool>(dropout_op_desc->GetAttr("is_test"));
|
||||
} else if (dropout_op_desc->GetAttrType("is_test") ==
|
||||
proto::AttrType::INT) {
|
||||
is_test = boost::get<int>(dropout_op_desc->GetAttr("is_test")) == 0
|
||||
? false
|
||||
: true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!is_test) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Node* dropout_x = GetInputVar(n, dropout_op_desc->Input("X")[0]);
|
||||
Node* dropout_out = GetOutputVar(n, dropout_op_desc->Output("Out")[0]);
|
||||
|
||||
bool upscale_in_train = false;
|
||||
// Once the dropout_implementation's AttrType is BOOLEAN, but now is STRING.
|
||||
if (dropout_op_desc->HasAttr("dropout_implementation")) {
|
||||
if (dropout_op_desc->GetAttrType("dropout_implementation") ==
|
||||
proto::AttrType::BOOLEAN) {
|
||||
upscale_in_train =
|
||||
boost::get<bool>(dropout_op_desc->GetAttr("dropout_implementation"));
|
||||
} else if (dropout_op_desc->GetAttrType("dropout_implementation") ==
|
||||
proto::AttrType::STRING) {
|
||||
upscale_in_train = boost::get<std::string>(dropout_op_desc->GetAttr(
|
||||
"dropout_implementation")) == "upscale_in_train";
|
||||
}
|
||||
}
|
||||
|
||||
if (upscale_in_train) {
|
||||
// dropout_op can be deleted.
|
||||
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
|
||||
// |
|
||||
// \|/
|
||||
// dropout_x -> next_op -> next_out
|
||||
// Check whether dropout_x is some next_op's output
|
||||
bool dropout_x_is_reused_as_output = false;
|
||||
for (auto* next_op : dropout_out->outputs) {
|
||||
for (auto* next_out : next_op->outputs) {
|
||||
if (next_out == dropout_x ||
|
||||
next_out->Var()->Name() == dropout_x->Var()->Name()) {
|
||||
dropout_x_is_reused_as_output = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (dropout_x_is_reused_as_output) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (dropout_x_is_reused_as_output) {
|
||||
VarDesc new_var_desc(*dropout_x->Var());
|
||||
new_var_desc.SetName("simplify_with_basic_ops_" + dropout_x->Name());
|
||||
auto* new_var_node = graph->CreateVarNode(&new_var_desc);
|
||||
for (auto* out_op : dropout_x->outputs) {
|
||||
if (out_op != n) {
|
||||
ReplaceInputVar(out_op, dropout_x, new_var_node);
|
||||
}
|
||||
}
|
||||
for (auto* in_op : dropout_x->inputs) {
|
||||
ReplaceOutputVar(in_op, dropout_x, new_var_node);
|
||||
}
|
||||
dropout_x = new_var_node;
|
||||
}
|
||||
for (auto* next_op : dropout_out->outputs) {
|
||||
ReplaceInputVar(next_op, dropout_out, dropout_x);
|
||||
}
|
||||
|
||||
del_node_set->insert(dropout_out);
|
||||
} else {
|
||||
// Use a scale_op replaces the dropout_op
|
||||
// dropout_x -> dropout_op -> dropout_out -> next_op -> next_out
|
||||
// |
|
||||
// \|/
|
||||
// dropout_x -> scale_op -> dropout_out -> next_op -> next_out
|
||||
float scale =
|
||||
1.0f - boost::get<float>(dropout_op_desc->GetAttr("dropout_prob"));
|
||||
|
||||
framework::OpDesc new_op_desc;
|
||||
new_op_desc.SetType("scale");
|
||||
new_op_desc.SetInput("X", {dropout_x->Name()});
|
||||
new_op_desc.SetOutput("Out", {dropout_out->Name()});
|
||||
new_op_desc.SetAttr("scale", scale);
|
||||
new_op_desc.SetAttr("bias", static_cast<float>(0));
|
||||
new_op_desc.SetAttr("bias_after_scale", true);
|
||||
|
||||
auto* scale_op_node = graph->CreateOpNode(&new_op_desc);
|
||||
IR_NODE_LINK_TO(dropout_x, scale_op_node);
|
||||
IR_NODE_LINK_TO(scale_op_node, dropout_out);
|
||||
}
|
||||
|
||||
del_node_set->insert(n);
|
||||
return true;
|
||||
}
|
||||
|
||||
Node* SimplifyWithBasicOpsPass::GetInputVar(Node* n,
|
||||
const std::string& name) const {
|
||||
for (auto* in : n->inputs) {
|
||||
if (in->Name() == name) {
|
||||
return in;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Node* SimplifyWithBasicOpsPass::GetOutputVar(Node* n,
|
||||
const std::string& name) const {
|
||||
for (auto* out : n->outputs) {
|
||||
if (out->Name() == name) {
|
||||
return out;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void SimplifyWithBasicOpsPass::ReplaceInputVar(Node* op, Node* old_var,
|
||||
Node* new_var) const {
|
||||
if (op->IsOp() && op->Op()) {
|
||||
new_var->outputs.push_back(op);
|
||||
for (size_t i = 0; i < op->inputs.size(); ++i) {
|
||||
if (op->inputs[i] == old_var) {
|
||||
op->inputs[i] = new_var;
|
||||
op->Op()->RenameInput(old_var->Name(), new_var->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SimplifyWithBasicOpsPass::ReplaceOutputVar(Node* op, Node* old_var,
|
||||
Node* new_var) const {
|
||||
if (op->IsOp() && op->Op()) {
|
||||
new_var->inputs.push_back(op);
|
||||
for (size_t i = 0; i < op->outputs.size(); ++i) {
|
||||
if (op->outputs[i] == old_var) {
|
||||
op->outputs[i] = new_var;
|
||||
op->Op()->RenameOutput(old_var->Name(), new_var->Name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(simplify_with_basic_ops_pass,
|
||||
paddle::framework::ir::SimplifyWithBasicOpsPass);
|
||||
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class SimplifyWithBasicOpsPass : public Pass {
|
||||
protected:
|
||||
void ApplyImpl(Graph* graph) const override;
|
||||
|
||||
private:
|
||||
bool SimplifyDropout(Graph* graph, Node* n,
|
||||
std::unordered_set<const Node*>* del_node_set) const;
|
||||
|
||||
Node* GetInputVar(Node* n, const std::string& name) const;
|
||||
Node* GetOutputVar(Node* n, const std::string& name) const;
|
||||
|
||||
void ReplaceInputVar(Node* op, Node* old_var, Node* new_var) const;
|
||||
void ReplaceOutputVar(Node* op, Node* old_var, Node* new_var) const;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
@ -0,0 +1,78 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/ir/simplify_with_basic_ops_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
||||
#include "paddle/fluid/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
TEST(SimplifyWithBasicOpsPass, dropout) {
|
||||
for (std::string dropout_implementation :
|
||||
{"downgrade_in_infer", "upscale_in_train"}) {
|
||||
for (auto inplace : {false, true}) {
|
||||
if (dropout_implementation == "downgrade_in_infer" && inplace == true) {
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG(INFO) << "dropout_implementation: " << dropout_implementation
|
||||
<< ", inplace: " << inplace;
|
||||
Layers layers;
|
||||
// (x, y) -> mul -> tmp_0
|
||||
// (tmp_0) -> dropout -> (tmp_1)
|
||||
// (tmp_1, z) -> elementwise_add -> (tmp_2)
|
||||
// or
|
||||
// (tmp_1, z) -> elementwise_add -> (tmp_0)
|
||||
auto* x = layers.data("x");
|
||||
auto* y = layers.data("y");
|
||||
auto* z = layers.data("z");
|
||||
auto* mul_out = layers.mul(x, y);
|
||||
auto* dropout_out = layers.dropout(mul_out, 0.5f, dropout_implementation);
|
||||
if (inplace) {
|
||||
layers.elementwise_add(dropout_out, z, mul_out);
|
||||
} else {
|
||||
layers.elementwise_add(dropout_out, z);
|
||||
}
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
|
||||
auto pass = PassRegistry::Instance().Get("simplify_with_basic_ops_pass");
|
||||
int num_dropout_nodes_before = GetNumOpNodes(graph, "dropout");
|
||||
int num_scale_nodes_before = GetNumOpNodes(graph, "scale");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
graph.reset(pass->Apply(graph.release()));
|
||||
int num_dropout_nodes_after = GetNumOpNodes(graph, "dropout");
|
||||
int num_scale_nodes_after = GetNumOpNodes(graph, "scale");
|
||||
VLOG(3) << DebugString(graph);
|
||||
|
||||
PADDLE_ENFORCE_EQ(num_dropout_nodes_after, 0UL);
|
||||
if (dropout_implementation == "downgrade_in_infer") {
|
||||
PADDLE_ENFORCE_EQ(num_dropout_nodes_before,
|
||||
num_scale_nodes_after - num_scale_nodes_before);
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(num_scale_nodes_after - num_scale_nodes_before, 0UL);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(simplify_with_basic_ops_pass);
|
||||
Loading…
Reference in new issue