|
|
|
@ -82,6 +82,7 @@ class PreparedOp {
|
|
|
|
|
framework::OperatorWithKernel::OpKernelFunc func;
|
|
|
|
|
platform::DeviceContext* dev_ctx;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpBase;
|
|
|
|
|
|
|
|
|
|
class VarBase {
|
|
|
|
@ -128,7 +129,11 @@ class VarBase {
|
|
|
|
|
|
|
|
|
|
class OpBase {
|
|
|
|
|
public:
|
|
|
|
|
OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {}
|
|
|
|
|
OpBase()
|
|
|
|
|
: op_desc_(nullptr),
|
|
|
|
|
grad_op_desc_(nullptr),
|
|
|
|
|
forward_id_(-1),
|
|
|
|
|
backward_id_(-1) {}
|
|
|
|
|
|
|
|
|
|
virtual ~OpBase() {
|
|
|
|
|
if (grad_op_desc_) delete grad_op_desc_;
|
|
|
|
@ -139,6 +144,9 @@ class OpBase {
|
|
|
|
|
framework::OpDesc* op_desc_;
|
|
|
|
|
framework::OpDesc* grad_op_desc_;
|
|
|
|
|
|
|
|
|
|
int forward_id_;
|
|
|
|
|
int backward_id_;
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> input_vars_;
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> output_vars_;
|
|
|
|
|
std::map<std::string, std::vector<OpBase*>> pre_ops_;
|
|
|
|
@ -159,7 +167,7 @@ class Layer {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static void CallPythonFunc(py::object* callable,
|
|
|
|
|
static void CallPythonFunc(const py::object& callable,
|
|
|
|
|
const std::vector<framework::LoDTensor>& ins,
|
|
|
|
|
std::vector<VarBase*>* outs) {
|
|
|
|
|
py::gil_scoped_acquire guard;
|
|
|
|
@ -169,7 +177,7 @@ static void CallPythonFunc(py::object* callable,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(panyx0718): Who owns the returned LoDTensor.
|
|
|
|
|
auto ret = (*callable)(in_args);
|
|
|
|
|
auto ret = callable(in_args);
|
|
|
|
|
auto ret_tuple = py::cast<py::tuple>(ret);
|
|
|
|
|
size_t ret_num = py::len(ret_tuple);
|
|
|
|
|
for (size_t i = 0; i < ret_num; ++i) {
|
|
|
|
@ -192,17 +200,10 @@ class PyLayer {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~PyLayer() {}
|
|
|
|
|
|
|
|
|
|
static std::vector<VarBase*> Apply(py::object* callable,
|
|
|
|
|
const std::vector<VarBase>& inputs) {
|
|
|
|
|
std::vector<framework::LoDTensor> tensor_inputs;
|
|
|
|
|
std::vector<VarBase*> ret;
|
|
|
|
|
static void RegisterFunc(int func_id, const py::object& py_func);
|
|
|
|
|
|
|
|
|
|
for (const VarBase& in : inputs) {
|
|
|
|
|
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
|
|
|
|
|
}
|
|
|
|
|
CallPythonFunc(callable, tensor_inputs, &ret);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
static std::vector<VarBase*> Apply(int func_id,
|
|
|
|
|
const std::vector<VarBase>& inputs);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace imperative
|
|
|
|
|