|
|
|
@ -18,7 +18,14 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "glog/logging.h"
|
|
|
|
|
#include "paddle/fluid/framework/block_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_call_stack.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_proto_maker.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/framework/program_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/shape_inference.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_type_inference.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -33,7 +40,7 @@ static T StringTo(const std::string& str) {
|
|
|
|
|
return value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string ExpandMultivariateTemplate(const std::string rhs,
|
|
|
|
|
static std::string ExpandMultivariateTemplate(const std::string& rhs,
|
|
|
|
|
const size_t input_size) {
|
|
|
|
|
int start_pos = rhs.find("[", 0);
|
|
|
|
|
int end_pos = rhs.find("]", 0);
|
|
|
|
@ -50,6 +57,66 @@ static std::string ExpandMultivariateTemplate(const std::string rhs,
|
|
|
|
|
return sum_rhs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string RefineTemplateWithAttr(const std::string& op_type,
|
|
|
|
|
const std::string& exp_definition,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
std::string ret;
|
|
|
|
|
// here str_cvt convert string to number in some attr
|
|
|
|
|
// for example in fill_constant str_value
|
|
|
|
|
std::stringstream str_cvt;
|
|
|
|
|
auto IsNumber = [exp_definition]() -> bool {
|
|
|
|
|
return exp_definition.find_first_not_of("0123456789") == std::string::npos;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
if (!IsNumber()) {
|
|
|
|
|
// Get attr with different type, Now we only support the simple attr
|
|
|
|
|
// condition
|
|
|
|
|
std::string attr_name, default_value;
|
|
|
|
|
if (exp_definition.find("=") != std::string::npos) {
|
|
|
|
|
attr_name = exp_definition.substr(0, exp_definition.find("="));
|
|
|
|
|
default_value = exp_definition.substr(exp_definition.rfind("=") + 1,
|
|
|
|
|
exp_definition.length() - 1);
|
|
|
|
|
ret = default_value;
|
|
|
|
|
} else {
|
|
|
|
|
attr_name = exp_definition;
|
|
|
|
|
}
|
|
|
|
|
auto it = attrs.find(attr_name);
|
|
|
|
|
if (it == attrs.end()) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
Attribute attr = it->second;
|
|
|
|
|
proto::AttrType attr_type =
|
|
|
|
|
static_cast<proto::AttrType>(it->second.which() - 1);
|
|
|
|
|
if (attr_type == proto::AttrType::BOOLEAN) {
|
|
|
|
|
bool result = boost::get<bool>(attr);
|
|
|
|
|
if (result) {
|
|
|
|
|
ret = "true";
|
|
|
|
|
} else {
|
|
|
|
|
ret = "false";
|
|
|
|
|
}
|
|
|
|
|
} else if (attr_type == proto::AttrType::INT) {
|
|
|
|
|
int result = boost::get<int>(attr);
|
|
|
|
|
str_cvt << result;
|
|
|
|
|
ret = str_cvt.str();
|
|
|
|
|
} else if (attr_type == proto::AttrType::LONG) {
|
|
|
|
|
int64_t result = boost::get<int64_t>(attr);
|
|
|
|
|
str_cvt << result;
|
|
|
|
|
ret = str_cvt.str();
|
|
|
|
|
} else if (attr_type == proto::AttrType::FLOAT) {
|
|
|
|
|
float result = boost::get<float>(attr);
|
|
|
|
|
str_cvt << result;
|
|
|
|
|
ret = str_cvt.str();
|
|
|
|
|
} else if (attr_type == proto::AttrType::STRING) {
|
|
|
|
|
std::string result = boost::get<std::string>(attr);
|
|
|
|
|
ret = result;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
ret = exp_definition;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// In order to avoid multiple __half2float function calls, we do this
|
|
|
|
|
// optimization
|
|
|
|
|
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
|
|
|
|
@ -74,7 +141,6 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
|
|
|
|
|
size_t input_size = input_ids_.size();
|
|
|
|
|
rhs = ExpandMultivariateTemplate(rhs, input_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < rhs.size(); i++) {
|
|
|
|
|
size_t pos = i;
|
|
|
|
|
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
|
|
|
|
@ -83,28 +149,36 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
|
|
|
|
|
length++;
|
|
|
|
|
}
|
|
|
|
|
std::string index_str = rhs.substr(pos + 2, length);
|
|
|
|
|
int index = StringTo<int>(index_str);
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
index, input_ids_.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only %d inputs are provided, but need %d for operation < %s >.",
|
|
|
|
|
input_ids_.size(), index + 1, op_type_));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
input_ids_[index], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected %d-th input id > 0 for operation < %s >. Received %d.",
|
|
|
|
|
index, op_type_, input_ids_[index]));
|
|
|
|
|
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we need
|
|
|
|
|
// to add general fp16 compute later.
|
|
|
|
|
std::string var_name;
|
|
|
|
|
if (rhs_type_ == "float16") {
|
|
|
|
|
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
|
|
|
|
|
var_name = "half2fp32_" + TmpName(input_ids_[index]);
|
|
|
|
|
std::string refine_str =
|
|
|
|
|
RefineTemplateWithAttr(op_type_, index_str, attr_);
|
|
|
|
|
if (index_str == refine_str) {
|
|
|
|
|
int index = StringTo<int>(index_str);
|
|
|
|
|
PADDLE_ENFORCE_LT(index, input_ids_.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only %d inputs are provided, but need %d for "
|
|
|
|
|
"operation < %s >.",
|
|
|
|
|
input_ids_.size(), index + 1, op_type_));
|
|
|
|
|
PADDLE_ENFORCE_GE(input_ids_[index], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expected %d-th input id > 0 for operation < %s "
|
|
|
|
|
">. Received %d.",
|
|
|
|
|
index, op_type_, input_ids_[index]));
|
|
|
|
|
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
|
|
|
|
|
// need
|
|
|
|
|
// to add general fp16 compute later.
|
|
|
|
|
std::string var_name;
|
|
|
|
|
if (rhs_type_ == "float16") {
|
|
|
|
|
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
|
|
|
|
|
var_name = "half2fp32_" + TmpName(input_ids_[index]);
|
|
|
|
|
} else {
|
|
|
|
|
var_name = TmpName(input_ids_[index]);
|
|
|
|
|
}
|
|
|
|
|
rhs.replace(pos, length + 3, var_name);
|
|
|
|
|
used->insert(input_ids_[index]);
|
|
|
|
|
} else {
|
|
|
|
|
var_name = TmpName(input_ids_[index]);
|
|
|
|
|
std::string var_name = refine_str;
|
|
|
|
|
rhs.replace(pos, length + 3, var_name);
|
|
|
|
|
}
|
|
|
|
|
rhs.replace(pos, length + 3, var_name);
|
|
|
|
|
used->insert(input_ids_[index]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return rhs;
|
|
|
|
|