|
|
|
@ -161,13 +161,14 @@ class Layer {
|
|
|
|
|
|
|
|
|
|
static void CallPythonFunc(py::object* callable,
|
|
|
|
|
const std::vector<framework::LoDTensor>& ins,
|
|
|
|
|
std::vector<framework::LoDTensor*>* outs) {
|
|
|
|
|
std::vector<VarBase*>* outs) {
|
|
|
|
|
py::gil_scoped_acquire guard;
|
|
|
|
|
py::tuple in_args(ins.size());
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
|
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(panyx0718): Who owns the returned LoDTensor.
|
|
|
|
|
auto ret = (*callable)(in_args);
|
|
|
|
|
auto ret_tuple = py::cast<py::tuple>(ret);
|
|
|
|
|
size_t ret_num = py::len(ret_tuple);
|
|
|
|
@ -176,7 +177,11 @@ static void CallPythonFunc(py::object* callable,
|
|
|
|
|
auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
|
|
|
|
|
"Output tensor %d should not be nullptr", i);
|
|
|
|
|
outs->push_back(py_out_tensor);
|
|
|
|
|
VarBase* var = new VarBase();
|
|
|
|
|
auto* tensor = var->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
tensor->ShareDataWith(*py_out_tensor);
|
|
|
|
|
tensor->set_lod(py_out_tensor->lod());
|
|
|
|
|
outs->push_back(var);
|
|
|
|
|
} catch (py::cast_error&) {
|
|
|
|
|
PADDLE_THROW("The %d-th output must be LoDTensor", i);
|
|
|
|
|
}
|
|
|
|
@ -187,18 +192,16 @@ class PyLayer {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~PyLayer() {}
|
|
|
|
|
|
|
|
|
|
static std::vector<VarBase> Apply(py::object* callable,
|
|
|
|
|
const std::vector<VarBase>& inputs) {
|
|
|
|
|
std::vector<VarBase> outputs;
|
|
|
|
|
static std::vector<VarBase*> Apply(py::object* callable,
|
|
|
|
|
const std::vector<VarBase>& inputs) {
|
|
|
|
|
std::vector<framework::LoDTensor> tensor_inputs;
|
|
|
|
|
std::vector<framework::LoDTensor*> tensor_outputs;
|
|
|
|
|
std::vector<VarBase*> ret;
|
|
|
|
|
|
|
|
|
|
for (const VarBase& in : inputs) {
|
|
|
|
|
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CallPythonFunc(callable, tensor_inputs, &tensor_outputs);
|
|
|
|
|
return outputs;
|
|
|
|
|
CallPythonFunc(callable, tensor_inputs, &ret);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|