You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/imperative/tracer.cc

262 lines
11 KiB

// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/tracer.h"
#include <unordered_set>
#include <utility>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace imperative {
static std::vector<std::unique_ptr<framework::OpDesc>> CreateGradOpDescs(
const framework::OpInfo& op_info, const framework::OpDesc& op_desc,
const std::unordered_set<std::string>& no_grad_set,
const std::vector<framework::BlockDesc*>& grad_sub_block,
std::unordered_map<std::string, std::string>* grad_to_var) {
if (op_info.grad_op_maker_) {
return op_info.grad_op_maker_(op_desc, no_grad_set, grad_to_var,
grad_sub_block);
} else {
return {};
}
}
static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
for (const auto& name_pair : outs) {
for (const auto& vb : name_pair.second) {
VLOG(6) << "Set output: " << vb->Name() << "'s OverridedStopGradient as "
<< generate_grad;
vb->InnerSetOverridedStopGradient(generate_grad);
}
}
}
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) {
platform::RecordEvent event(type);
VLOG(1) << "Trace Op: " << type;
size_t op_id = GenerateUniqueId();
auto op = OpBase::Create(op_id, type, ins, outs, std::move(attrs), place);
op->Run(ins, outs);
if (enable_program_desc_tracing_) {
VLOG(5) << "Trace op " << type << " into ProgramDesc";
program_desc_tracer_->InsertOp(type, ins, outs, op->Attrs());
}
if (ComputeRequiredGrad(ins, outs, trace_backward)) {
TraceBackward(op, framework::OpDesc(op->Type(), op->InputNameMap(),
op->OutputNameMap(), op->Attrs()),
ins, outs);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
bool trace_backward) {
if (!trace_backward) return false;
for (const auto& name_pair : ins) {
for (const auto& var_base : name_pair.second) {
if (!var_base->OverridedStopGradient()) {
VLOG(6) << "Find out input: " << var_base->Name()
<< "'s GeneratedGrad is True";
PassStopGradient(outs, var_base->OverridedStopGradient());
return true;
}
}
}
return false;
}
void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
const framework::OpDesc& fwd_op_desc,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs) {
// grad_to_var is a map of framework::GradVarName(in_var_name/out_var_name) ->
// in_var_name/out_var_name
std::unordered_map<std::string, std::string> grad_to_var;
// Get grad_op_desc using fwd_op_desc
std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs_ =
CreateGradOpDescs(fwd_op->Info(), fwd_op_desc, {}, {}, &grad_to_var);
// Create grad_ops using grad_op_descs
size_t grad_op_num = grad_op_descs_.size();
VLOG(3) << "Create " << grad_op_num << " grad op desc(s) to op "
<< fwd_op->Type();
if (grad_op_num == 0) {
return;
}
// Build a map to record var_name -> std::shared_ptr<VarBase>*,
// so that we can find suitable var in grad op descs
std::unordered_map<std::string, const std::shared_ptr<VarBase>*> name_to_var;
for (auto& pair : ins) {
for (auto& var : pair.second) {
auto& var_ptr = name_to_var[var->Name()];
PADDLE_ENFORCE_EQ(var_ptr == nullptr || var_ptr->get() == var.get(), true,
"There are different variables with same name %s",
var->Name());
var_ptr = &var;
}
}
for (auto& pair : outs) {
for (auto& var : pair.second) {
auto& var_ptr = name_to_var[var->Name()];
PADDLE_ENFORCE_EQ(var_ptr == nullptr || var_ptr->get() == var.get(), true,
"There are different variables with same name %s",
var->Name());
var_ptr = &var;
}
}
// Build backward ins and outs
for (size_t i = 0; i < grad_op_num; i++) {
// Step1: build grad op and add them to engine
// Use trace id to decide the order of gradient sum in sorted sum mode
size_t trace_id = fwd_op->id();
std::shared_ptr<OpBase> grad_op =
OpBase::Create(trace_id, (*(grad_op_descs_[i].get())), fwd_op->place());
// this OpBase* is just used to manage op's life time
engine_->InsertOp(grad_op.get(), grad_op);
std::unordered_set<OpBase*> visited_preceding_ops;
// Step2 : prepare grad_in vars and bind them with grad_op,
// set inputs' grad_op as current grad_op
for (const auto& grad_ins : grad_op_descs_[i]->Inputs()) {
if (grad_ins.second.empty()) continue;
auto& bwd_in = (*grad_op->GetMutableInsMap())[grad_ins.first];
bwd_in.reserve(grad_ins.second.size());
for (auto& grad_in_var_name : grad_ins.second) {
auto iter = grad_to_var.find(grad_in_var_name);
if (iter != grad_to_var.end()) {
// If it is a grad var, find its coresponding forward var
auto& fwd_var_name = iter->second;
auto fwd_var_iter = name_to_var.find(fwd_var_name);
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s",
fwd_var_name);
const auto& tmp = (*(fwd_var_iter->second))->GradVarBase();
PADDLE_ENFORCE_NOT_NULL(
tmp.get(),
"Grad of %s should "
"not be NULL when we Track_Backward Input of %s",
(*(fwd_var_iter->second))->Name(), grad_op->Type());
// Create grad_in's dim in tensor for Grad Dependency compute
auto* tensor = tmp->MutableVar()->GetMutable<framework::LoDTensor>();
tensor->Resize((*(fwd_var_iter->second))
->Var()
.Get<framework::LoDTensor>()
.dims());
// Add Grad Op for grad_in
tmp->AddGradOps(grad_op);
VLOG(3) << "Add Grad Op " << grad_op->Type() << " for :"
<< (*(fwd_var_iter->second))->GradVarBase()->Name();
// Add Grad var input to engine set
engine_->InsertGradVar(tmp.get());
VLOG(3) << "Add Grad: " << tmp->Name() << " in to Engine";
bwd_in.emplace_back((*(fwd_var_iter->second))->GradVarBase());
} else {
// If it is a forward var, just add it
auto fwd_var_iter = name_to_var.find(grad_in_var_name);
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s",
grad_in_var_name);
bwd_in.emplace_back(*(fwd_var_iter->second));
}
VLOG(3) << "Set backward input from fwd var" << grad_ins.first << " of "
<< grad_op->Type() << " to be "
<< (bwd_in.back() ? bwd_in.back()->Name() : "nullptr");
}
}
// Step3: prepare grad_out vars and using their grad_ops to set current
// grad_op's preceding op
for (auto& grad_outs : grad_op_descs_[i]->Outputs()) {
if (grad_outs.second.empty()) continue;
auto& bwd_out = (*grad_op->GetMutableOutsMap())[grad_outs.first];
bwd_out.reserve(grad_outs.second.size());
for (auto& grad_out_var_name : grad_outs.second) {
auto iter = grad_to_var.find(grad_out_var_name);
PADDLE_ENFORCE_EQ(iter != grad_to_var.end(), true,
"Cannot find output of input grad %s in op %s",
grad_out_var_name, fwd_op->Type());
auto fwd_var_iter = name_to_var.find(iter->second);
PADDLE_ENFORCE_EQ(fwd_var_iter != name_to_var.end(), true,
"Cannot find forward variable named %s",
iter->second);
const auto& tmp = (*(fwd_var_iter->second))->GradVarBase();
PADDLE_ENFORCE_NOT_NULL(tmp.get(),
"Grad output: %s of op: %s should not be NULL",
(tmp->Name(), grad_op->Type()));
if ((!tmp->OverridedStopGradient()) || (grad_outs.second.size() > 1)) {
VLOG(3) << "Set backward output " << grad_outs.first << " of "
<< grad_op->Type() << " to be " << tmp->Name()
<< ". Its Overrided Stop_Gradient is: False";
bwd_out.emplace_back(tmp);
auto grad_pending_ops =
(*(fwd_var_iter->second))->GradVarBase()->GradOps();
if (VLOG_IS_ON(3) && !grad_pending_ops.empty()) {
VLOG(3) << "Add grad_pending Op of :"
<< (*(fwd_var_iter->second))->GradVarBase()->Name()
<< " It's grad_pending Op are: ";
for (const auto& op : grad_pending_ops) {
VLOG(3) << op->Type();
}
}
auto grad_name = (*(fwd_var_iter->second))->GradVarBase()->Name();
if (!grad_pending_ops.empty()) {
for (const auto& op : grad_pending_ops) {
PADDLE_ENFORCE_NOT_NULL(
op, "No nullptr should be grad_pending op for variable %s ",
grad_name);
if (visited_preceding_ops.count(op) == 0) {
visited_preceding_ops.insert(op);
grad_op->InsertGradPendingOps(op);
}
}
} else {
VLOG(5) << "Hit leaf VarBase"
<< (*(fwd_var_iter->second))->GradVarBase()->Name();
}
} else {
VLOG(3) << "Skip backward output " << grad_outs.first << " of "
<< grad_op->Type() << " Named: " << tmp->Name()
<< ", since its Overrided Stop_Gradient is: True";
}
}
}
// To ensure numeric stability as static graph
grad_op->SortGradPendingOps();
}
}
} // namespace imperative
} // namespace paddle