|
|
|
@ -24,7 +24,7 @@ namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace fusion_group {
|
|
|
|
|
|
|
|
|
|
std::string ExtractDataType(const std::vector<Node*> nodes) {
|
|
|
|
|
std::string ExtractDataType(const std::vector<Node*>& nodes) {
|
|
|
|
|
std::string dtype_str = "float";
|
|
|
|
|
auto data_type = nodes.back()->Var()->GetDataType();
|
|
|
|
|
|
|
|
|
@ -98,6 +98,7 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
std::vector<int> output_ids;
|
|
|
|
|
std::vector<std::string> output_names =
|
|
|
|
|
OperationMap::Instance().Get(op->Type()).output_names;
|
|
|
|
|
|
|
|
|
|
for (auto& name : output_names) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op->Output(name).size(), 1U,
|
|
|
|
@ -125,9 +126,10 @@ std::string CodeGenerator::Generate(
|
|
|
|
|
std::string func_name,
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
|
|
|
|
|
std::set<int> input_ids = DistilInputIds(expressions);
|
|
|
|
|
std::set<int> output_ids = DistilOutputIds(expressions);
|
|
|
|
|
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions);
|
|
|
|
|
std::set<int> input_ids = std::move(DistilInputIds(expressions));
|
|
|
|
|
std::set<int> output_ids = std::move(DistilOutputIds(expressions));
|
|
|
|
|
std::unordered_map<int, std::string> dtypes =
|
|
|
|
|
std::move(DistilDtypes(expressions));
|
|
|
|
|
TemplateVariable template_var;
|
|
|
|
|
template_var.Add("func_name", func_name);
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
|
|
|
|
@ -211,7 +213,7 @@ std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
|
|
|
|
|
// we get the parameter list code for the expression information
|
|
|
|
|
std::string CodeGenerator::EmitParameters(
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
std::unordered_map<int, std::string> dtypes) {
|
|
|
|
|
const std::unordered_map<int, std::string>& dtypes) const {
|
|
|
|
|
std::stringstream ret;
|
|
|
|
|
ret << "int N, ";
|
|
|
|
|
|
|
|
|
@ -219,13 +221,13 @@ std::string CodeGenerator::EmitParameters(
|
|
|
|
|
// from the input list.
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end()) {
|
|
|
|
|
ret << dtypes[id] << "* " << ArgName(id) << ", ";
|
|
|
|
|
ret << dtypes.at(id) << "* " << ArgName(id) << ", ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
for (auto id : output_ids) {
|
|
|
|
|
ret << dtypes[id] << "* " << ArgName(id);
|
|
|
|
|
ret << dtypes.at(id) << "* " << ArgName(id);
|
|
|
|
|
if (index != output_ids.size() - 1) {
|
|
|
|
|
ret << ", ";
|
|
|
|
|
}
|
|
|
|
@ -238,7 +240,7 @@ std::string CodeGenerator::EmitParameters(
|
|
|
|
|
std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
const std::vector<OperationExpression>& expressions,
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
std::unordered_map<int, std::string> dtypes) {
|
|
|
|
|
const std::unordered_map<int, std::string>& dtypes) const {
|
|
|
|
|
std::ostringstream compute;
|
|
|
|
|
std::unordered_set<int> used;
|
|
|
|
|
for (size_t i = 0; i < expressions.size(); i++) {
|
|
|
|
@ -251,7 +253,8 @@ std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end() &&
|
|
|
|
|
used.find(id) != used.end()) {
|
|
|
|
|
load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";";
|
|
|
|
|
load << dtypes.at(id) << " " << TmpName(id) << " = " << VarName(id)
|
|
|
|
|
<< ";";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Store temporal variables to memory.
|
|
|
|
|