Reset output var's pre_op pointer when op was destructed

revert-15953-remove_default_stream_task_1
minqiyang 6 years ago
parent cb85ee987b
commit ac88c62a5b

@ -158,9 +158,10 @@ class Autograd {
for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) {
if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " "
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " trace id "
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " " << pre_op->trace_id_;
<< pre_op->op_desc_->Type() << " trace id "
<< pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op);
queue.push_back(pre_op);

@ -119,23 +119,32 @@ class VarBase {
var_(var),
grads_(grad),
block_(nullptr),
persistable_(false),
stop_gradient_(stop_gradient),
pre_op_(nullptr),
pre_op_out_name_(),
pre_op_out_idx_(-1) {}
public:
virtual ~VarBase() {
if (block_) {
// LOG(ERROR) << "remove var " << name_;
if (block_ && !persistable_) {
block_->RemoveVar(name_);
}
if (var_) {
delete var_;
var_ = nullptr;
}
if (grads_) {
delete grads_;
grads_ = nullptr;
}
pre_op_ = nullptr;
pre_op_out_idx_ = -1;
}
inline OpBase* PreOp() const { return pre_op_; }
@ -148,6 +157,14 @@ class VarBase {
void RunBackward();
inline void ResetPreOp(OpBase* op) {
if (op == pre_op_) {
// clear pre_op info when op equals to var's pre_op
pre_op_ = nullptr;
pre_op_out_idx_ = -1;
}
}
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
int pre_op_out_idx, bool pre_op_stop_gradient) {
pre_op_ = pre_op;
@ -188,6 +205,7 @@ class VarBase {
VarBase* grads_;
framework::BlockDesc* block_;
bool persistable_;
private:
bool stop_gradient_;
@ -210,13 +228,22 @@ class PYBIND11_HIDDEN OpBase {
backward_hooks_() {}
virtual ~OpBase() {
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
// reset all output vars' pre op
for (auto iter : output_vars_) {
for (VarBase* var : iter.second) {
var->ResetPreOp(this);
}
}
// remove op desc from block desc
if (block_) {
block_->RemoveOpInternal(op_desc_);
}
// release resource
for (framework::OpDesc* desc : grad_op_descs_) {
delete desc;
}
}
std::map<std::string, std::vector<VarBase*>> ApplyGrad();

@ -76,7 +76,8 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
std::map<std::string, VarBase*> vars;
framework::OpDesc* op_desc = op->op_desc_;
VLOG(3) << "tracer tracing " << op_desc->Type();
VLOG(3) << "tracer tracing " << op_desc->Type() << " trace id "
<< op->trace_id_;
op_desc->InferShape(*block);
op_desc->InferVarType(block);
@ -99,11 +100,13 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
if (inp->PreOp() && !inp->IsStopGradient()) {
op->pre_ops_[it.first].push_back(inp->PreOp());
op->pre_ops_out_idx_[it.first].push_back(inp->PreOpOutIdx());
VLOG(3) << "add pre op " << inp->PreOp()->op_desc_->Type();
} else {
op->pre_ops_[it.first].push_back(nullptr);
}
VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
<< inp->var_->IsInitialized();
<< inp->var_->IsInitialized() << " stop_gradient "
<< inp->IsStopGradient();
}
}

@ -180,6 +180,12 @@ PYBIND11_MODULE(core, m) {
self.block_ = block;
},
py::return_value_policy::reference)
.def_property(
"persistable",
[](const imperative::VarBase &self) { return self.persistable_; },
[](imperative::VarBase &self, const bool persistable) {
self.persistable_ = persistable;
})
.def_property(
"desc",
[](const imperative::VarBase &self) { return self.var_desc_; },

@ -386,6 +386,7 @@ class Variable(object):
self._ivar.desc = self.desc
self._ivar.block = block.desc
self._ivar.name = name
self._ivar.persistable = persistable
if persistable:
self.block.vars[name] = self
else:

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save