|
|
@ -16,6 +16,7 @@ limitations under the License. */
|
|
|
|
#include <fstream>
|
|
|
|
#include <fstream>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
@ -45,6 +46,10 @@ template <typename ClassType>
|
|
|
|
void ExposeOperator(ClassType& m) {
|
|
|
|
void ExposeOperator(ClassType& m) {
|
|
|
|
m.def("infer_shape", &ClassType::type::InferShape)
|
|
|
|
m.def("infer_shape", &ClassType::type::InferShape)
|
|
|
|
.def("run", &ClassType::type::Run)
|
|
|
|
.def("run", &ClassType::type::Run)
|
|
|
|
|
|
|
|
.def("type",
|
|
|
|
|
|
|
|
[](const typename ClassType::type& op) -> std::string {
|
|
|
|
|
|
|
|
return op.type_;
|
|
|
|
|
|
|
|
})
|
|
|
|
.def("outputs",
|
|
|
|
.def("outputs",
|
|
|
|
[](const typename ClassType::type& op) -> std::vector<std::string> {
|
|
|
|
[](const typename ClassType::type& op) -> std::vector<std::string> {
|
|
|
|
return op.outputs_;
|
|
|
|
return op.outputs_;
|
|
|
@ -192,6 +197,13 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
desc.InitializationErrorString());
|
|
|
|
return pd::OpRegistry::CreateOp(desc);
|
|
|
|
return pd::OpRegistry::CreateOp(desc);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
operator_base.def("backward",
|
|
|
|
|
|
|
|
[](const pd::OperatorBase& forwardOp,
|
|
|
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
|
|
|
return pd::Backward(forwardOp, no_grad_vars);
|
|
|
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
|
|
ExposeOperator(operator_base);
|
|
|
|
ExposeOperator(operator_base);
|
|
|
|
|
|
|
|
|
|
|
|
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
|
|
|
|
py::class_<pd::NetOp, std::shared_ptr<pd::NetOp>> net(m, "Net");
|
|
|
|