|
|
|
@ -48,29 +48,6 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename ClassType>
|
|
|
|
|
void ExposeOperator(ClassType &m) {
|
|
|
|
|
m.def("infer_shape", &ClassType::type::InferShape)
|
|
|
|
|
.def("run", &ClassType::type::Run)
|
|
|
|
|
.def("type",
|
|
|
|
|
[](const typename ClassType::type &op) -> std::string {
|
|
|
|
|
return op.Type();
|
|
|
|
|
})
|
|
|
|
|
.def("outputs",
|
|
|
|
|
[](const typename ClassType::type &op)
|
|
|
|
|
-> std::map<std::string, std::vector<std::string>> {
|
|
|
|
|
return op.Outputs();
|
|
|
|
|
})
|
|
|
|
|
.def("inputs",
|
|
|
|
|
[](const typename ClassType::type &op) { return op.Inputs(); })
|
|
|
|
|
.def("__str__", &ClassType::type::DebugString)
|
|
|
|
|
.def("no_intermediate_outputs",
|
|
|
|
|
[](const typename ClassType::type &op) {
|
|
|
|
|
return op.OutputVars(false);
|
|
|
|
|
})
|
|
|
|
|
.def("support_gpu", &ClassType::type::SupportGPU);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static size_t UniqueIntegerGenerator() {
|
|
|
|
|
static std::atomic<size_t> generator;
|
|
|
|
|
return generator.fetch_add(1);
|
|
|
|
@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
.def(py::init<>())
|
|
|
|
|
.def("__str__", string::to_string<const platform::CPUPlace &>);
|
|
|
|
|
|
|
|
|
|
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
|
|
|
|
|
m, "Operator");
|
|
|
|
|
|
|
|
|
|
operator_base.def_static("create", [](py::bytes protobin) {
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
|
return OpRegistry::CreateOp(desc);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
operator_base.def("backward",
|
|
|
|
|
[](const OperatorBase &forwardOp,
|
|
|
|
|
const std::unordered_set<std::string> &no_grad_vars) {
|
|
|
|
|
return Backward(forwardOp, no_grad_vars);
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
ExposeOperator(operator_base);
|
|
|
|
|
|
|
|
|
|
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
|
|
|
|
|
|
|
|
|
|
net.def_static("create",
|
|
|
|
|
[]() -> std::shared_ptr<operators::NetOp> {
|
|
|
|
|
auto retv = std::make_shared<operators::NetOp>();
|
|
|
|
|
retv->SetType("plain_net");
|
|
|
|
|
return retv;
|
|
|
|
|
})
|
|
|
|
|
.def("add_op", &operators::NetOp::AddOp)
|
|
|
|
|
.def("add_op",
|
|
|
|
|
[](operators::NetOp &self,
|
|
|
|
|
const std::shared_ptr<operators::NetOp> &net) -> void {
|
|
|
|
|
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
|
|
|
|
|
})
|
|
|
|
|
.def("add_op",
|
|
|
|
|
[](operators::NetOp &self,
|
|
|
|
|
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
|
|
|
|
|
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
|
|
|
|
|
py::class_<OperatorBase>(m, "Operator")
|
|
|
|
|
.def_static("create",
|
|
|
|
|
[](py::bytes protobin) {
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
|
return OpRegistry::CreateOp(desc);
|
|
|
|
|
})
|
|
|
|
|
.def("backward",
|
|
|
|
|
[](const OperatorBase &forwardOp,
|
|
|
|
|
const std::unordered_set<std::string> &no_grad_vars) {
|
|
|
|
|
return Backward(forwardOp, no_grad_vars).release();
|
|
|
|
|
})
|
|
|
|
|
.def("infer_shape", &OperatorBase::InferShape)
|
|
|
|
|
.def("run", &OperatorBase::Run)
|
|
|
|
|
.def("type",
|
|
|
|
|
[](const OperatorBase &op) -> std::string { return op.Type(); })
|
|
|
|
|
.def("outputs",
|
|
|
|
|
[](const OperatorBase &op)
|
|
|
|
|
-> std::map<std::string, std::vector<std::string>> {
|
|
|
|
|
return op.Outputs();
|
|
|
|
|
})
|
|
|
|
|
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
|
|
|
|
|
.def("__str__", &OperatorBase::DebugString)
|
|
|
|
|
.def("no_intermediate_outputs",
|
|
|
|
|
[](const OperatorBase &op) { return op.OutputVars(false); })
|
|
|
|
|
.def("support_gpu", &OperatorBase::SupportGPU);
|
|
|
|
|
|
|
|
|
|
py::class_<operators::NetOp, OperatorBase>(m, "Net")
|
|
|
|
|
.def_static("create",
|
|
|
|
|
[]() -> operators::NetOp * {
|
|
|
|
|
auto *retv = new operators::NetOp;
|
|
|
|
|
retv->SetType("plain_net");
|
|
|
|
|
return retv;
|
|
|
|
|
})
|
|
|
|
|
.def("add_op", [](operators::NetOp &self,
|
|
|
|
|
const OperatorBase &op) { self.AddOp(op); })
|
|
|
|
|
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
|
|
|
|
|
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
|
|
|
|
|
self->CompleteAddOp();
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
ExposeOperator(net);
|
|
|
|
|
|
|
|
|
|
// recurrent_op
|
|
|
|
|
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
|
|
|
|
|
rnn(m, "RecurrentOp");
|
|
|
|
|
|
|
|
|
|
rnn.def_static(
|
|
|
|
|
"create",
|
|
|
|
|
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
|
auto rnn_op = OpRegistry::CreateOp(desc);
|
|
|
|
|
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
|
|
|
|
|
})
|
|
|
|
|
.def("set_stepnet",
|
|
|
|
|
[](operators::RecurrentOp &self,
|
|
|
|
|
const std::shared_ptr<operators::NetOp> &net) -> void {
|
|
|
|
|
self.set_stepnet(net);
|
|
|
|
|
});
|
|
|
|
|
ExposeOperator(rnn);
|
|
|
|
|
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
|
|
|
|
|
.def_static(
|
|
|
|
|
"create",
|
|
|
|
|
[](py::bytes protobin) -> operators::RecurrentOp * {
|
|
|
|
|
OpDesc desc;
|
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
|
auto rnn_op = OpRegistry::CreateOp(desc);
|
|
|
|
|
return static_cast<operators::RecurrentOp *>(rnn_op.release());
|
|
|
|
|
})
|
|
|
|
|
.def("set_stepnet", [](operators::RecurrentOp &self,
|
|
|
|
|
const operators::NetOp &net) -> void {
|
|
|
|
|
self.set_stepnet(net.Clone());
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
m.def("unique_integer", UniqueIntegerGenerator);
|
|
|
|
|
|
|
|
|
|