|
|
|
@ -17,6 +17,9 @@
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "pybind11/pybind11.h"
|
|
|
|
|
|
|
|
|
|
#include "Python.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_desc.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
|
#include "paddle/fluid/framework/var_desc.h"
|
|
|
|
@ -25,6 +28,8 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
namespace py = ::pybind11;
|
|
|
|
|
|
|
|
|
|
class PreparedOp {
|
|
|
|
|
public:
|
|
|
|
|
PreparedOp(const framework::OperatorBase& op,
|
|
|
|
@ -152,10 +157,48 @@ class Layer {
|
|
|
|
|
std::vector<VarBase> vars;
|
|
|
|
|
return vars;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
virtual std::vector<VarBase> Backward(const std::vector<VarBase>& inputs) {
|
|
|
|
|
std::vector<VarBase> vars;
|
|
|
|
|
return vars;
|
|
|
|
|
static void CallPythonFunc(py::object* callable,
|
|
|
|
|
const std::vector<framework::LoDTensor>& ins,
|
|
|
|
|
std::vector<framework::LoDTensor*>* outs) {
|
|
|
|
|
py::gil_scoped_acquire guard;
|
|
|
|
|
py::tuple in_args(ins.size());
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ret = (*callable)(in_args);
|
|
|
|
|
auto ret_tuple = py::cast<py::tuple>(ret);
|
|
|
|
|
size_t ret_num = py::len(ret_tuple);
|
|
|
|
|
for (size_t i = 0; i < ret_num; ++i) {
|
|
|
|
|
try {
|
|
|
|
|
auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
|
|
|
|
|
"Output tensor %d should not be nullptr", i);
|
|
|
|
|
outs->push_back(py_out_tensor);
|
|
|
|
|
} catch (py::cast_error&) {
|
|
|
|
|
PADDLE_THROW("The %d-th output must be LoDTensor", i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class PyLayer {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~PyLayer() {}
|
|
|
|
|
|
|
|
|
|
static std::vector<VarBase> Apply(py::object* callable,
|
|
|
|
|
const std::vector<VarBase>& inputs) {
|
|
|
|
|
std::vector<VarBase> outputs;
|
|
|
|
|
std::vector<framework::LoDTensor> tensor_inputs;
|
|
|
|
|
std::vector<framework::LoDTensor*> tensor_outputs;
|
|
|
|
|
|
|
|
|
|
for (const VarBase& in : inputs) {
|
|
|
|
|
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CallPythonFunc(callable, tensor_inputs, &tensor_outputs);
|
|
|
|
|
return outputs;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|