|
|
|
@ -91,66 +91,68 @@ static void CallPythonFunc(py::object *callable,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
class PyFuncOpVarTypInference : public framework::VarTypeInference {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"Infer shape cannot be called in runtime.");
|
|
|
|
|
void operator()(const framework::OpDesc &op,
|
|
|
|
|
framework::BlockDesc *block) const override {
|
|
|
|
|
auto &outs = op.Outputs();
|
|
|
|
|
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
|
|
|
|
|
|
|
|
|
|
auto &ins = op.Inputs();
|
|
|
|
|
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* X or Out can be empty, so that py_func can be more flexible
|
|
|
|
|
* to support Python functions with no input or no output
|
|
|
|
|
*/
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInputs("X") || ctx->HasOutputs("Out"),
|
|
|
|
|
"Input(X) or Output(Out) must exist");
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>(kForwardPythonCallableId), 0,
|
|
|
|
|
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
|
|
|
|
|
"Function id cannot be less than 0");
|
|
|
|
|
|
|
|
|
|
if (!has_out) return;
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Traverse all outputs, check if name of any output ends with @GRAD.
|
|
|
|
|
* If found, set its shape, dtype, lod_level, type to be the same as
|
|
|
|
|
* the corresponding forward variable
|
|
|
|
|
*
|
|
|
|
|
* Why not get input dims from InferShapeContext?
|
|
|
|
|
* Because some variables in forward inputs/outputs may not be needed
|
|
|
|
|
* in backward. Those variables are not inside InferShapeContext.
|
|
|
|
|
*
|
|
|
|
|
* InferShape would be only called in compile time. During runtime,
|
|
|
|
|
* the shapes of outputs should be guaranteed by user-defined Python
|
|
|
|
|
* functions.
|
|
|
|
|
*/
|
|
|
|
|
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
|
|
|
|
|
auto *block = op->Block();
|
|
|
|
|
const std::string kGradVarSuffix = framework::kGradVarSuffix;
|
|
|
|
|
auto out_vars = ctx->GetOutputVarPtrs("Out");
|
|
|
|
|
for (auto &out_var : out_vars) {
|
|
|
|
|
auto *out_var_desc = boost::get<framework::VarDesc *>(out_var);
|
|
|
|
|
if (out_var_desc == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto out_name = out_var_desc->Name();
|
|
|
|
|
if (out_name == framework::kEmptyVarName ||
|
|
|
|
|
out_name.size() < kGradVarSuffix.size()) {
|
|
|
|
|
auto &out_var_names = outs.at("Out");
|
|
|
|
|
for (auto &out_var_name : out_var_names) {
|
|
|
|
|
if (out_var_name == framework::kEmptyVarName ||
|
|
|
|
|
out_var_name.size() < kGradVarSuffix.size()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t len = out_name.size() - kGradVarSuffix.size();
|
|
|
|
|
if (out_name.substr(len) == kGradVarSuffix) {
|
|
|
|
|
auto fwd_var_name = out_name.substr(0, len);
|
|
|
|
|
auto *in_var_desc = block->FindVarRecursive(fwd_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found",
|
|
|
|
|
size_t len = out_var_name.size() - kGradVarSuffix.size();
|
|
|
|
|
if (out_var_name.substr(len) == kGradVarSuffix) {
|
|
|
|
|
auto fwd_var_name = out_var_name.substr(0, len);
|
|
|
|
|
auto *out_var_desc = block->FindVarRecursive(out_var_name);
|
|
|
|
|
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
|
|
|
|
|
out_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
|
|
|
|
|
fwd_var_name);
|
|
|
|
|
VLOG(10) << "Infer shape of Output(" << out_name << ") as Input("
|
|
|
|
|
<< in_var_desc->Name() << ")";
|
|
|
|
|
out_var_desc->SetShape(in_var_desc->GetShape());
|
|
|
|
|
out_var_desc->SetDataType(in_var_desc->GetDataType());
|
|
|
|
|
out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel());
|
|
|
|
|
out_var_desc->SetType(in_var_desc->GetType());
|
|
|
|
|
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
|
|
|
|
|
<< fwd_var_name << ")";
|
|
|
|
|
out_var_desc->SetShape(fwd_var_desc->GetShape());
|
|
|
|
|
out_var_desc->SetDataType(fwd_var_desc->GetDataType());
|
|
|
|
|
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
|
|
|
|
|
out_var_desc->SetType(fwd_var_desc->GetType());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class PyFuncOpShapeInference : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->IsRuntime(),
|
|
|
|
|
"Infer shape cannot be called in runtime.");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
void Make() override {
|
|
|
|
@ -307,4 +309,5 @@ class PyFuncOp : public framework::OperatorBase {
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
|
|
|
|
|
ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker);
|
|
|
|
|
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
|
|
|
|
|
ops::PyFuncOpGradDescMaker);
|
|
|
|
|