|
|
|
@ -71,6 +71,8 @@ static bool HasInput(Node* n, std::string name) {
|
|
|
|
|
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
SubGraph* subgraph) {
|
|
|
|
|
std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
|
|
|
|
|
std::vector<Node*> intermediate_out_nodes =
|
|
|
|
|
subgraph->GetIntermediateOutVarNodes();
|
|
|
|
|
std::vector<OperationExpression> expressions;
|
|
|
|
|
for (auto* node : subgraph->SortedNodes()) {
|
|
|
|
|
if (node && node->IsOp() && node->Op()) {
|
|
|
|
@ -81,7 +83,8 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
// - X, Y in forward operations
|
|
|
|
|
// - X, Y, Out, out@GRAD in backward operations
|
|
|
|
|
std::vector<int> input_ids;
|
|
|
|
|
auto operation = OperationMap::Instance().Get(op->Type());
|
|
|
|
|
std::string op_name = op->Type();
|
|
|
|
|
auto operation = OperationMap::Instance().Get(op_name);
|
|
|
|
|
std::vector<std::string> input_names = operation.input_names;
|
|
|
|
|
|
|
|
|
|
for (auto& name : input_names) {
|
|
|
|
@ -105,6 +108,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
std::vector<int> output_ids;
|
|
|
|
|
std::vector<std::string> output_names =
|
|
|
|
|
OperationMap::Instance().Get(op->Type()).output_names;
|
|
|
|
|
std::unordered_map<int, bool> intermediate_state;
|
|
|
|
|
|
|
|
|
|
for (auto& name : output_names) {
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
@ -112,12 +116,21 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(%s) of operation %s is not set.", name, op->Type()));
|
|
|
|
|
output_ids.push_back(var_ids[op->Output(name)[0]]);
|
|
|
|
|
bool enable_intermediate = false;
|
|
|
|
|
for (auto* n : intermediate_out_nodes) {
|
|
|
|
|
if (n->Name() == op->Output(name)[0]) {
|
|
|
|
|
enable_intermediate = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string lhs_type = ExtractDataType(node->outputs);
|
|
|
|
|
std::string rhs_type = ExtractDataType(node->inputs);
|
|
|
|
|
auto expression = OperationExpression(node->Name(), input_ids, output_ids,
|
|
|
|
|
rhs_type, lhs_type);
|
|
|
|
|
auto expression =
|
|
|
|
|
OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
|
|
|
|
|
lhs_type, intermediate_state);
|
|
|
|
|
expression.SetAttr(attr);
|
|
|
|
|
expressions.push_back(expression);
|
|
|
|
|
}
|
|
|
|
@ -133,13 +146,17 @@ std::string CodeGenerator::Generate(
|
|
|
|
|
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
|
|
|
|
|
std::set<int> input_ids = std::move(DistilInputIds(expressions));
|
|
|
|
|
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
|
|
|
|
|
std::set<int> intermediate_ids =
|
|
|
|
|
std::move(DistilIntermediateIds(expressions));
|
|
|
|
|
std::unordered_map<int, std::string> dtypes =
|
|
|
|
|
std::move(DistilDtypes(expressions));
|
|
|
|
|
TemplateVariable template_var;
|
|
|
|
|
template_var.Add("func_name", func_name);
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids,
|
|
|
|
|
intermediate_ids, dtypes));
|
|
|
|
|
template_var.Add("compute_body",
|
|
|
|
|
EmitComputeBody(expressions, input_ids, output_ids, dtypes));
|
|
|
|
|
EmitComputeBody(expressions, input_ids, output_ids,
|
|
|
|
|
intermediate_ids, dtypes));
|
|
|
|
|
|
|
|
|
|
std::set<std::string> all_dtype;
|
|
|
|
|
for (const auto& type : dtypes) {
|
|
|
|
@ -185,6 +202,19 @@ std::set<int> CodeGenerator::DistilOutputIds(
|
|
|
|
|
return output_ids;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<int> CodeGenerator::DistilIntermediateIds(
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
std::set<int> intermediate_ids;
|
|
|
|
|
// Use std::set to remove the reptead id and get a ordered list.
|
|
|
|
|
for (size_t i = 0; i < expressions.size(); i++) {
|
|
|
|
|
for (auto id : expressions[i].GetOutputIds()) {
|
|
|
|
|
auto intermediate_state = expressions[i].GetIntermediateState();
|
|
|
|
|
if (intermediate_state[id]) intermediate_ids.insert(id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return intermediate_ids;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
std::unordered_map<int, std::string> dtypes;
|
|
|
|
@ -218,6 +248,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
|
|
|
|
|
// we get the parameter list code for the expression information
|
|
|
|
|
std::string CodeGenerator::EmitParameters(
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
const std::set<int>& intermediate_ids,
|
|
|
|
|
const std::unordered_map<int, std::string>& dtypes) const {
|
|
|
|
|
std::stringstream ret;
|
|
|
|
|
ret << "int N, ";
|
|
|
|
@ -226,25 +257,28 @@ std::string CodeGenerator::EmitParameters(
|
|
|
|
|
// from the input list.
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end()) {
|
|
|
|
|
ret << dtypes.at(id) << "* " << ArgName(id) << ", ";
|
|
|
|
|
ret << "const " << dtypes.at(id) << "* __restrict__ " << ArgName(id)
|
|
|
|
|
<< ", ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
for (auto id : output_ids) {
|
|
|
|
|
if (intermediate_ids.find(id) == intermediate_ids.end()) {
|
|
|
|
|
ret << dtypes.at(id) << "* " << ArgName(id);
|
|
|
|
|
if (index != output_ids.size() - 1) {
|
|
|
|
|
ret << ", ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
index++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
const std::vector<OperationExpression>& expressions,
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
const std::set<int>& intermediate_ids,
|
|
|
|
|
const std::unordered_map<int, std::string>& dtypes) const {
|
|
|
|
|
std::ostringstream compute;
|
|
|
|
|
std::unordered_set<int> used;
|
|
|
|
@ -258,15 +292,18 @@ std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end() &&
|
|
|
|
|
used.find(id) != used.end()) {
|
|
|
|
|
load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id)
|
|
|
|
|
load << dtypes.at(id) << " " << TmpName(id) << " = "
|
|
|
|
|
<< "__ldg(&" << VarName(id) << ")"
|
|
|
|
|
<< ";";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Store temporal variables to memory.
|
|
|
|
|
std::ostringstream store;
|
|
|
|
|
for (auto id : output_ids) {
|
|
|
|
|
if (intermediate_ids.find(id) == intermediate_ids.end()) {
|
|
|
|
|
store << VarName(id) << " = " << TmpName(id) << ";";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return load.str() + compute.str() + store.str();
|
|
|
|
|
}
|
|
|
|
@ -285,32 +322,7 @@ std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
|
|
|
|
|
var_ids[in->Name()] = id++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Numbering internal vars.
|
|
|
|
|
for (auto* node : subgraph->SortedNodes()) {
|
|
|
|
|
if (node && node->IsVar() && node->Var()) {
|
|
|
|
|
bool is_found = false;
|
|
|
|
|
for (auto* in : input_var_nodes) {
|
|
|
|
|
if (node == in) {
|
|
|
|
|
is_found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (is_found) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto* out : output_var_nodes) {
|
|
|
|
|
if (node == out) {
|
|
|
|
|
is_found = true;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
is_found, true,
|
|
|
|
|
platform::errors::Unimplemented(
|
|
|
|
|
"Subgraph with internal var nodes (%s) is not supported yet.",
|
|
|
|
|
node->Name()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Encoding output vars.
|
|
|
|
|
for (auto* out : output_var_nodes) {
|
|
|
|
|
VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
|
|
|
|
|