|
|
|
@ -27,6 +27,20 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
void CreateGradOp(const framework::OpDesc& op_desc,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_set,
|
|
|
|
|
const std::vector<framework::BlockDesc*>& grad_sub_block,
|
|
|
|
|
framework::OpDesc** grad_op_desc,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
|
|
|
|
|
framework::OpInfoMap::Instance()
|
|
|
|
|
.Get(op_desc.Type())
|
|
|
|
|
.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
|
|
|
|
|
PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now.");
|
|
|
|
|
// TODO(panyx0718): Leak?
|
|
|
|
|
*grad_op_desc = grad_op_descs[0].release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class Tracer {
|
|
|
|
|
public:
|
|
|
|
|
Tracer() {}
|
|
|
|
@ -44,6 +58,7 @@ class Tracer {
|
|
|
|
|
for (VarBase* input : inputs) {
|
|
|
|
|
const std::string vname = input->var_desc_->Name();
|
|
|
|
|
framework::Variable* var = scope_->Var(vname);
|
|
|
|
|
input->var_ = var;
|
|
|
|
|
if (!var->IsInitialized()) {
|
|
|
|
|
framework::VarDesc* var_desc = block_->FindVar(vname);
|
|
|
|
|
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
@ -52,11 +67,17 @@ class Tracer {
|
|
|
|
|
LOG(ERROR) << "tracer doesn't support yet";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (input->pre_op_) {
|
|
|
|
|
op->pre_ops_->push_back(input->pre_op_);
|
|
|
|
|
op->pre_ops_out_idx_->push_back(input->pre_op_out_idx_);
|
|
|
|
|
} else {
|
|
|
|
|
op->pre_ops_->push_back(nullptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*op->output_vars_ = outputs;
|
|
|
|
|
for (auto output : outputs) {
|
|
|
|
|
const std::string vname = output->var_desc_->Name();
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
|
|
|
const std::string vname = outputs[i]->var_desc_->Name();
|
|
|
|
|
framework::Variable* var = scope_->Var(vname);
|
|
|
|
|
if (!var->IsInitialized()) {
|
|
|
|
|
framework::VarDesc* var_desc = block_->FindVar(vname);
|
|
|
|
@ -66,9 +87,18 @@ class Tracer {
|
|
|
|
|
LOG(ERROR) << "tracer doesn't support yet";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
output->pre_op_ = op;
|
|
|
|
|
outputs[i]->var_ = var;
|
|
|
|
|
outputs[i]->pre_op_ = op;
|
|
|
|
|
outputs[i]->pre_op_out_idx_ = i;
|
|
|
|
|
}
|
|
|
|
|
op_base->Run(*scope_, platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
framework::OpDesc* grad_op_desc;
|
|
|
|
|
auto grad_to_var = new std::unordered_map<std::string, std::string>();
|
|
|
|
|
CreateGradOp(*op_desc, {}, {block_}, &grad_op_desc, grad_to_var);
|
|
|
|
|
op->grad_op_desc_ = grad_op_desc;
|
|
|
|
|
op->grad_to_var_ = grad_to_var;
|
|
|
|
|
op->block_ = block_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetScope(framework::Scope* scope) { scope_ = scope; }
|
|
|
|
|