|
|
|
@ -35,6 +35,9 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
|
|
|
|
|
return g_py_callables.size() - 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Return py::object* instead of py::object
|
|
|
|
|
// Returning py::object would cause reference count increasing
|
|
|
|
|
// but without GIL, reference count in Python may not be safe
|
|
|
|
|
static py::object *GetPythonCallableObject(size_t i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id");
|
|
|
|
|
return &g_py_callables[i];
|
|
|
|
@ -47,7 +50,7 @@ static std::string PythonObjectToString(const py::object &py_callable) {
|
|
|
|
|
|
|
|
|
|
static void CallPythonFunc(py::object *callable,
|
|
|
|
|
const std::vector<framework::LoDTensor> &ins,
|
|
|
|
|
std::vector<framework::LoDTensor *> *out) {
|
|
|
|
|
std::vector<framework::LoDTensor *> *outs) {
|
|
|
|
|
py::gil_scoped_acquire guard;
|
|
|
|
|
py::tuple in_args(ins.size());
|
|
|
|
|
for (size_t i = 0; i < ins.size(); ++i) {
|
|
|
|
@ -57,8 +60,8 @@ static void CallPythonFunc(py::object *callable,
|
|
|
|
|
auto ret = (*callable)(*in_args);
|
|
|
|
|
auto ret_tuple = py::cast<py::tuple>(ret);
|
|
|
|
|
size_t ret_num = py::len(ret_tuple);
|
|
|
|
|
size_t out_num = out->size();
|
|
|
|
|
if (ret_num != out_num) {
|
|
|
|
|
size_t out_num = outs->size();
|
|
|
|
|
if (UNLIKELY(ret_num != out_num)) {
|
|
|
|
|
// Python function has no return values or returns None
|
|
|
|
|
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
|
|
|
|
|
// Otherwise, ret_num must be equal to out_num
|
|
|
|
@ -69,17 +72,18 @@ static void CallPythonFunc(py::object *callable,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < out_num; ++i) {
|
|
|
|
|
if ((*out)[i] == nullptr) {
|
|
|
|
|
auto *out = (*outs)[i];
|
|
|
|
|
if (out == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
try {
|
|
|
|
|
auto *out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_tensor,
|
|
|
|
|
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
|
|
|
|
|
"Output tensor %d should not be nullptr", i);
|
|
|
|
|
(*out)[i]->set_lod(out_tensor->lod());
|
|
|
|
|
(*out)[i]->ShareDataWith(*out_tensor);
|
|
|
|
|
out->set_lod(py_out_tensor->lod());
|
|
|
|
|
out->ShareDataWith(*py_out_tensor);
|
|
|
|
|
} catch (py::cast_error &) {
|
|
|
|
|
PADDLE_THROW("Output %d is not LoDTensor", i);
|
|
|
|
|
PADDLE_THROW("The %d-th output must be LoDTensor", i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -94,6 +98,10 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>(kForwardPythonCallableId), 0,
|
|
|
|
|
"Function id cannot be less than 0");
|
|
|
|
|
|
|
|
|
|
// Transverse all outputs
|
|
|
|
|
// If name of any output ends with @GRAD,
|
|
|
|
|
// set its shape, dtype, lod_level, type to be the same as
|
|
|
|
|
// the correponding forward variable
|
|
|
|
|
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
|
|
|
|
|
auto *block = op->Block();
|
|
|
|
|
const std::string kGradVarSuffix = framework::kGradVarSuffix;
|
|
|
|
@ -115,7 +123,7 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
auto *in_var_desc = block->FindVarRecursive(fwd_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found",
|
|
|
|
|
fwd_var_name);
|
|
|
|
|
VLOG(10) << "Infer shape of Out(" << out_name << ") as Input("
|
|
|
|
|
VLOG(10) << "Infer shape of Output(" << out_name << ") as Input("
|
|
|
|
|
<< in_var_desc->Name() << ")";
|
|
|
|
|
out_var_desc->SetShape(in_var_desc->GetShape());
|
|
|
|
|
out_var_desc->SetDataType(in_var_desc->GetDataType());
|
|
|
|
@ -135,7 +143,7 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"Index of registered forward Python function.")
|
|
|
|
|
.SetDefault(0);
|
|
|
|
|
AddAttr<int>(kBackwardPythonCallableId,
|
|
|
|
|
"Index of registered backward Python function")
|
|
|
|
|
"Index of registered backward Python function.")
|
|
|
|
|
.SetDefault(-1);
|
|
|
|
|
AddAttr<std::vector<std::string>>(kPyFuncBackwardSkipVars,
|
|
|
|
|
"Unused forward in/out in backward op")
|
|
|
|
@ -170,8 +178,7 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
|
|
|
|
|
auto fwd_outs = Output("Out");
|
|
|
|
|
|
|
|
|
|
// For memory reused, some inputs/output in forward part may be not needed
|
|
|
|
|
// in backward part
|
|
|
|
|
// Just skip these vars
|
|
|
|
|
// in backward part. Skipping these vars helps to save memory
|
|
|
|
|
auto &backward_skip_var_list = boost::get<std::vector<std::string>>(
|
|
|
|
|
fwd_attrs.at(kPyFuncBackwardSkipVars));
|
|
|
|
|
std::unordered_set<std::string> backward_skip_var_set(
|
|
|
|
|