|
|
|
@ -43,9 +43,12 @@ static py::object *GetPythonCallableObject(size_t i) {
|
|
|
|
|
return &g_py_callables[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string PythonObjectToString(const py::object &py_callable) {
|
|
|
|
|
static std::string PythonFuncDebugString(const py::object &py_callable) {
|
|
|
|
|
py::gil_scoped_acquire guard;
|
|
|
|
|
return py::str(*py_callable);
|
|
|
|
|
std::string wrapper_func_str = py::str(py_callable);
|
|
|
|
|
auto inner_func = py_callable.attr("_func");
|
|
|
|
|
std::string inner_func_str = py::str(inner_func);
|
|
|
|
|
return inner_func_str + " wrapped by " + wrapper_func_str;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CallPythonFunc(py::object *callable,
|
|
|
|
@ -93,15 +96,29 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"Infer shape cannot be called in runtime.");
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* X or Out can be empty, so that py_func can be more flexible
|
|
|
|
|
* to support Python functions with no input or no output
|
|
|
|
|
*/
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"),
|
|
|
|
|
"Input(X) or Output(Out) must exist");
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>(kForwardPythonCallableId), 0,
|
|
|
|
|
"Function id cannot be less than 0");
|
|
|
|
|
|
|
|
|
|
// Transverse all outputs
|
|
|
|
|
// If name of any output ends with @GRAD,
|
|
|
|
|
// set its shape, dtype, lod_level, type to be the same as
|
|
|
|
|
// the correponding forward variable
|
|
|
|
|
/**
|
|
|
|
|
* Traverse all outputs, check if name of any output ends with @GRAD.
|
|
|
|
|
* If found, set its shape, dtype, lod_level, type to be the same as
|
|
|
|
|
* the corresponding forward variable
|
|
|
|
|
*
|
|
|
|
|
* Why not get input dims from InferShapeContext?
|
|
|
|
|
* Because some variables in forward inputs/outputs may not be needed
|
|
|
|
|
* in backward. Those variables are not inside InferShapeContext.
|
|
|
|
|
*
|
|
|
|
|
* InferShape would be only called in compile time. During runtime,
|
|
|
|
|
* the shapes of outputs should be guaranteed by user-defined Python
|
|
|
|
|
* functions.
|
|
|
|
|
*/
|
|
|
|
|
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
|
|
|
|
|
auto *block = op->Block();
|
|
|
|
|
const std::string kGradVarSuffix = framework::kGradVarSuffix;
|
|
|
|
@ -113,7 +130,7 @@ class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
}
|
|
|
|
|
auto out_name = out_var_desc->Name();
|
|
|
|
|
if (out_name == framework::kEmptyVarName ||
|
|
|
|
|
out_name.size() <= kGradVarSuffix.size()) {
|
|
|
|
|
out_name.size() < kGradVarSuffix.size()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -152,7 +169,28 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* There are several benefits when backward op of py_func op is
|
|
|
|
|
* still py_func op.
|
|
|
|
|
*
|
|
|
|
|
* - Less codes are needed, since codes of backward is almost
|
|
|
|
|
* the same as forward.
|
|
|
|
|
*
|
|
|
|
|
* - To support high order derivative, so that py_func is
|
|
|
|
|
* infinite-order differentiable
|
|
|
|
|
*/
|
|
|
|
|
class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
|
|
|
|
|
private:
|
|
|
|
|
static std::string DebugString(const std::vector<std::string> &strs) {
|
|
|
|
|
if (strs.empty()) return "";
|
|
|
|
|
std::string ret = strs[0];
|
|
|
|
|
for (size_t i = 1; i < strs.size(); ++i) {
|
|
|
|
|
ret += " ";
|
|
|
|
|
ret += strs[i];
|
|
|
|
|
}
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
|
|
|
|
|
|
|
|
|
@ -207,21 +245,8 @@ class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
|
|
|
|
|
// But in Python side, if IG is not needed, users can just return None
|
|
|
|
|
auto bwd_outs = InputGrad("X", false);
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
std::string in_str = "PyFunc Grad Input: ";
|
|
|
|
|
for (auto &in : bwd_ins) {
|
|
|
|
|
in_str += in;
|
|
|
|
|
in_str += " ";
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << in_str;
|
|
|
|
|
|
|
|
|
|
std::string out_str = "PyFunc Grad Output: ";
|
|
|
|
|
for (auto &out : bwd_outs) {
|
|
|
|
|
out_str += out;
|
|
|
|
|
out_str += " ";
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << out_str;
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << "PyFunc Grad Input: " << DebugString(bwd_ins);
|
|
|
|
|
VLOG(10) << "PyFunc Grad Output: " << DebugString(bwd_outs);
|
|
|
|
|
|
|
|
|
|
grad_op->SetInput("X", bwd_ins);
|
|
|
|
|
grad_op->SetOutput("Out", bwd_outs);
|
|
|
|
@ -245,6 +270,7 @@ class PyFuncOp : public framework::OperatorBase {
|
|
|
|
|
std::vector<framework::LoDTensor> inputs(in_arg_names.size());
|
|
|
|
|
for (size_t i = 0; i < in_arg_names.size(); ++i) {
|
|
|
|
|
auto in_var = scope.FindVar(in_arg_names[i]);
|
|
|
|
|
// When py_func op is called in backward, in_var may be null
|
|
|
|
|
if (in_var == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
@ -263,15 +289,14 @@ class PyFuncOp : public framework::OperatorBase {
|
|
|
|
|
std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
|
|
|
|
|
for (size_t i = 0; i < out_arg_names.size(); ++i) {
|
|
|
|
|
auto *out_var = scope.FindVar(out_arg_names[i]);
|
|
|
|
|
auto *out_tensor =
|
|
|
|
|
outputs[i] =
|
|
|
|
|
out_var ? out_var->GetMutable<framework::LoDTensor>() : nullptr;
|
|
|
|
|
outputs[i] = out_tensor;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto callable_id = static_cast<size_t>(Attr<int>(kForwardPythonCallableId));
|
|
|
|
|
auto *py_callable = GetPythonCallableObject(callable_id);
|
|
|
|
|
VLOG(10) << "Call py_func_op with id " << callable_id << ": "
|
|
|
|
|
<< PythonObjectToString(*py_callable);
|
|
|
|
|
VLOG(10) << "Call Python function with id " << callable_id << ": "
|
|
|
|
|
<< PythonFuncDebugString(*py_callable);
|
|
|
|
|
CallPythonFunc(py_callable, inputs, &outputs);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|