Enable generating code for a given subgraph. (#21126)

* Enable generating code for a given subgraph.

* Support sorting the subgraph.

* Remove the rearange of expressions because we use the sorted subgraph directly.

* Enable generating code for a subgraph which is composed of grad ops.

* Use expression information to check the accuracy in unittest.

* Separate load and store from computation expressions.
test=develop

* Improve the loading statements in generated codes.
test=develop

* Remove unused arguments from formal list.
test=develop
revert-21172-masked_select_api
Yiqun Liu 6 years ago committed by GitHub
parent 3ff5cc2d5e
commit 6b1e1f0dda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph) cc_library(code_generator SRCS operation.cc code_generator.cc code_generator_helper.cc DEPS graph)
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
if(WITH_GPU) if(WITH_GPU)
cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor) cc_test(test_code_generator SRCS code_generator_tester.cc DEPS code_generator device_code lod_tensor graph_viz_pass)
endif() endif()
endif() endif()

@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include <set>
#include <sstream> #include <sstream>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -30,69 +31,205 @@ CodeGenerator::CodeGenerator() {
code_templates_[0] = elementwise_t; code_templates_[0] = elementwise_t;
} }
std::string CodeGenerator::Generate(SubGraph* subgraph) {
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
return Generate(subgraph->func_name, expressions);
}
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
SubGraph* subgraph) {
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
std::vector<OperationExpression> expressions;
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsOp() && node->Op()) {
auto* op = node->Op();
// Input ids should be set in fixed order, like:
// - x, y in forward operations
// - x, y, out, out@GRAD in backward operations
std::vector<int> input_ids;
std::vector<std::string> input_names =
OperationMap::Instance().Get(op->Type()).input_names;
for (auto& name : input_names) {
// TODO(liuyiqun): support duplicated input.
if (op->Input(name).size() >= 1U) {
// Some input vars are not used in grad ops, such as
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
PADDLE_ENFORCE_NE(var_ids.find(op->Input(name)[0]), var_ids.end(),
"Input(%s) of operation %s should be set.", name,
op->Type());
input_ids.push_back(var_ids[op->Input(name)[0]]);
} else {
input_ids.push_back(-1);
}
}
// Output ids should be set in fixed order, like:
// - dx, dy in backward operations
std::vector<int> output_ids;
std::vector<std::string> output_names =
OperationMap::Instance().Get(op->Type()).output_names;
for (auto& name : output_names) {
PADDLE_ENFORCE_EQ(op->Output(name).size(), 1U,
"Output(%s) of operation %s should be set.", name,
op->Type());
PADDLE_ENFORCE_NE(var_ids.find(op->Output(name)[0]), var_ids.end(),
"Output(%s) of operation %s should be set.", name,
op->Type());
output_ids.push_back(var_ids[op->Output(name)[0]]);
}
expressions.push_back(
OperationExpression(node->Name(), input_ids, output_ids));
}
}
return expressions;
}
// In order to get the right result of expression, we need to calculate and // In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector. // store the expression as suffix Expressions using vector.
std::string CodeGenerator::GenerateCode( std::string CodeGenerator::Generate(
std::string func_name, std::vector<OperationExpression> expressions) { std::string func_name, std::vector<OperationExpression> expressions) {
// Check whether all expressions are elementwise operations. // TODO(liuyiqun): Check whether all expressions are elementwise operations.
std::string dtype = "float";
std::set<int> input_ids = DistilInputIds(expressions);
std::set<int> output_ids = DistilOutputIds(expressions);
TemplateVariable template_var; TemplateVariable template_var;
template_var.Add("func_name", func_name); template_var.Add("func_name", func_name);
template_var.Add("parameters", EmitParameters(expressions, "float")); template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
template_var.Add("compute_body", EmitComputeBody(expressions)); template_var.Add("compute_body",
EmitComputeBody(expressions, input_ids, output_ids, dtype));
return predefined_cuda_functions + code_templates_[0].Format(template_var); return predefined_cuda_functions + code_templates_[0].Format(template_var);
} }
// we get the parameter list code for the expression information std::set<int> CodeGenerator::DistilInputIds(
std::string CodeGenerator::EmitParameters( const std::vector<OperationExpression>& expressions) {
std::vector<OperationExpression> expressions, std::string dtype) {
std::set<int> input_ids; std::set<int> input_ids;
std::set<int> output_ids; // Use std::set to remove the reptead id and get a ordered list.
// Remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetInputIds()) { for (auto id : expressions[i].GetInputIds()) {
input_ids.insert(id); if (id >= 0) {
input_ids.insert(id);
}
} }
}
return input_ids;
}
std::set<int> CodeGenerator::DistilOutputIds(
const std::vector<OperationExpression>& expressions) {
std::set<int> output_ids;
// Use std::set to remove the reptead id and get a ordered list.
for (size_t i = 0; i < expressions.size(); i++) {
for (auto id : expressions[i].GetOutputIds()) { for (auto id : expressions[i].GetOutputIds()) {
output_ids.insert(id); output_ids.insert(id);
} }
} }
return output_ids;
}
// we get the parameter list code for the expression information
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype) {
std::stringstream ret;
ret << "int N, ";
// If a id is in the input and output list at the same time, then remove it // If a id is in the input and output list at the same time, then remove it
// from the input list. // from the input list.
for (auto iter = input_ids.begin(); iter != input_ids.end();) { for (auto id : input_ids) {
if (output_ids.find(*iter) != output_ids.end()) { if (output_ids.find(id) == output_ids.end()) {
input_ids.erase(iter++); ret << dtype << "* " << ArgName(id) << ", ";
} else {
iter++;
} }
} }
std::stringstream ret; size_t index = 0;
ret << "int N, "; for (auto id : output_ids) {
for (auto iter = input_ids.begin(); iter != input_ids.end(); iter++) { ret << dtype << "* " << ArgName(id);
ret << dtype << "* " << VarName(*iter) << ", "; if (index != output_ids.size() - 1) {
}
size_t count_index = 0;
for (auto iter = output_ids.begin(); iter != output_ids.end(); iter++) {
ret << dtype << "* " << VarName(*iter);
if (count_index != output_ids.size() - 1) {
ret << ", "; ret << ", ";
} }
count_index++; index++;
} }
return ret.str(); return ret.str();
} }
std::string CodeGenerator::EmitComputeBody( std::string CodeGenerator::EmitComputeBody(
std::vector<OperationExpression> expressions) { const std::vector<OperationExpression>& expressions,
// get the right experssion code using suffix expression const std::set<int>& input_ids, const std::set<int>& output_ids,
std::stringstream ret; std::string dtype) {
std::ostringstream compute;
std::unordered_set<int> used;
for (size_t i = 0; i < expressions.size(); i++) { for (size_t i = 0; i < expressions.size(); i++) {
ret << expressions[i].GetExpression(); VLOG(3) << DebugString(expressions[i]);
compute << expressions[i].GetExpression(dtype, &used);
} }
return ret.str();
// Load input to temporal variables.
std::ostringstream load;
for (auto id : input_ids) {
if (output_ids.find(id) == output_ids.end() &&
used.find(id) != used.end()) {
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
}
}
// Store temporal variables to memory.
std::ostringstream store;
for (auto id : output_ids) {
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
}
return load.str() + compute.str() + store.str();
}
std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
SubGraph* subgraph) {
const auto& input_var_nodes = subgraph->GetInputVarNodes();
const auto& output_var_nodes = subgraph->GetOutputVarNodes();
int id = 0;
std::unordered_map<std::string, int> var_ids;
// Numbering input vars.
for (auto* in : input_var_nodes) {
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
if (var_ids.find(in->Name()) == var_ids.end()) {
var_ids[in->Name()] = id++;
}
}
// Numbering internal vars.
for (auto* node : subgraph->SortedNodes()) {
if (node && node->IsVar() && node->Var()) {
bool is_found = false;
for (auto* in : input_var_nodes) {
if (node == in) {
is_found = true;
break;
}
}
if (is_found) {
continue;
}
for (auto* out : output_var_nodes) {
if (node == out) {
is_found = true;
break;
}
}
PADDLE_ENFORCE_EQ(
is_found, true,
"Subgraph with internal var nodes (%s) is not supported yet.",
node->Name());
}
}
// Encoding output vars.
for (auto* out : output_var_nodes) {
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
if (var_ids.find(out->Name()) == var_ids.end()) {
var_ids[out->Name()] = id++;
}
}
return var_ids;
} }
} // namespace fusion_group } // namespace fusion_group

