|
|
|
@ -23,45 +23,46 @@
|
|
|
|
|
#include "paddle/fluid/pybind/pybind.h"
|
|
|
|
|
#include "paddle/fluid/string/string_helper.h"
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
const char* OUT_INITIALIZER_TEMPLATE =
|
|
|
|
|
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}})";
|
|
|
|
|
|
|
|
|
|
const char* OP_FUNCTION_TEMPLATE =
|
|
|
|
|
R"([](const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs,
|
|
|
|
|
imperative::NameVarBaseMap outs, const std::map<std::string, size_t>& out_nums)
|
|
|
|
|
{
|
|
|
|
|
auto tracer = imperative::GetCurrentTracer();
|
|
|
|
|
if (outs.size() == 0) {
|
|
|
|
|
if (out_nums.size() == 0) {
|
|
|
|
|
imperative::NameVarBaseMap outs_ = %s;
|
|
|
|
|
outs = std::move(outs_);
|
|
|
|
|
} else {
|
|
|
|
|
for (auto &pair : out_nums) {
|
|
|
|
|
for (size_t i = 0; i < pair.second; i ++) {
|
|
|
|
|
auto var_base_name = tracer->GenerateUniqueName();
|
|
|
|
|
auto out = new imperative::VarBase(var_base_name);
|
|
|
|
|
outs[pair.first].emplace_back(std::shared_ptr<imperative::VarBase>(out));
|
|
|
|
|
}
|
|
|
|
|
R"(
|
|
|
|
|
inline imperative::NameVarBaseMap %s(const imperative::NameVarBaseMap& ins, const framework::AttributeMap& attrs,
|
|
|
|
|
imperative::NameVarBaseMap outs, const std::map<std::string, size_t>& out_nums)
|
|
|
|
|
{
|
|
|
|
|
auto tracer = imperative::GetCurrentTracer();
|
|
|
|
|
if (outs.size() == 0) {
|
|
|
|
|
if (out_nums.size() == 0) {
|
|
|
|
|
imperative::NameVarBaseMap outs_ = %s;
|
|
|
|
|
outs = std::move(outs_);
|
|
|
|
|
} else {
|
|
|
|
|
for (auto &pair : out_nums) {
|
|
|
|
|
for (size_t i = 0; i < pair.second; i ++) {
|
|
|
|
|
auto var_base_name = tracer->GenerateUniqueName();
|
|
|
|
|
outs[pair.first].emplace_back(new imperative::VarBase(var_base_name));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs));
|
|
|
|
|
return outs;
|
|
|
|
|
}
|
|
|
|
|
}, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(),
|
|
|
|
|
py::arg("outs")=imperative::NameVarBaseMap(),
|
|
|
|
|
py::arg("out_nums")=std::map<std::string, size_t>())";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tracer->TraceOp("%s", std::move(ins), std::move(outs), std::move(attrs));
|
|
|
|
|
return outs;
|
|
|
|
|
})";
|
|
|
|
|
|
|
|
|
|
const char* PYBIND_ITEM_TEMPLATE = R"( %s.def("%s", %s);)";
|
|
|
|
|
const char* PYBIND_ITEM_TEMPLATE =
|
|
|
|
|
R"(
|
|
|
|
|
%s.def("%s", &%s, py::arg("ins"), py::arg("attrs")=framework::AttributeMap(), py::arg("outs")=imperative::NameVarBaseMap(),
|
|
|
|
|
py::arg("out_nums")=std::map<std::string, size_t>(), py::call_guard<py::gil_scoped_release>());)";
|
|
|
|
|
|
|
|
|
|
static std::vector<std::string> GenerateOpFunctions(
|
|
|
|
|
const std::string& module_name) {
|
|
|
|
|
// clang-format on
|
|
|
|
|
|
|
|
|
|
static std::tuple<std::vector<std::string>, std::vector<std::string>>
|
|
|
|
|
GenerateOpFunctions(const std::string& module_name) {
|
|
|
|
|
auto& op_info_map = paddle::framework::OpInfoMap::Instance().map();
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> op_function_list;
|
|
|
|
|
std::vector<std::string> op_function_list, bind_function_list;
|
|
|
|
|
for (auto& pair : op_info_map) {
|
|
|
|
|
auto& op_info = pair.second;
|
|
|
|
|
auto op_proto = op_info.proto_;
|
|
|
|
@ -85,18 +86,21 @@ static std::vector<std::string> GenerateOpFunctions(
|
|
|
|
|
}
|
|
|
|
|
outs_initializer += "}";
|
|
|
|
|
|
|
|
|
|
std::string func_name = "imperative_" + op_type;
|
|
|
|
|
|
|
|
|
|
// generate op funtcion body
|
|
|
|
|
auto op_function_str = paddle::string::Sprintf(OP_FUNCTION_TEMPLATE,
|
|
|
|
|
outs_initializer, op_type);
|
|
|
|
|
auto op_function_str = paddle::string::Sprintf(
|
|
|
|
|
OP_FUNCTION_TEMPLATE, func_name, outs_initializer, op_type);
|
|
|
|
|
|
|
|
|
|
// generate pybind item
|
|
|
|
|
auto pybind_op_function = paddle::string::Sprintf(
|
|
|
|
|
PYBIND_ITEM_TEMPLATE, module_name.c_str(), op_type, op_function_str);
|
|
|
|
|
pybind_op_function += "\n";
|
|
|
|
|
op_function_list.emplace_back(std::move(pybind_op_function));
|
|
|
|
|
auto bind_function_str = paddle::string::Sprintf(
|
|
|
|
|
PYBIND_ITEM_TEMPLATE, module_name, op_type, func_name);
|
|
|
|
|
|
|
|
|
|
op_function_list.emplace_back(std::move(op_function_str));
|
|
|
|
|
bind_function_list.emplace_back(std::move(bind_function_str));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return op_function_list;
|
|
|
|
|
return std::make_tuple(op_function_list, bind_function_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int main(int argc, char* argv[]) {
|
|
|
|
@ -115,19 +119,21 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
out << "#include " + header + "\n";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// all op functions
|
|
|
|
|
auto op_funcs = GenerateOpFunctions("m");
|
|
|
|
|
|
|
|
|
|
out << "namespace py = pybind11;"
|
|
|
|
|
<< "\n";
|
|
|
|
|
out << "namespace paddle {\n"
|
|
|
|
|
<< "namespace pybind {\n"
|
|
|
|
|
<< "\n"
|
|
|
|
|
<< "inline void BindOpFunctions(pybind11::module *module) {\n"
|
|
|
|
|
<< " auto m = module->def_submodule(\"ops\");\n\n";
|
|
|
|
|
|
|
|
|
|
// all op functions
|
|
|
|
|
auto op_funcs = GenerateOpFunctions("m");
|
|
|
|
|
<< "namespace pybind {\n";
|
|
|
|
|
out << paddle::string::join_strings(std::get<0>(op_funcs), '\n');
|
|
|
|
|
out << "\n\n";
|
|
|
|
|
|
|
|
|
|
out << paddle::string::join_strings(op_funcs, '\n');
|
|
|
|
|
out << "inline void BindOpFunctions(pybind11::module *module) {\n"
|
|
|
|
|
<< " auto m = module->def_submodule(\"ops\");\n\n";
|
|
|
|
|
|
|
|
|
|
out << paddle::string::join_strings(std::get<1>(op_funcs), '\n');
|
|
|
|
|
out << "\n";
|
|
|
|
|
out << "}\n\n"
|
|
|
|
|
<< "} // namespace pybind\n"
|
|
|
|
|
<< "} // namespace paddle\n";
|
|
|
|
|