|
|
|
@ -68,11 +68,35 @@ static bool HasInput(Node* n, std::string name) {
|
|
|
|
|
return input_names_set.find(name) != input_names_set.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Node* GetInputVar(Node* n, const std::string& name) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected node %p to be an operator node.", n));
|
|
|
|
|
for (auto* in : n->inputs) {
|
|
|
|
|
if (in->Name() == name) {
|
|
|
|
|
return in;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static Node* GetOutputVar(Node* n, const std::string& name) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected node %p to be an operator node.", n));
|
|
|
|
|
for (auto* out : n->outputs) {
|
|
|
|
|
if (out->Name() == name) {
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
SubGraph* subgraph) {
|
|
|
|
|
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
|
|
|
|
|
std::vector<Node*> intermediate_out_nodes =
|
|
|
|
|
subgraph->GetIntermediateOutVarNodes();
|
|
|
|
|
std::unordered_map<Node*, int> var_ids = EncodeVarNodes(subgraph);
|
|
|
|
|
std::unordered_set<Node*> intermediate_out_vars_set =
|
|
|
|
|
subgraph->GetIntermediateOutVarNodesSet();
|
|
|
|
|
std::vector<OperationExpression> expressions;
|
|
|
|
|
for (auto* node : subgraph->SortedNodes()) {
|
|
|
|
|
if (node && node->IsOp() && node->Op()) {
|
|
|
|
@ -92,11 +116,12 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
// "elementwise_add_grad", where "X", "Y" and "Out" are not used.
|
|
|
|
|
if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
|
|
|
|
|
for (size_t i = 0; i < op->Input(name).size(); i++) {
|
|
|
|
|
Node* input_var = GetInputVar(node, op->Input(name)[i]);
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
var_ids.find(op->Input(name)[i]), var_ids.end(),
|
|
|
|
|
var_ids.find(input_var), var_ids.end(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(%s) of operation %s is not set.", name, op->Type()));
|
|
|
|
|
input_ids.push_back(var_ids[op->Input(name)[i]]);
|
|
|
|
|
input_ids.push_back(var_ids[input_var]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
input_ids.push_back(-1);
|
|
|
|
@ -106,31 +131,29 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
// Output ids should be set in fixed order, like:
|
|
|
|
|
// - dx, dy in backward operations
|
|
|
|
|
std::vector<int> output_ids;
|
|
|
|
|
std::vector<int> intermediate_output_ids;
|
|
|
|
|
std::vector<std::string> output_names =
|
|
|
|
|
OperationMap::Instance().Get(op->Type()).output_names;
|
|
|
|
|
std::unordered_map<int, bool> intermediate_state;
|
|
|
|
|
|
|
|
|
|
for (auto& name : output_names) {
|
|
|
|
|
Node* output_var = GetOutputVar(node, op->Output(name)[0]);
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
var_ids.find(op->Output(name)[0]), var_ids.end(),
|
|
|
|
|
var_ids.find(output_var), var_ids.end(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(%s) of operation %s is not set.", name, op->Type()));
|
|
|
|
|
output_ids.push_back(var_ids[op->Output(name)[0]]);
|
|
|
|
|
bool enable_intermediate = false;
|
|
|
|
|
for (auto* n : intermediate_out_nodes) {
|
|
|
|
|
if (n->Name() == op->Output(name)[0]) {
|
|
|
|
|
enable_intermediate = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
output_ids.push_back(var_ids[output_var]);
|
|
|
|
|
if (!subgraph->SaveIntermediateOut() &&
|
|
|
|
|
intermediate_out_vars_set.find(output_var) !=
|
|
|
|
|
intermediate_out_vars_set.end()) {
|
|
|
|
|
intermediate_output_ids.push_back(var_ids[output_var]);
|
|
|
|
|
}
|
|
|
|
|
intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string lhs_type = ExtractDataType(node->outputs);
|
|
|
|
|
std::string rhs_type = ExtractDataType(node->inputs);
|
|
|
|
|
auto expression =
|
|
|
|
|
OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
|
|
|
|
|
lhs_type, intermediate_state);
|
|
|
|
|
lhs_type, intermediate_output_ids);
|
|
|
|
|
expression.SetAttr(attr);
|
|
|
|
|
expressions.push_back(expression);
|
|
|
|
|
}
|
|
|
|
@ -146,17 +169,18 @@ std::string CodeGenerator::Generate(
|
|
|
|
|
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
|
|
|
|
|
std::set<int> input_ids = std::move(DistilInputIds(expressions));
|
|
|
|
|
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
|
|
|
|
|
std::set<int> intermediate_ids =
|
|
|
|
|
std::set<int> intermediate_output_ids =
|
|
|
|
|
std::move(DistilIntermediateIds(expressions));
|
|
|
|
|
std::unordered_map<int, std::string> dtypes =
|
|
|
|
|
std::move(DistilDtypes(expressions));
|
|
|
|
|
TemplateVariable template_var;
|
|
|
|
|
template_var.Add("func_name", func_name);
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids,
|
|
|
|
|
intermediate_ids, dtypes));
|
|
|
|
|
template_var.Add(
|
|
|
|
|
"parameters",
|
|
|
|
|
EmitParameters(input_ids, output_ids, intermediate_output_ids, dtypes));
|
|
|
|
|
template_var.Add("compute_body",
|
|
|
|
|
EmitComputeBody(expressions, input_ids, output_ids,
|
|
|
|
|
intermediate_ids, dtypes));
|
|
|
|
|
intermediate_output_ids, dtypes));
|
|
|
|
|
|
|
|
|
|
std::set<std::string> all_dtype;
|
|
|
|
|
for (const auto& type : dtypes) {
|
|
|
|
@ -204,18 +228,14 @@ std::set<int> CodeGenerator::DistilOutputIds(
|
|
|
|
|
|
|
|
|
|
std::set<int> CodeGenerator::DistilIntermediateIds(
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
std::set<int> intermediate_ids;
|
|
|
|
|
std::set<int> intermediate_output_ids;
|
|
|
|
|
// Use std::set to remove the reptead id and get a ordered list.
|
|
|
|
|
for (size_t i = 0; i < expressions.size(); i++) {
|
|
|
|
|
for (auto id : expressions[i].GetOutputIds()) {
|
|
|
|
|
auto intermediate_state = expressions[i].GetIntermediateState();
|
|
|
|
|
if (intermediate_state.find(id) != intermediate_state.end() &&
|
|
|
|
|
intermediate_state[id]) {
|
|
|
|
|
intermediate_ids.insert(id);
|
|
|
|
|
}
|
|
|
|
|
for (auto id : expressions[i].GetIntermediateOutputIds()) {
|
|
|
|
|
intermediate_output_ids.insert(id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return intermediate_ids;
|
|
|
|
|
return intermediate_output_ids;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
|
|
|
|
@ -316,26 +336,29 @@ std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
return load.str() + compute.str() + store.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
|
|
|
|
|
std::unordered_map<Node*, int> CodeGenerator::EncodeVarNodes(
|
|
|
|
|
SubGraph* subgraph) {
|
|
|
|
|
const auto& input_var_nodes = subgraph->GetInputVarNodes();
|
|
|
|
|
const auto& output_var_nodes = subgraph->GetOutputVarNodes();
|
|
|
|
|
// Encode all var nodes, including intermediate output var nodes.
|
|
|
|
|
const auto& output_var_nodes = subgraph->GetOutputVarNodes(true);
|
|
|
|
|
|
|
|
|
|
int id = 0;
|
|
|
|
|
std::unordered_map<std::string, int> var_ids;
|
|
|
|
|
std::unordered_map<Node*, int> var_ids;
|
|
|
|
|
// Numbering input vars.
|
|
|
|
|
for (auto* in : input_var_nodes) {
|
|
|
|
|
VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
|
|
|
|
|
if (var_ids.find(in->Name()) == var_ids.end()) {
|
|
|
|
|
var_ids[in->Name()] = id++;
|
|
|
|
|
VLOG(3) << "Encoding input names:" << in->Name() << "(" << in
|
|
|
|
|
<< "), id:" << id;
|
|
|
|
|
if (var_ids.find(in) == var_ids.end()) {
|
|
|
|
|
var_ids[in] = id++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Encoding output vars.
|
|
|
|
|
for (auto* out : output_var_nodes) {
|
|
|
|
|
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
|
|
|
|
|
if (var_ids.find(out->Name()) == var_ids.end()) {
|
|
|
|
|
var_ids[out->Name()] = id++;
|
|
|
|
|
VLOG(3) << "Ecoding output names:" << out->Name() << "(" << out
|
|
|
|
|
<< "), id:" << id;
|
|
|
|
|
if (var_ids.find(out) == var_ids.end()) {
|
|
|
|
|
var_ids[out] = id++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return var_ids;
|
|
|
|
|