|
|
|
@ -24,6 +24,21 @@ namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
namespace fusion_group {
|
|
|
|
|
|
|
|
|
|
std::string ExtractDataType(const std::vector<Node*> nodes) {
|
|
|
|
|
std::string dtype_str = "float";
|
|
|
|
|
auto data_type = nodes.back()->Var()->GetDataType();
|
|
|
|
|
|
|
|
|
|
if (data_type == proto::VarType::FP32) {
|
|
|
|
|
dtype_str = "float";
|
|
|
|
|
} else if (data_type == proto::VarType::FP64) {
|
|
|
|
|
dtype_str = "double";
|
|
|
|
|
} else if (data_type == proto::VarType::FP16) {
|
|
|
|
|
dtype_str = "float16";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return dtype_str;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CodeGenerator::CodeGenerator() {
|
|
|
|
|
// Only support elementwise operations now.
|
|
|
|
|
code_templates_.resize(1);
|
|
|
|
@ -34,8 +49,7 @@ CodeGenerator::CodeGenerator() {
|
|
|
|
|
|
|
|
|
|
std::string CodeGenerator::Generate(SubGraph* subgraph) {
|
|
|
|
|
std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
|
|
|
|
|
return Generate(subgraph->GetFuncName(), subgraph->GetDataType(),
|
|
|
|
|
expressions);
|
|
|
|
|
return Generate(subgraph->GetFuncName(), expressions);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool HasInput(Node* n, std::string name) {
|
|
|
|
@ -95,8 +109,11 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
"Output(%s) of operation %s is not set.", name, op->Type()));
|
|
|
|
|
output_ids.push_back(var_ids[op->Output(name)[0]]);
|
|
|
|
|
}
|
|
|
|
|
expressions.push_back(
|
|
|
|
|
OperationExpression(node->Name(), input_ids, output_ids));
|
|
|
|
|
|
|
|
|
|
std::string lhs_type = ExtractDataType(node->outputs);
|
|
|
|
|
std::string rhs_type = ExtractDataType(node->inputs);
|
|
|
|
|
expressions.emplace_back(OperationExpression(
|
|
|
|
|
node->Name(), input_ids, output_ids, rhs_type, lhs_type));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return expressions;
|
|
|
|
@ -105,25 +122,32 @@ std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
|
|
|
|
|
// In order to get the right result of expression, we need to calculate and
|
|
|
|
|
// store the expression as suffix Expressions using vector.
|
|
|
|
|
std::string CodeGenerator::Generate(
|
|
|
|
|
std::string func_name, std::string dtype,
|
|
|
|
|
std::string func_name,
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
// TODO(liuyiqun): Check whether all expressions are elementwise operations.
|
|
|
|
|
std::set<int> input_ids = DistilInputIds(expressions);
|
|
|
|
|
std::set<int> output_ids = DistilOutputIds(expressions);
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::string> dtypes = DistilDtypes(expressions);
|
|
|
|
|
TemplateVariable template_var;
|
|
|
|
|
template_var.Add("func_name", func_name);
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtype));
|
|
|
|
|
template_var.Add("parameters", EmitParameters(input_ids, output_ids, dtypes));
|
|
|
|
|
template_var.Add("compute_body",
|
|
|
|
|
EmitComputeBody(expressions, input_ids, output_ids, dtype));
|
|
|
|
|
EmitComputeBody(expressions, input_ids, output_ids, dtypes));
|
|
|
|
|
|
|
|
|
|
std::string predefined_cuda_functions;
|
|
|
|
|
if (dtype == "float") {
|
|
|
|
|
predefined_cuda_functions = predefined_cuda_functions_fp32;
|
|
|
|
|
} else if (dtype == "double") {
|
|
|
|
|
predefined_cuda_functions = predefined_cuda_functions_fp64;
|
|
|
|
|
} else if (dtype == "float16") {
|
|
|
|
|
predefined_cuda_functions = predefined_cuda_functions_fp16;
|
|
|
|
|
std::set<std::string> all_dtype;
|
|
|
|
|
for (const auto& type : dtypes) {
|
|
|
|
|
all_dtype.insert(type.second);
|
|
|
|
|
}
|
|
|
|
|
std::string predefined_cuda_functions = "";
|
|
|
|
|
if (all_dtype.find("float") != all_dtype.end() &&
|
|
|
|
|
all_dtype.find("float16") == all_dtype.end()) {
|
|
|
|
|
predefined_cuda_functions += predefined_cuda_functions_fp32;
|
|
|
|
|
}
|
|
|
|
|
if (all_dtype.find("double") != all_dtype.end()) {
|
|
|
|
|
predefined_cuda_functions += predefined_cuda_functions_fp64;
|
|
|
|
|
}
|
|
|
|
|
if (all_dtype.find("float16") != all_dtype.end()) {
|
|
|
|
|
predefined_cuda_functions += predefined_cuda_functions_fp16;
|
|
|
|
|
}
|
|
|
|
|
return predefined_cuda_functions + code_templates_[0].Format(template_var);
|
|
|
|
|
}
|
|
|
|
@ -154,10 +178,40 @@ std::set<int> CodeGenerator::DistilOutputIds(
|
|
|
|
|
return output_ids;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
|
|
|
|
|
const std::vector<OperationExpression>& expressions) {
|
|
|
|
|
std::unordered_map<int, std::string> dtypes;
|
|
|
|
|
for (const auto& expression : expressions) {
|
|
|
|
|
for (auto id : expression.GetInputIds()) {
|
|
|
|
|
auto dtype = expression.GetRHSType();
|
|
|
|
|
if (dtypes.find(id) == dtypes.end()) {
|
|
|
|
|
dtypes[id] = dtype;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dtypes[id], dtype,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"In fusion group, Same Node id must have same date type"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto id : expression.GetOutputIds()) {
|
|
|
|
|
auto dtype = expression.GetLHSType();
|
|
|
|
|
if (dtypes.find(id) == dtypes.end()) {
|
|
|
|
|
dtypes[id] = dtype;
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
dtypes[id], dtype,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"In fusion group, Same Node id must have same date type"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return dtypes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// we get the parameter list code for the expression information
|
|
|
|
|
std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
|
|
|
|
|
const std::set<int>& output_ids,
|
|
|
|
|
std::string dtype) {
|
|
|
|
|
std::string CodeGenerator::EmitParameters(
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
std::unordered_map<int, std::string> dtypes) {
|
|
|
|
|
std::stringstream ret;
|
|
|
|
|
ret << "int N, ";
|
|
|
|
|
|
|
|
|
@ -165,13 +219,13 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
|
|
|
|
|
// from the input list.
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end()) {
|
|
|
|
|
ret << dtype << "* " << ArgName(id) << ", ";
|
|
|
|
|
ret << dtypes[id] << "* " << ArgName(id) << ", ";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
for (auto id : output_ids) {
|
|
|
|
|
ret << dtype << "* " << ArgName(id);
|
|
|
|
|
ret << dtypes[id] << "* " << ArgName(id);
|
|
|
|
|
if (index != output_ids.size() - 1) {
|
|
|
|
|
ret << ", ";
|
|
|
|
|
}
|
|
|
|
@ -184,13 +238,12 @@ std::string CodeGenerator::EmitParameters(const std::set<int>& input_ids,
|
|
|
|
|
std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
const std::vector<OperationExpression>& expressions,
|
|
|
|
|
const std::set<int>& input_ids, const std::set<int>& output_ids,
|
|
|
|
|
std::string dtype) {
|
|
|
|
|
std::unordered_map<int, std::string> dtypes) {
|
|
|
|
|
std::ostringstream compute;
|
|
|
|
|
std::unordered_set<int> used;
|
|
|
|
|
std::string compute_dtype = (dtype == "float16") ? "float" : dtype;
|
|
|
|
|
for (size_t i = 0; i < expressions.size(); i++) {
|
|
|
|
|
VLOG(3) << DebugString(expressions[i]);
|
|
|
|
|
compute << expressions[i].GetExpression(compute_dtype, &used);
|
|
|
|
|
compute << expressions[i].GetExpression(&used);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Load input to temporal variables.
|
|
|
|
@ -198,23 +251,13 @@ std::string CodeGenerator::EmitComputeBody(
|
|
|
|
|
for (auto id : input_ids) {
|
|
|
|
|
if (output_ids.find(id) == output_ids.end() &&
|
|
|
|
|
used.find(id) != used.end()) {
|
|
|
|
|
if (dtype == "float16") {
|
|
|
|
|
load << "float " << TmpName(id) << " = __half2float(" << ArgName(id)
|
|
|
|
|
<< "[idx]);";
|
|
|
|
|
} else {
|
|
|
|
|
load << dtype << " " << TmpName(id) << " = " << ArgName(id) << "[idx];";
|
|
|
|
|
}
|
|
|
|
|
load << dtypes[id] << " " << TmpName(id) << " = " << VarName(id) << ";";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Store temporal variables to memory.
|
|
|
|
|
std::ostringstream store;
|
|
|
|
|
for (auto id : output_ids) {
|
|
|
|
|
if (dtype == "float16") {
|
|
|
|
|
store << ArgName(id) << "[idx] = __float2half(" << TmpName(id) << ");";
|
|
|
|
|
} else {
|
|
|
|
|
store << ArgName(id) << "[idx] = " << TmpName(id) << ";";
|
|
|
|
|
}
|
|
|
|
|
store << VarName(id) << " = " << TmpName(id) << ";";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return load.str() + compute.str() + store.str();
|
|
|
|
|