|
|
|
@ -58,10 +58,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
for (auto it : op->input_vars_) {
|
|
|
|
|
auto& invars = invars_map[it.first];
|
|
|
|
|
for (VarBase* inp : it.second) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(inp->var_.get(), "op %s input %s nullptr",
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
|
|
|
|
|
op->op_desc_->Type(), inp->var_desc_->Name());
|
|
|
|
|
|
|
|
|
|
invars.push_back(inp->var_.get());
|
|
|
|
|
invars.push_back(inp->var_);
|
|
|
|
|
vars[inp->var_desc_->Name()] = inp;
|
|
|
|
|
if (inp->pre_op_) {
|
|
|
|
|
op->pre_ops_[it.first].push_back(inp->pre_op_);
|
|
|
|
@ -80,7 +80,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
const std::vector<VarBase*>& outputs = it.second;
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
|
|
|
VarBase* out = outputs[i];
|
|
|
|
|
outvars.push_back(out->var_.get());
|
|
|
|
|
outvars.push_back(out->var_);
|
|
|
|
|
vars[out->var_desc_->Name()] = out;
|
|
|
|
|
|
|
|
|
|
framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
|
|
|
|
@ -127,13 +127,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
if (var_it == grad_to_var->end()) {
|
|
|
|
|
auto fwd_var_it = vars.find(grad_invar);
|
|
|
|
|
PADDLE_ENFORCE(fwd_var_it != vars.end());
|
|
|
|
|
grad_in_vars.push_back(fwd_var_it->second->var_.get());
|
|
|
|
|
grad_in_vars.push_back(fwd_var_it->second->var_);
|
|
|
|
|
} else {
|
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_.get(), var->grads_->var_.get());
|
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
grad_in_vars.push_back(var->grads_->var_.get());
|
|
|
|
|
grad_in_vars.push_back(var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -146,9 +146,9 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
|
PADDLE_ENFORCE(var_it != grad_to_var->end());
|
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
|
InitVar(var->var_.get(), var->grads_->var_.get());
|
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
grad_out_vars.push_back(var->grads_->var_.get());
|
|
|
|
|
grad_out_vars.push_back(var->grads_->var_);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|