|
|
|
@ -116,8 +116,19 @@ const char* OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST = R"(
|
|
|
|
|
const char* ARG_OUT_NUM = R"(%sNum)";
|
|
|
|
|
const char* ARG_OUT_NUM_TYPE = R"(size_t )";
|
|
|
|
|
|
|
|
|
|
const char* VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
|
|
|
|
|
const char* VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
|
|
|
|
|
const char* IN_VAR_TYPE = R"(py::handle)";
|
|
|
|
|
const char* IN_VAR_LIST_TYPE = R"(py::handle)";
|
|
|
|
|
|
|
|
|
|
const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
|
|
|
|
|
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";
|
|
|
|
|
|
|
|
|
|
const char* CAST_VAR_TEMPLATE = R"(
|
|
|
|
|
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)";
|
|
|
|
|
|
|
|
|
|
const char* CAST_VAR_LIST_TEMPLATE = R"(
|
|
|
|
|
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const char* ARG_TEMPLATE = R"(const %s& %s)";
|
|
|
|
|
|
|
|
|
|
const char* RETURN_TUPLE_TYPE = R"(std::tuple<%s>)";
|
|
|
|
@ -133,6 +144,7 @@ const char* OP_FUNCTION_TEMPLATE =
|
|
|
|
|
R"(
|
|
|
|
|
%s %s(%s)
|
|
|
|
|
{
|
|
|
|
|
%s
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
ConstructAttrMapFromPyArgs(&attrs, args);
|
|
|
|
|
{
|
|
|
|
@ -164,6 +176,10 @@ static inline bool FindPassingOutsMap(const std::string& op_type,
|
|
|
|
|
return op_passing_outs_map[op_type].count(out_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static inline std::string TempName(const std::string& name) {
|
|
|
|
|
return name + '_';
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::tuple<std::vector<std::string>, std::vector<std::string>>
|
|
|
|
|
GenerateOpFunctions(const std::string& module_name) {
|
|
|
|
|
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
|
|
|
|
@ -187,16 +203,24 @@ GenerateOpFunctions(const std::string& module_name) {
|
|
|
|
|
std::string ins_initializer = "{";
|
|
|
|
|
std::string ins_initializer_with_null = "";
|
|
|
|
|
std::string py_arg = "";
|
|
|
|
|
int arg_idx = 0;
|
|
|
|
|
std::string ins_cast_str = "";
|
|
|
|
|
for (auto& input : op_proto->inputs()) {
|
|
|
|
|
auto& in_name = input.name();
|
|
|
|
|
// skip those dispensable inputs, like ResidualData in conv2d
|
|
|
|
|
if (input.dispensable() && !FindInsMap(op_type, in_name)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const auto in_type = input.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
|
|
|
|
|
auto input_arg = paddle::string::Sprintf(ARG_TEMPLATE, in_type, in_name);
|
|
|
|
|
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
|
|
|
|
|
auto input_arg =
|
|
|
|
|
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
|
|
|
|
|
input_args += input_arg;
|
|
|
|
|
input_args += ",";
|
|
|
|
|
const auto in_cast_type =
|
|
|
|
|
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
|
|
|
|
|
ins_cast_str +=
|
|
|
|
|
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
|
|
|
|
|
arg_idx++, TempName(in_name));
|
|
|
|
|
|
|
|
|
|
if (input.dispensable()) {
|
|
|
|
|
const auto in_template = input.duplicable()
|
|
|
|
@ -235,7 +259,8 @@ GenerateOpFunctions(const std::string& module_name) {
|
|
|
|
|
if (output.dispensable() && !FindOutsMap(op_type, out_name)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const auto out_type = output.duplicable() ? VAR_LIST_TYPE : VAR_TYPE;
|
|
|
|
|
const auto out_type =
|
|
|
|
|
output.duplicable() ? OUT_VAR_LIST_TYPE : OUT_VAR_TYPE;
|
|
|
|
|
const auto return_template =
|
|
|
|
|
output.duplicable() ? RETURN_LIST_TEMPLATE : RETURN_TEMPLATE;
|
|
|
|
|
if (FindPassingOutsMap(op_type, out_name)) {
|
|
|
|
@ -309,7 +334,7 @@ GenerateOpFunctions(const std::string& module_name) {
|
|
|
|
|
// generate op funtcion body
|
|
|
|
|
auto op_function_str = paddle::string::Sprintf(
|
|
|
|
|
OP_FUNCTION_TEMPLATE, return_type, func_name, function_args,
|
|
|
|
|
outs_initializer, ins_initializer,
|
|
|
|
|
ins_cast_str, outs_initializer, ins_initializer,
|
|
|
|
|
ins_initializer_with_null + outs_initializer_with_null, op_type,
|
|
|
|
|
return_str);
|
|
|
|
|
|
|
|
|
|