|
|
|
@ -97,6 +97,13 @@ void AddTo(Variable* src, Variable* dst, platform::Place place) {
|
|
|
|
|
boost::apply_visitor(func, place);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ZeroGrads(VarBase* vb, const platform::Place& place) {
|
|
|
|
|
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto* dev_ctx = pool.Get(place);
|
|
|
|
|
auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
|
|
|
|
|
PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
|
|
|
|
|
"Can't find %s in backward grad map", target->Name());
|
|
|
|
@ -110,9 +117,9 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
|
|
|
|
|
for (auto& var_pair : current.second) {
|
|
|
|
|
Variable* origin_grad = target->var_.get();
|
|
|
|
|
Variable* grad_to_add = var_pair.second->var_.get();
|
|
|
|
|
VLOG(2) << "add origin_grad: " << target->Name();
|
|
|
|
|
VLOG(2) << "added grad: " << var_pair.second->Name()
|
|
|
|
|
<< " trace id is: " << var_pair.first;
|
|
|
|
|
VLOG(10) << "add origin_grad: " << target->Name();
|
|
|
|
|
VLOG(10) << "added grad: " << var_pair.second->Name()
|
|
|
|
|
<< " trace id is: " << var_pair.first;
|
|
|
|
|
AddTo(grad_to_add, origin_grad, current.first);
|
|
|
|
|
delete var_pair.second;
|
|
|
|
|
var_pair.second = nullptr;
|
|
|
|
@ -127,7 +134,7 @@ class Autograd {
|
|
|
|
|
if (var->IsStopGradient()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "start autograd";
|
|
|
|
|
VLOG(2) << "start autograd";
|
|
|
|
|
BackwardSumMap bck_map;
|
|
|
|
|
GradientRef grad_ref;
|
|
|
|
|
std::deque<OpBase*> ready;
|
|
|
|
@ -195,7 +202,7 @@ class Autograd {
|
|
|
|
|
for (auto it : candidate->pre_ops_) {
|
|
|
|
|
for (OpBase* pre_op : it.second) {
|
|
|
|
|
if (!pre_op) continue;
|
|
|
|
|
VLOG(2) << "op dep " << candidate->Type() << " trace id "
|
|
|
|
|
VLOG(9) << "op dep " << candidate->Type() << " trace id "
|
|
|
|
|
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
|
|
|
|
|
<< pre_op->Type() << " trace id " << pre_op->trace_id_;
|
|
|
|
|
if (visited.find(pre_op) == visited.end()) {
|
|
|
|
@ -267,9 +274,11 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
for (const auto& it : grad_output_variable_map) {
|
|
|
|
|
auto& outputs = tmp_grad_outputs[k][it.first];
|
|
|
|
|
outputs.reserve(it.second.size());
|
|
|
|
|
for (size_t i = 0; i < it.second.size(); ++i) {
|
|
|
|
|
VarBase* origin_grad_var_base = it.second[i];
|
|
|
|
|
|
|
|
|
|
for (VarBase* origin_grad_var_base : it.second) {
|
|
|
|
|
if (!origin_grad_var_base->IsInitialize()) {
|
|
|
|
|
origin_grad_var_base->InitBuffer();
|
|
|
|
|
ZeroGrads(origin_grad_var_base, place_);
|
|
|
|
|
}
|
|
|
|
|
// Allocate a new variable
|
|
|
|
|
VarBase* tmp_grad_var_base = new VarBase(
|
|
|
|
|
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
|
|
|
|
@ -304,11 +313,15 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
for (const auto& it : grad_input_vars_[k]) {
|
|
|
|
|
auto& grad_invars = grad_invars_map[it.first];
|
|
|
|
|
grad_invars.reserve(it.second.size());
|
|
|
|
|
for (const VarBase* grad_inp : it.second) {
|
|
|
|
|
for (VarBase* grad_inp : it.second) {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
|
|
|
|
|
grad_op_desc->Type(), grad_inp->Name());
|
|
|
|
|
|
|
|
|
|
grad_invars.emplace_back(grad_inp->var_.get());
|
|
|
|
|
if (!grad_inp->IsInitialize()) {
|
|
|
|
|
grad_inp->InitBuffer();
|
|
|
|
|
ZeroGrads(grad_inp, place_);
|
|
|
|
|
}
|
|
|
|
|
const VarBase* const_grad_inp = grad_inp;
|
|
|
|
|
grad_invars.emplace_back(const_grad_inp->var_.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -343,22 +356,23 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
// track outputs used by sum
|
|
|
|
|
if (bck_stratedy.sorted_sum_gradient_) {
|
|
|
|
|
#ifndef PADDLE_WITH_CUDA
|
|
|
|
|
VLOG(2) << "origin_outputs is : " << origin_outputs[i]->Name() << " ";
|
|
|
|
|
VLOG(2) << origin_outputs[i]
|
|
|
|
|
->var_->GetMutable<framework::LoDTensor>()
|
|
|
|
|
->data<float>()[0];
|
|
|
|
|
VLOG(2) << "outputs is : " << outputs[i]->Name() << " ";
|
|
|
|
|
VLOG(2) << outputs[i]
|
|
|
|
|
->var_->GetMutable<framework::LoDTensor>()
|
|
|
|
|
->data<float>()[0];
|
|
|
|
|
VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
|
|
|
|
|
<< " ";
|
|
|
|
|
VLOG(10) << origin_outputs[i]
|
|
|
|
|
->var_->GetMutable<framework::LoDTensor>()
|
|
|
|
|
->data<float>()[0];
|
|
|
|
|
VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
|
|
|
|
|
VLOG(10) << outputs[i]
|
|
|
|
|
->var_->GetMutable<framework::LoDTensor>()
|
|
|
|
|
->data<float>()[0];
|
|
|
|
|
#endif
|
|
|
|
|
if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
|
|
|
|
|
VLOG(2) << "add sub grad to " << origin_outputs[i]->Name();
|
|
|
|
|
VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
|
|
|
|
|
bck_map->at(origin_outputs[i])
|
|
|
|
|
.second.emplace_back(
|
|
|
|
|
std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(2) << "insert new map for " << origin_outputs[i]->Name();
|
|
|
|
|
VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
|
|
|
|
|
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
|
|
|
|
|
tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
|
|
|
|
|
bck_map->insert(std::make_pair(origin_outputs[i], tmp));
|
|
|
|
@ -370,19 +384,19 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
|
|
|
|
|
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
|
|
|
|
|
"Backward error when calculate grad reference");
|
|
|
|
|
if (grad_ref->at(origin_outputs[i]) > 1) {
|
|
|
|
|
VLOG(2) << "remove ref for " << origin_outputs[i]->Name();
|
|
|
|
|
VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
|
|
|
|
|
grad_ref->at(origin_outputs[i])--;
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(2) << "Add grad for: " << origin_outputs[i]->Name();
|
|
|
|
|
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
|
|
|
|
|
AddGradBySort(bck_map, origin_outputs[i]);
|
|
|
|
|
grad_ref->at(origin_outputs[i])--;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
framework::Variable* grad = outputs[i]->var_.get();
|
|
|
|
|
framework::Variable* orig_grad = origin_outputs[i]->var_.get();
|
|
|
|
|
VLOG(2) << "AddTo Called with orig_grad is: "
|
|
|
|
|
<< origin_outputs[i]->name_ << " Grad to be added is "
|
|
|
|
|
<< outputs[i]->name_;
|
|
|
|
|
VLOG(10) << "AddTo Called with orig_grad is: "
|
|
|
|
|
<< origin_outputs[i]->name_ << " Grad to be added is "
|
|
|
|
|
<< outputs[i]->name_;
|
|
|
|
|
AddTo(grad, orig_grad, place_);
|
|
|
|
|
delete outputs[i];
|
|
|
|
|
}
|
|
|
|
@ -413,6 +427,7 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
|
|
|
|
|
if (!pre_op_) return;
|
|
|
|
|
platform::RecordEvent record_event("Imperative Backward");
|
|
|
|
|
VLOG(3) << "start backward";
|
|
|
|
|
grads_->InitBuffer();
|
|
|
|
|
auto grads_t = grads_->var_->GetMutable<framework::LoDTensor>();
|
|
|
|
|
operators::math::set_constant(
|
|
|
|
|
*(platform::DeviceContextPool::Instance().Get(
|
|
|
|
|