@ -14,9 +14,12 @@ limitations under the License. */
#pragma once #pragma once
#include <set>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
#include "paddle/fluid/framework/ir/fusion_group/subgraph.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -27,18 +30,31 @@ class CodeGenerator {
public: public:
CodeGenerator(); CodeGenerator();
std::string GenerateCode(std::string func_name, std::string Generate(std::string func_name,
std::vector<OperationExpression> expressions); std::vector<OperationExpression> expressions);
// TODO(wangchao): add a more general interface std::string Generate(SubGraph* subgraph);
// std::string Generate(const std::string name, const SubGraph& subgraph);
std::vector<OperationExpression> ConvertToExpressions(SubGraph* subgraph);
private: private:
std::set<int> DistilInputIds(
const std::vector<OperationExpression>& expressions);
std::set<int> DistilOutputIds(
const std::vector<OperationExpression>& expressions);
// we get the parameter list code for the expression information // we get the parameter list code for the expression information
std::string EmitParameters(std::vector<OperationExpression> expressions, std::string EmitParameters(const std::set<int>& input_ids,
const std::set<int>& output_ids,
std::string dtype); std::string dtype);
std::string EmitComputeBody(std::vector<OperationExpression> expressions); std::string EmitComputeBody(
const std::vector<OperationExpression>& expressions,
const std::set<int>& input_ids, const std::set<int>& output_ids,
std::string dtype);
// Encode all var nodes in the subgraph with an unique number.
std::unordered_map<std::string, int> EncodeVarNodes(SubGraph* subgraph);
private: private:
std::vector<CodeTemplate> code_templates_; std::vector<CodeTemplate> code_templates_;

