|
|
|
@ -119,23 +119,32 @@ class VarBase {
|
|
|
|
|
var_(var),
|
|
|
|
|
grads_(grad),
|
|
|
|
|
block_(nullptr),
|
|
|
|
|
persistable_(false),
|
|
|
|
|
stop_gradient_(stop_gradient),
|
|
|
|
|
pre_op_(nullptr),
|
|
|
|
|
pre_op_out_name_(),
|
|
|
|
|
pre_op_out_idx_(-1) {}
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
virtual ~VarBase() {
|
|
|
|
|
if (block_) {
|
|
|
|
|
// LOG(ERROR) << "remove var " << name_;
|
|
|
|
|
|
|
|
|
|
if (block_ && !persistable_) {
|
|
|
|
|
block_->RemoveVar(name_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (var_) {
|
|
|
|
|
delete var_;
|
|
|
|
|
var_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (grads_) {
|
|
|
|
|
delete grads_;
|
|
|
|
|
grads_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pre_op_ = nullptr;
|
|
|
|
|
pre_op_out_idx_ = -1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inline OpBase* PreOp() const { return pre_op_; }
|
|
|
|
@ -148,6 +157,14 @@ class VarBase {
|
|
|
|
|
|
|
|
|
|
void RunBackward();
|
|
|
|
|
|
|
|
|
|
inline void ResetPreOp(OpBase* op) {
|
|
|
|
|
if (op == pre_op_) {
|
|
|
|
|
// clear pre_op info when op equals to var's pre_op
|
|
|
|
|
pre_op_ = nullptr;
|
|
|
|
|
pre_op_out_idx_ = -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TrackPreOp(OpBase* pre_op, const std::string& pre_op_out_name,
|
|
|
|
|
int pre_op_out_idx, bool pre_op_stop_gradient) {
|
|
|
|
|
pre_op_ = pre_op;
|
|
|
|
@ -188,6 +205,7 @@ class VarBase {
|
|
|
|
|
VarBase* grads_;
|
|
|
|
|
|
|
|
|
|
framework::BlockDesc* block_;
|
|
|
|
|
bool persistable_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool stop_gradient_;
|
|
|
|
@ -210,13 +228,22 @@ class PYBIND11_HIDDEN OpBase {
|
|
|
|
|
backward_hooks_() {}
|
|
|
|
|
|
|
|
|
|
virtual ~OpBase() {
|
|
|
|
|
for (framework::OpDesc* desc : grad_op_descs_) {
|
|
|
|
|
delete desc;
|
|
|
|
|
// reset all output vars' pre op
|
|
|
|
|
for (auto iter : output_vars_) {
|
|
|
|
|
for (VarBase* var : iter.second) {
|
|
|
|
|
var->ResetPreOp(this);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// remove op desc from block desc
|
|
|
|
|
if (block_) {
|
|
|
|
|
block_->RemoveOpInternal(op_desc_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// release resource
|
|
|
|
|
for (framework::OpDesc* desc : grad_op_descs_) {
|
|
|
|
|
delete desc;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::vector<VarBase*>> ApplyGrad();
|
|
|
|
|