|
|
@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include <Python.h>
|
|
|
|
#include <Python.h>
|
|
|
|
#include <paddle/framework/op_registry.h>
|
|
|
|
|
|
|
|
#include <paddle/framework/operator.h>
|
|
|
|
|
|
|
|
#include <paddle/framework/scope.h>
|
|
|
|
|
|
|
|
#include <paddle/pybind/tensor_bind.h>
|
|
|
|
|
|
|
|
#include <pybind11/numpy.h>
|
|
|
|
|
|
|
|
#include <pybind11/pybind11.h>
|
|
|
|
|
|
|
|
#include <pybind11/stl.h>
|
|
|
|
|
|
|
|
#include <fstream>
|
|
|
|
#include <fstream>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
|
|
|
#include "paddle/pybind/tensor_bind.h"
|
|
|
|
|
|
|
|
#include "pybind11/numpy.h"
|
|
|
|
|
|
|
|
#include "pybind11/pybind11.h"
|
|
|
|
|
|
|
|
#include "pybind11/stl.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace pd = paddle::framework;
|
|
|
|
namespace pd = paddle::framework;
|
|
|
@ -29,6 +30,17 @@ namespace pd = paddle::framework;
|
|
|
|
USE_OP(add_two);
|
|
|
|
USE_OP(add_two);
|
|
|
|
USE_OP_WITHOUT_KERNEL(fc);
|
|
|
|
USE_OP_WITHOUT_KERNEL(fc);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename ClassType>
|
|
|
|
|
|
|
|
void ExposeOperator(ClassType& m) {
|
|
|
|
|
|
|
|
m.def("infer_shape", &ClassType::type::InferShape)
|
|
|
|
|
|
|
|
.def("run", &ClassType::type::Run)
|
|
|
|
|
|
|
|
.def("outputs",
|
|
|
|
|
|
|
|
[](const typename ClassType::type& op) -> std::vector<std::string> {
|
|
|
|
|
|
|
|
return op.outputs_;
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("__str__", &ClassType::type::DebugString);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PYBIND11_PLUGIN(core) {
|
|
|
|
PYBIND11_PLUGIN(core) {
|
|
|
|
py::module m("core", "C++ core of Paddle Paddle");
|
|
|
|
py::module m("core", "C++ core of Paddle Paddle");
|
|
|
|
|
|
|
|
|
|
|
@ -107,21 +119,36 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
return new paddle::platform::CPUDeviceContext();
|
|
|
|
return new paddle::platform::CPUDeviceContext();
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<pd::OperatorBase, pd::OperatorPtr>(m, "Operator")
|
|
|
|
py::class_<pd::OperatorBase, pd::OperatorPtr> operator_base(m, "Operator");
|
|
|
|
.def("__str__", &pd::OperatorBase::DebugString)
|
|
|
|
|
|
|
|
.def_static("create",
|
|
|
|
operator_base.def_static("create", [](py::bytes protobin) -> pd::OperatorPtr {
|
|
|
|
[](py::bytes protobin) {
|
|
|
|
pd::OpDesc desc;
|
|
|
|
pd::OpDesc desc;
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
"Cannot parse user input to OpDesc");
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
PADDLE_ENFORCE(desc.IsInitialized(),
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
"User OpDesc is not initialized, reason %s",
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
return pd::OpRegistry::CreateOp(desc);
|
|
|
|
return pd::OpRegistry::CreateOp(desc);
|
|
|
|
});
|
|
|
|
})
|
|
|
|
ExposeOperator(operator_base);
|
|
|
|
.def("infer_shape", &pd::OperatorBase::InferShape)
|
|
|
|
|
|
|
|
.def("run", &pd::OperatorBase::Run)
|
|
|
|
using PlainNetPtr = std::shared_ptr<pd::PlainNet>;
|
|
|
|
.def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; });
|
|
|
|
py::class_<pd::PlainNet, PlainNetPtr> net(m, "Net");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
net.def_static("create",
|
|
|
|
|
|
|
|
[]() -> std::shared_ptr<pd::PlainNet> {
|
|
|
|
|
|
|
|
auto retv = std::make_shared<pd::PlainNet>();
|
|
|
|
|
|
|
|
retv->type_ = "naive_net";
|
|
|
|
|
|
|
|
return retv;
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("add_op", &pd::PlainNet::AddOp)
|
|
|
|
|
|
|
|
.def("add_op",
|
|
|
|
|
|
|
|
[](PlainNetPtr& self, const PlainNetPtr& net) -> void {
|
|
|
|
|
|
|
|
self->AddOp(std::static_pointer_cast<pd::OperatorBase>(net));
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
.def("complete_add_op", &pd::PlainNet::CompleteAddOp)
|
|
|
|
|
|
|
|
.def("complete_add_op", [](PlainNetPtr& self) { self->CompleteAddOp(); });
|
|
|
|
|
|
|
|
ExposeOperator(net);
|
|
|
|
|
|
|
|
|
|
|
|
return m.ptr();
|
|
|
|
return m.ptr();
|
|
|
|
}
|
|
|
|
}
|
|
|
|