@ -33,8 +33,9 @@ static T StringTo(const std::string& str) {
return value; return value;
} }
std::string OperationExpression::GetRHS(size_t i) { std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
auto rhs = OperationMap::Instance().Get(op_).exprs[i]; size_t i) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[i];
for (size_t i = 0; i < rhs.size(); i++) { for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i; size_t pos = i;
if (rhs[pos] == '$' && rhs[pos + 1] == '{') { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
@ -47,29 +48,33 @@ std::string OperationExpression::GetRHS(size_t i) {
PADDLE_ENFORCE_LT(index, input_ids_.size(), PADDLE_ENFORCE_LT(index, input_ids_.size(),
"Only %d inputs are provided, but need %d.", "Only %d inputs are provided, but need %d.",
input_ids_.size(), index + 1); input_ids_.size(), index + 1);
rhs.replace(pos, length + 3, VarName(input_ids_[index]) + R"([idx])"); PADDLE_ENFORCE_GE(input_ids_[index], 0,
"Input id should be no less than 0.");
rhs.replace(pos, length + 3, TmpName(input_ids_[index]));
used->insert(input_ids_[index]);
} }
} }
return rhs; return rhs;
} }
std::string OperationExpression::GetLHS(size_t i) { std::string OperationExpression::GetLHS(size_t i) const {
std::stringstream ret; std::stringstream ret;
ret << VarName(output_ids_[i]) << R"([idx])"; ret << TmpName(output_ids_[i]);
return ret.str(); return ret.str();
} }
bool OperationExpression::IsSupport() { bool OperationExpression::IsSupport() const {
return OperationMap::Instance().Has(op_); return OperationMap::Instance().Has(op_type_);
} }
// we Traverse the graph and get the group , all input id and output id is // we Traverse the graph and get the group , all input id and output id is
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression() { std::string OperationExpression::GetExpression(
std::string dtype, std::unordered_set<int>* used) const {
std::stringstream ret; std::stringstream ret;
if (IsSupport()) { if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) { for (size_t i = 0; i < output_ids_.size(); ++i) {
ret << GetLHS(i) << " = " << GetRHS(i) << ";"; ret << dtype << " " << GetLHS(i) << " = " << GetRHS(used, i) << ";";
} }
} }

