|
|
|
@ -18,11 +18,8 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
#include "paddle/framework/tensor_py.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
#include "paddle/operators/type_alias.h"
|
|
|
|
|
#include "paddle/platform/enforce.h"
|
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
|
#include "pybind11/numpy.h"
|
|
|
|
@ -45,6 +42,9 @@ USE_OP_WITHOUT_KERNEL(recurrent_op);
|
|
|
|
|
USE_OP(uniform_random);
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename ClassType>
|
|
|
|
|
void ExposeOperator(ClassType &m) {
|
|
|
|
|
m.def("infer_shape", &ClassType::type::InferShape)
|
|
|
|
@ -150,8 +150,8 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
.def("get_net",
|
|
|
|
|
[](Variable &self) -> ops::NetOp * {
|
|
|
|
|
return self.GetMutable<ops::NetOp>();
|
|
|
|
|
[](Variable &self) -> operators::NetOp * {
|
|
|
|
|
return self.GetMutable<operators::NetOp>();
|
|
|
|
|
},
|
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
|
|
|
|
|
@ -230,23 +230,24 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
|
|
|
|
|
ExposeOperator(operator_base);
|
|
|
|
|
|
|
|
|
|
py::class_<ops::NetOp, std::shared_ptr<ops::NetOp>> net(m, "Net");
|
|
|
|
|
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
|
|
|
|
|
|
|
|
|
|
net.def_static("create",
|
|
|
|
|
[]() -> std::shared_ptr<ops::NetOp> {
|
|
|
|
|
auto retv = std::make_shared<ops::NetOp>();
|
|
|
|
|
[]() -> std::shared_ptr<operators::NetOp> {
|
|
|
|
|
auto retv = std::make_shared<operators::NetOp>();
|
|
|
|
|
retv->type_ = "plain_net";
|
|
|
|
|
return retv;
|
|
|
|
|
})
|
|
|
|
|
.def("add_op", &ops::NetOp::AddOp)
|
|
|
|
|
.def(
|
|
|
|
|
"add_op",
|
|
|
|
|
[](ops::NetOp &self, const std::shared_ptr<ops::NetOp> &net) -> void {
|
|
|
|
|
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
|
|
|
|
|
})
|
|
|
|
|
.def("complete_add_op", &ops::NetOp::CompleteAddOp)
|
|
|
|
|
.def("complete_add_op",
|
|
|
|
|
[](std::shared_ptr<ops::NetOp> &self) { self->CompleteAddOp(); });
|
|
|
|
|
.def("add_op", &operators::NetOp::AddOp)
|
|
|
|
|
.def("add_op",
|
|
|
|
|
[](operators::NetOp &self,
|
|
|
|
|
const std::shared_ptr<operators::NetOp> &net) -> void {
|
|
|
|
|
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
|
|
|
|
|
})
|
|
|
|
|
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
|
|
|
|
|
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
|
|
|
|
|
self->CompleteAddOp();
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
ExposeOperator(net);
|
|
|
|
|
|
|
|
|
|