|
|
|
@ -126,12 +126,19 @@ class VarBase {
|
|
|
|
|
: var_desc_(nullptr),
|
|
|
|
|
var_(var),
|
|
|
|
|
grads_(grad),
|
|
|
|
|
block_(nullptr),
|
|
|
|
|
stop_gradient_(stop_gradient),
|
|
|
|
|
pre_op_(nullptr),
|
|
|
|
|
pre_op_out_idx_(-1) {}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
virtual ~VarBase() {
|
|
|
|
|
LOG(ERROR) << "remove var " << name_;
|
|
|
|
|
|
|
|
|
|
if (block_) {
|
|
|
|
|
block_->RemoveVar(name_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (var_) {
|
|
|
|
|
delete var_;
|
|
|
|
|
}
|
|
|
|
@ -189,11 +196,14 @@ class VarBase {
|
|
|
|
|
framework::Variable* var_;
|
|
|
|
|
VarBase* grads_;
|
|
|
|
|
|
|
|
|
|
framework::BlockDesc* block_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool stop_gradient_;
|
|
|
|
|
OpBase* pre_op_;
|
|
|
|
|
std::string pre_op_out_name_;
|
|
|
|
|
int pre_op_out_idx_;
|
|
|
|
|
std::string name_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/* The wrapper for OpDesc which holds a OpDesc and a OpDesc of its
|
|
|
|
@ -212,6 +222,12 @@ class OpBase {
|
|
|
|
|
for (framework::OpDesc* desc : grad_op_descs_) {
|
|
|
|
|
delete desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "remove op " << op_desc_->Type() << " id " << trace_id_;
|
|
|
|
|
|
|
|
|
|
if (block_) {
|
|
|
|
|
block_->RemoveOp(trace_id_, trace_id_ + 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
|
|
|
|
|