|
|
|
@ -42,7 +42,11 @@ size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
|
|
|
|
|
// Returning py::object would cause reference count increasing
|
|
|
|
|
// but without GIL, reference count in Python may not be safe
|
|
|
|
|
static py::object *GetPythonCallableObject(size_t i) {
|
|
|
|
|
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id");
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
i, g_py_callables.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid python callable id %d, which should be less than %d.", i,
|
|
|
|
|
g_py_callables.size()));
|
|
|
|
|
return &g_py_callables[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -71,10 +75,27 @@ static void CallPythonFunc(py::object *callable,
|
|
|
|
|
// Python function has no return values or returns None
|
|
|
|
|
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
|
|
|
|
|
// Otherwise, ret_num must be equal to out_num
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ret_num == 1 && out_num == 0 &&
|
|
|
|
|
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr,
|
|
|
|
|
"Output number not match. Expected %d, actual %d", out_num, ret_num);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ret_num == 1, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Python function has no return values or returns "
|
|
|
|
|
"None. In this case, ret_num = 1 && ret[0] == None "
|
|
|
|
|
"&& out_num should be 0. But ret_num is %d",
|
|
|
|
|
ret_num));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
out_num == 0, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Python function has no return values or returns None. In "
|
|
|
|
|
"this case, ret_num = 1 && ret[0] == None && out_num should "
|
|
|
|
|
"be 0. But out_num is %d",
|
|
|
|
|
out_num));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Python function has no return values or returns None. In "
|
|
|
|
|
"this case, ret_num = 1 && ret[0] == None && out_num should "
|
|
|
|
|
"be 0. But ret[0] is not None"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < out_num; ++i) {
|
|
|
|
@ -85,7 +106,8 @@ static void CallPythonFunc(py::object *callable,
|
|
|
|
|
try {
|
|
|
|
|
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
|
|
|
|
|
"Output tensor %d should not be nullptr", i);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output tensor %d should not be nullptr", i));
|
|
|
|
|
out->set_lod(py_out_tensor->lod());
|
|
|
|
|
out->ShareDataWith(*py_out_tensor);
|
|
|
|
|
} catch (py::cast_error &) {
|
|
|
|
@ -105,10 +127,17 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
* X or Out can be empty, so that py_func can be more flexible
|
|
|
|
|
* to support Python functions with no input or no output
|
|
|
|
|
*/
|
|
|
|
|
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)),
|
|
|
|
|
0, "Function id cannot be less than 0");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
has_in || has_out, true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) or Output(Out) must exist, "
|
|
|
|
|
"but has_in is %d, has_out is %d.",
|
|
|
|
|
has_in, has_out));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId)), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Function id cannot be less than 0, but received value is %d.",
|
|
|
|
|
boost::get<int>(ctx->GetAttr(kForwardPythonCallableId))));
|
|
|
|
|
|
|
|
|
|
if (!has_out) return;
|
|
|
|
|
|
|
|
|
@ -128,10 +157,12 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
size_t len = out_var_name.size() - kGradVarSuffix.size();
|
|
|
|
|
if (out_var_name.substr(len) == kGradVarSuffix) {
|
|
|
|
|
auto fwd_var_name = out_var_name.substr(0, len);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasVar(out_var_name),
|
|
|
|
|
"Backward variable %s not found", out_var_name);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasVar(fwd_var_name),
|
|
|
|
|
"Backward variable %s not found", fwd_var_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasVar(out_var_name), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Backward variable %s not found", out_var_name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasVar(fwd_var_name), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Backward variable %s not found", fwd_var_name));
|
|
|
|
|
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
|
|
|
|
|
<< fwd_var_name << ")";
|
|
|
|
|
|
|
|
|
@ -147,8 +178,9 @@ class PyFuncOpVarTypeInference : public framework::VarTypeInference {
|
|
|
|
|
class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"Infer shape cannot be called in runtime.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(!ctx->IsRuntime(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Infer shape cannot be called in runtime."));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|