@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
@ -27,28 +27,36 @@ namespace framework {
namespace ir { namespace ir {
namespace fusion_group { namespace fusion_group {
static std::string VarName(int index) { return "var" + std::to_string(index); } static inline std::string ArgName(int index) {
return "arg" + std::to_string(index);
}
static inline std::string TmpName(int index) {
return "tmp" + std::to_string(index);
}
class OperationExpression { class OperationExpression {
public: public:
explicit OperationExpression(std::string op, std::vector<int> input_ids, explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
std::vector<int> output_ids) std::vector<int> output_ids)
: op_(op), input_ids_(input_ids), output_ids_(output_ids) {} : op_type_(op_type), input_ids_(input_ids), output_ids_(output_ids) {}
std::vector<int> GetInputIds() { return input_ids_; } std::string GetOpType() const { return op_type_; }
std::vector<int> GetOutputIds() { return output_ids_; } std::vector<int> GetInputIds() const { return input_ids_; }
std::vector<int> GetOutputIds() const { return output_ids_; }
// Check whether this operation type is supported in OperationMap. // Check whether this operation type is supported in OperationMap.
bool IsSupport(); bool IsSupport() const;
std::string GetExpression(); std::string GetExpression(std::string dtype,
std::unordered_set<int>* used) const;
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset // TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(size_t i = 0); std::string GetRHS(std::unordered_set<int>* used, size_t i = 0) const;
std::string GetLHS(size_t i = 0); std::string GetLHS(size_t i = 0) const;
private: private:
std::string op_; std::string op_type_;
std::vector<int> input_ids_; std::vector<int> input_ids_;
std::vector<int> output_ids_; std::vector<int> output_ids_;
}; };
@ -58,6 +66,7 @@ class TemplateVariable {
void Add(std::string identifier, std::string expression) { void Add(std::string identifier, std::string expression) {
strings_[identifier] = expression; strings_[identifier] = expression;
} }
void Remove(std::string identifier, std::string expression) { void Remove(std::string identifier, std::string expression) {
for (auto it = strings_.begin(); it != strings_.end();) { for (auto it = strings_.begin(); it != strings_.end();) {
if (it->first == identifier) { if (it->first == identifier) {
@ -155,7 +164,6 @@ __device__ double real_max(double x, double y) { return ::fmax(x, y); }
)"; )";
static const char elementwise_cuda_template[] = R"( static const char elementwise_cuda_template[] = R"(
extern "C" __global__ void $func_name($parameters) { extern "C" __global__ void $func_name($parameters) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x; for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N; idx < N;
@ -165,6 +173,28 @@ extern "C" __global__ void $func_name($parameters) {
} }
)"; )";
static std::string DebugString(const OperationExpression& expr) {
std::stringstream ret;
ret << "Op(" << expr.GetOpType() << "), inputs:{";
auto input_ids = expr.GetInputIds();
for (size_t i = 0; i < input_ids.size(); ++i) {
if (i != 0) {
ret << ",";
}
ret << expr.GetInputIds()[i];
}
ret << "}, outputs:{";
auto output_ids = expr.GetOutputIds();
for (size_t i = 0; i < output_ids.size(); ++i) {
if (i != 0) {
ret << ",";
}
ret << expr.GetOutputIds()[i];
}
ret << "}";
return ret.str();
}
} // namespace fusion_group } // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework

@ -108,13 +108,6 @@ bool ElementwiseGroupDetector::IsOutputOfElementwiseOp(Node* n) {
return false; return false;
} }
void ElementwiseGroupDetector::Insert(Node* n) {
if (subgraph_.nodes_set.find(n) == subgraph_.nodes_set.end()) {
VLOG(5) << "Insert " << n->Name() << " to subgraph " << name_;
subgraph_.nodes_set.insert(n);
}
}
int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) { int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
std::unordered_set<Node*> except_nodes_set; std::unordered_set<Node*> except_nodes_set;
for (size_t i = 0; i < except_nodes.size(); ++i) { for (size_t i = 0; i < except_nodes.size(); ++i) {
@ -123,16 +116,16 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
int num_operations = 0; int num_operations = 0;
if (IsElementwiseOp(n)) { if (IsElementwiseOp(n)) {
Insert(n); subgraph_.Insert(n);
num_operations += 1; num_operations += 1;
for (auto* var : n->inputs) { for (auto* var : n->inputs) {
Insert(var); subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) { if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n}); num_operations += Search(var, {n});
} }
} }
for (auto* var : n->outputs) { for (auto* var : n->outputs) {
Insert(var); subgraph_.Insert(var);
if (except_nodes_set.find(var) == except_nodes_set.end()) { if (except_nodes_set.find(var) == except_nodes_set.end()) {
num_operations += Search(var, {n}); num_operations += Search(var, {n});
} }
@ -157,7 +150,7 @@ int ElementwiseGroupDetector::Search(Node* n, std::vector<Node*> except_nodes) {
int ElementwiseGroupDetector::operator()(Node* n) { int ElementwiseGroupDetector::operator()(Node* n) {
if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) { if (!IsOutputOfElementwiseOp(n) && IsInputOfElementwiseOp(n, "X")) {
name_ = n->Name(); name_ = n->Name();
Insert(n); subgraph_.Insert(n);
num_operations_ = Search(n, n->inputs); num_operations_ = Search(n, n->inputs);
VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", " VLOG(4) << "Detect elementwise subgraph begin with " << name_ << ", "
<< num_operations_ << " operations, " << GetSubgraph().GetNumNodes() << num_operations_ << " operations, " << GetSubgraph().GetNumNodes()

@ -36,7 +36,6 @@ class ElementwiseGroupDetector {
bool IsInputOfElementwiseOp(Node* n, std::string name = ""); bool IsInputOfElementwiseOp(Node* n, std::string name = "");
bool IsOutputOfElementwiseOp(Node* n); bool IsOutputOfElementwiseOp(Node* n);
void Insert(Node* n);
int Search(Node* n, std::vector<Node*> except_nodes = {}); int Search(Node* n, std::vector<Node*> except_nodes = {});
private: private:

@ -36,7 +36,7 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
for (Node* n : all_nodes) { for (Node* n : all_nodes) {
bool is_found = false; bool is_found = false;
for (auto& subgraph : subgraphs) { for (auto& subgraph : subgraphs) {
if (subgraph.nodes_set.find(n) != subgraph.nodes_set.end()) { if (subgraph.Has(n)) {
is_found = true; is_found = true;
break; break;
} }
@ -61,15 +61,17 @@ int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
// TODO(liuyiqun): check whether there are intersection between subgraphs // TODO(liuyiqun): check whether there are intersection between subgraphs
for (size_t i = 0; i < subgraphs.size(); ++i) { for (size_t i = 0; i < subgraphs.size(); ++i) {
InsertFusionGroupOp(graph, subgraphs[i]); InsertFusionGroupOp(graph, &subgraphs[i]);
} }
return subgraphs.size(); return subgraphs.size();
} }
void FusionGroupPass::InsertFusionGroupOp( void FusionGroupPass::InsertFusionGroupOp(
Graph* graph, const fusion_group::SubGraph& subgraph) const { Graph* graph, fusion_group::SubGraph* subgraph) const {
std::vector<Node*> input_vars_of_subgraph = subgraph.GetInputVarNodes(); const std::vector<Node*>& input_vars_of_subgraph =
std::vector<Node*> output_vars_of_subgraph = subgraph.GetOutputVarNodes(); subgraph->GetInputVarNodes();
const std::vector<Node*>& output_vars_of_subgraph =
subgraph->GetOutputVarNodes();
std::unordered_set<Node*> external_nodes; std::unordered_set<Node*> external_nodes;
OpDesc op_desc; OpDesc op_desc;
@ -88,8 +90,8 @@ void FusionGroupPass::InsertFusionGroupOp(
external_nodes.insert(n); external_nodes.insert(n);
} }
op_desc.SetOutput("Outs", output_names); op_desc.SetOutput("Outs", output_names);
op_desc.SetAttr("type", subgraph.type); op_desc.SetAttr("type", subgraph->type);
op_desc.SetAttr("func_name", subgraph.func_name); op_desc.SetAttr("func_name", subgraph->func_name);
auto fusion_group_node = graph->CreateOpNode(&op_desc); auto fusion_group_node = graph->CreateOpNode(&op_desc);
for (auto* in : input_vars_of_subgraph) { for (auto* in : input_vars_of_subgraph) {
@ -100,7 +102,7 @@ void FusionGroupPass::InsertFusionGroupOp(
} }
std::unordered_set<const Node*> internal_nodes; std::unordered_set<const Node*> internal_nodes;
for (auto* n : subgraph.nodes_set) { for (auto* n : subgraph->Nodes()) {
if (external_nodes.find(n) == external_nodes.end()) { if (external_nodes.find(n) == external_nodes.end()) {
internal_nodes.insert(n); internal_nodes.insert(n);
} }

@ -30,7 +30,7 @@ class FusionGroupPass : public Pass {
private: private:
int DetectFusionGroup(Graph* graph, int type = 0) const; int DetectFusionGroup(Graph* graph, int type = 0) const;
void InsertFusionGroupOp(Graph* graph, void InsertFusionGroupOp(Graph* graph,
const fusion_group::SubGraph& subgraph) const; fusion_group::SubGraph* subgraph) const;
const std::string name_scope_{"fusion_group"}; const std::string name_scope_{"fusion_group"};
}; };

@ -22,6 +22,14 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void VisualizeGraph(std::unique_ptr<Graph> graph, std::string graph_viz_path) {
// Insert a graph_viz_pass to transform the graph to a .dot file.
// It can be used for debug.
auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
graph_viz_pass->Set("graph_viz_path", new std::string(graph_viz_path));
graph.reset(graph_viz_pass->Apply(graph.release()));
}
TEST(FusionGroupPass, elementwise_list) { TEST(FusionGroupPass, elementwise_list) {
fusion_group::OperationMap::Init(); fusion_group::OperationMap::Init();
@ -46,29 +54,17 @@ TEST(FusionGroupPass, elementwise_list) {
layers.elementwise_add(tmp_2, w); layers.elementwise_add(tmp_2, w);
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// VisualizeGraph(graph, "00_elementwise_list.dot");
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_list.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); graph.reset(fusion_group_pass->Apply(graph.release()));
// VisualizeGraph(graph, "01_elementwise_list.fusion_group.dot");
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1); PADDLE_ENFORCE_EQ(num_fusion_group_ops, 1);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_list.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
} }
TEST(FusionGroupPass, elementwise_tree) { TEST(FusionGroupPass, elementwise_tree) {
@ -128,29 +124,17 @@ TEST(FusionGroupPass, elementwise_tree) {
layers.mul(tmp_6, tmp_9); layers.mul(tmp_6, tmp_9);
std::unique_ptr<Graph> graph(new Graph(layers.main_program())); std::unique_ptr<Graph> graph(new Graph(layers.main_program()));
// VisualizeGraph(graph, "00_elementwise_tree.dot");
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("00_elementwise_tree.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass"); auto fusion_group_pass = PassRegistry::Instance().Get("fusion_group_pass");
LOG(INFO) << DebugString(graph); VLOG(3) << DebugString(graph);
graph.reset(fusion_group_pass->Apply(graph.release())); graph.reset(fusion_group_pass->Apply(graph.release()));
// VisualizeGraph(graph, "01_elementwise_tree.fusion_group.dot");
int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group"); int num_fusion_group_ops = GetNumOpNodes(graph, "fusion_group");
LOG(INFO) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2); PADDLE_ENFORCE_EQ(num_fusion_group_ops, 2);
// The following codes is to insert a graph_viz_pass to transform the graph to
// a .dot file. It is used for debug.
// auto graph_viz_pass = PassRegistry::Instance().Get("graph_viz_pass");
// graph_viz_pass->Set("graph_viz_path", new
// std::string("01_elementwise_tree.fusion_group.dot"));
// graph.reset(graph_viz_pass->Apply(graph.release()));
} }
} // namespace ir } // namespace ir

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -38,15 +39,30 @@ std::unordered_set<std::string> OperationMap::Find(int type, int num_operands) {
} }
void OperationMap::Insert(int type, int num_operands, std::string op_type, void OperationMap::Insert(int type, int num_operands, std::string op_type,
std::string expr, std::string expr, std::vector<std::string> grad_exprs,
std::vector<std::string> grad_exprs) { std::vector<std::string> input_names,
Operation op(type, num_operands, op_type, {expr}); std::vector<std::string> output_names) {
Operation op(type, num_operands, op_type, {expr}, input_names, output_names);
PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type); PADDLE_ENFORCE_EQ(op.IsValid(), true, "Operation %s is invalid.", op_type);
operations_[op_type] = op; operations_[op_type] = op;
if (grad_exprs.size() > 0U) { if (grad_exprs.size() > 0U) {
std::string grad_op_type = op_type + "_grad"; std::string grad_op_type = op_type + "_grad";
Operation grad_op(type, num_operands, grad_op_type, grad_exprs); // grad_inputs = inputs + outputs + grad of outputs
std::vector<std::string> grad_input_names = input_names;
for (auto name : output_names) {
grad_input_names.push_back(name);
}
for (auto name : output_names) {
grad_input_names.push_back(GradVarName(name));
}
// grad_output = grad of inputs
std::vector<std::string> grad_output_names;
for (auto name : input_names) {
grad_output_names.push_back(GradVarName(name));
}
Operation grad_op(type, num_operands, grad_op_type, grad_exprs,
grad_input_names, grad_output_names);
PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.", PADDLE_ENFORCE_EQ(grad_op.IsValid(), true, "Operation %s is invalid.",
grad_op_type); grad_op_type);
operations_[grad_op_type] = grad_op; operations_[grad_op_type] = grad_op;
@ -54,59 +70,65 @@ void OperationMap::Insert(int type, int num_operands, std::string op_type,
} }
void OperationMap::InsertUnaryElementwiseOperations() { void OperationMap::InsertUnaryElementwiseOperations() {
int type = 0;
int num_oprands = 1;
// For unary elementwise operations: // For unary elementwise operations:
// ${0} - x // ${0} - x
// ${1} - out // ${1} - out
// ${2} - dout // ${2} - dout
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = 1;
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X"}, {"Out"});
};
// relu: // relu:
// out = f(x) = x > 0 ? x : 0 // out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0) // dx = dout * (out > 0 ? 1 : 0) = dout * (x > 0 ? 1 : 0)
Insert(type, num_oprands, "relu", "real_max(${0}, 0)", insert_handler("relu", "real_max(${0}, 0)", {"${0} > 0 ? ${2} : 0"});
{"${0} > 0 ? ${2} : 0"});
// sigmoid: // sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x)) // out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
Insert(type, num_oprands, "sigmoid", "1.0 / (1.0 + real_exp(- ${0}))", insert_handler("sigmoid", "1.0 / (1.0 + real_exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"}); {"${2} * ${1} * (1.0 - ${1})"});
// tanh: // tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0; // out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out) // dx = dout * (1 - out * out)
Insert(type, num_oprands, "tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0", insert_handler("tanh", "2.0 / (1.0 + real_exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"}); {"${2} * (1.0 - ${1} * ${1})"});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
int type = 0;
int num_oprands = 2;
// For binary elementwise oprations: // For binary elementwise oprations:
// ${0} - x // ${0} - x
// ${1} - y // ${1} - y
// ${2} - out // ${2} - out
// ${3} - dout // ${3} - dout
auto insert_handler = [&](std::string op_type, std::string expr,
std::vector<std::string> grad_exprs) {
int type = 0;
int num_oprands = 2;
Insert(type, num_oprands, op_type, expr, grad_exprs, {"X", "Y"}, {"Out"});
};
// elementwise_add: // elementwise_add:
// out = x + y // out = x + y
// dx = dout * 1 // dx = dout * 1
// dy = dout * 1 // dy = dout * 1
Insert(type, num_oprands, "elementwise_add", "${0} + ${1}", {"${3}", "${3}"}); insert_handler("elementwise_add", "${0} + ${1}", {"${3}", "${3}"});
// elementwise_sub: // elementwise_sub:
// out = x - y // out = x - y
// dx = dout * 1 // dx = dout * 1
// dy = dout * (-1) // dy = dout * (-1)
Insert(type, num_oprands, "elementwise_sub", "${0} - ${1}", insert_handler("elementwise_sub", "${0} - ${1}", {"${3}", "- ${3}"});
{"${3}", "- ${3}"});
// elementwise_mul: // elementwise_mul:
// out = x * y // out = x * y
// dx = dout * y // dx = dout * y
// dy = dout * x // dy = dout * x
Insert(type, num_oprands, "elementwise_mul", "${0} * ${1}", insert_handler("elementwise_mul", "${0} * ${1}",
{"${3} * ${1}", "${3} * ${0}"}); {"${3} * ${1}", "${3} * ${0}"});
Insert(type, num_oprands, "elementwise_div", "${0} / ${1}", {}); insert_handler("elementwise_div", "${0} / ${1}", {});
Insert(type, num_oprands, "elementwise_min", "real_min(${0}, ${1})", {}); insert_handler("elementwise_min", "real_min(${0}, ${1})", {});
Insert(type, num_oprands, "elementwise_max", "real_max(${0}, ${1})", {}); insert_handler("elementwise_max", "real_max(${0}, ${1})", {});
} }
} // namespace fusion_group } // namespace fusion_group

@ -26,20 +26,32 @@ namespace ir {
namespace fusion_group { namespace fusion_group {
struct Operation { struct Operation {
Operation() {} Operation() = default;
Operation(int t, int n, std::string o, std::vector<std::string> e) Operation(int t, int n, std::string o, std::vector<std::string> e,
: type(t), num_operands(n), op_type(o), exprs(e) {} std::vector<std::string> i_n, std::vector<std::string> o_n)
: type(t),
num_operands(n),
op_type(o),
exprs(e),
input_names(i_n),
output_names(o_n) {}
bool IsGradOp() { bool IsGradOp() {
std::string suffix = "_grad"; std::string suffix = "_grad";
return op_type.rfind(suffix) == (op_type.length() - suffix.length()); size_t pos = op_type.rfind(suffix);
return pos != std::string::npos &&
pos == (op_type.length() - suffix.length());
} }
bool IsValid() { bool IsValid() {
if (!IsGradOp() && exprs.size() != 1U) { if (!IsGradOp() && exprs.size() != 1U) {
// When it is a forward operation, it should hold only one expression (for
// only one output).
return false; return false;
} }
if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) { if (IsGradOp() && exprs.size() != static_cast<size_t>(num_operands)) {
// When it is a backward opertion, it should hold a expression for each
// operand.
return false; return false;
} }
return true; return true;
@ -49,6 +61,8 @@ struct Operation {
int num_operands; int num_operands;
std::string op_type; std::string op_type;
std::vector<std::string> exprs; std::vector<std::string> exprs;
std::vector<std::string> input_names;
std::vector<std::string> output_names;
}; };
class OperationMap { class OperationMap {
@ -83,7 +97,9 @@ class OperationMap {
private: private:
void Insert(int type, int num_operands, std::string op_type, std::string expr, void Insert(int type, int num_operands, std::string op_type, std::string expr,
std::vector<std::string> grad_exprs); std::vector<std::string> grad_exprs,
std::vector<std::string> input_names,
std::vector<std::string> output_names);
void InsertUnaryElementwiseOperations(); void InsertUnaryElementwiseOperations();
void InsertBinaryElementwiseOperations(); void InsertBinaryElementwiseOperations();

File diff suppressed because it is too large Load Diff

@ -19,7 +19,10 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -267,6 +270,47 @@ struct Layers {
return outs; return outs;
} }
void backward() {
BlockDesc* block = program_.MutableBlock(0);
std::vector<OpDesc*> forward_ops = block->AllOps();
for (int i = forward_ops.size() - 1; i >= 0; --i) {
OpDesc* op = forward_ops[i];
OpDesc* grad_op = block->AppendOp();
grad_op->SetType(op->Type() + "_grad");
// All op's inputs are grad_op's input.
for (auto name : op->InputNames()) {
grad_op->SetInput(name, op->Input(name));
}
// All op's outputs are grad_op's input.
for (auto name : op->OutputNames()) {
grad_op->SetInput(name, op->Output(name));
}
// All op's outputs grad are grad_op's input.
for (auto name : op->OutputNames()) {
std::vector<std::string> grad_var_names;
for (auto var_name : op->Output(name)) {
VarDesc* var = block->FindVar(var_name);
VarDesc* grad_var =
lod_tensor(GradVarName(var_name), var->GetShape(), false);
grad_var_names.push_back(grad_var->Name());
}
grad_op->SetInput(GradVarName(name), grad_var_names);
}
// All op's inputs grad are grad_op's output.
for (auto name : op->InputNames()) {
std::vector<std::string> grad_var_names;
for (auto var_name : op->Input(name)) {
VarDesc* var = block->FindVar(var_name);
VarDesc* grad_var =
lod_tensor(GradVarName(var_name), var->GetShape(), false);
grad_var_names.push_back(grad_var->Name());
}
grad_op->SetOutput(GradVarName(name), grad_var_names);
}
// TODO(liuyiqun): attrs
}
}
private: private:
VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {}, VarDesc* lod_tensor(std::string name, std::vector<int64_t> shape = {},
bool is_persistable = false) { bool is_persistable = false) {
@ -412,7 +456,7 @@ static std::string DebugString(Node* node) {
return os.str(); return os.str();
} }
static std::string DebugString(const std::unordered_set<Node*>& nodes) { static std::string DebugString(const std::vector<Node*>& nodes) {
std::ostringstream os; std::ostringstream os;
for (auto* node : nodes) { for (auto* node : nodes) {
if (node->IsOp() && node->Op()) { if (node->IsOp() && node->Op()) {
@ -425,6 +469,14 @@ static std::string DebugString(const std::unordered_set<Node*>& nodes) {
return os.str(); return os.str();
} }
static std::string DebugString(const std::unordered_set<Node*>& nodes) {
std::vector<Node*> vec;
for (auto* node : nodes) {
vec.push_back(node);
}
return DebugString(vec);
}
static std::string DebugString(const std::unique_ptr<Graph>& graph) { static std::string DebugString(const std::unique_ptr<Graph>& graph) {
std::ostringstream os; std::ostringstream os;
os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n"; os << "Graph: {\n" << DebugString(graph->Nodes()) << "}\n";

Loading…
Cancel
Save