|
|
|
@ -115,6 +115,7 @@ framework::Variable* CreateVariable(const std::string& name,
|
|
|
|
|
varname = string::Sprintf("%s@%d", varname, id);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "creating var " << varname;
|
|
|
|
|
VLOG(3) << "creating var " << varname;
|
|
|
|
|
framework::Variable* var = scope->Var(varname);
|
|
|
|
|
framework::LoDTensor* tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
@ -130,13 +131,22 @@ framework::LoDTensor& VarBase::Grad() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void VarBase::ApplyGrad(framework::Scope* scope, Variable* grad) {
|
|
|
|
|
PADDLE_ENFORCE(grad->IsInitialized(), "grad %s must be initialized",
|
|
|
|
|
var_desc_->Name());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(grad->Get<framework::LoDTensor>().IsInitialized(),
|
|
|
|
|
"variable %s has NO gradient, please set stop_gradient to it",
|
|
|
|
|
var_desc_->Name());
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "apply var grad " << var_desc_->Name() << " "
|
|
|
|
|
<< grad->Get<framework::LoDTensor>().data<float>()[0];
|
|
|
|
|
|
|
|
|
|
if (!grads_) {
|
|
|
|
|
grads_ =
|
|
|
|
|
CreateVariable(string::Sprintf("%s@IGrad", var_desc_->Name()),
|
|
|
|
|
var_->Get<framework::LoDTensor>().dims(), 0.0, scope);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AddTo(grad, grads_);
|
|
|
|
|
VLOG(3) << "grad_ after apply var grad " << var_desc_->Name() << " "
|
|
|
|
|
<< grads_->Get<framework::LoDTensor>().data<float>()[0];
|
|
|
|
@ -153,8 +163,9 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
|
|
|
|
|
// grad op inputs can be forward inputs, so not in grad_to_var.
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "op grad in var " << grad_invar;
|
|
|
|
|
block_->FindRecursiveOrCreateVar(grad_invar);
|
|
|
|
|
VLOG(3) << "op grad input var " << grad_invar;
|
|
|
|
|
framework::VarDesc& grad_invar_desc =
|
|
|
|
|
block_->FindRecursiveOrCreateVar(grad_invar);
|
|
|
|
|
framework::Variable* var = scope->Var(grad_invar);
|
|
|
|
|
const std::string& invar = grad_to_var_->at(grad_invar);
|
|
|
|
|
for (VarBase* varbase : *output_vars_) {
|
|
|
|
@ -165,21 +176,33 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
grad_invar_desc.SetShape(
|
|
|
|
|
framework::vectorize(var->Get<framework::LoDTensor>().dims()));
|
|
|
|
|
VLOG(3)
|
|
|
|
|
<< "set op grad var desc's shape size "
|
|
|
|
|
<< framework::vectorize(var->Get<framework::LoDTensor>().dims()).size();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "grad_op_desc_" << grad_op_desc_->Proto()->DebugString();
|
|
|
|
|
|
|
|
|
|
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
|
|
|
|
|
VLOG(3) << "grad outvar " << outvar;
|
|
|
|
|
VLOG(3) << "op grad output var " << outvar;
|
|
|
|
|
block_->FindRecursiveOrCreateVar(outvar);
|
|
|
|
|
framework::Variable* var = scope->Var(outvar);
|
|
|
|
|
if (!var->IsInitialized()) {
|
|
|
|
|
VLOG(3) << "init op grad output var " << outvar;
|
|
|
|
|
framework::VarDesc* var_desc = block_->FindVar(outvar);
|
|
|
|
|
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// framework::Tensor* tensor = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
// tensor->mutable_data(platform::CPUPlace());
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << "tracer doesn't support yet";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "op grad output var " << outvar << " is inited";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
grad_op_desc_->InferShape(*block_);
|
|
|
|
|
grad_op_desc_->InferVarType(block_);
|
|
|
|
|
std::unique_ptr<framework::OperatorBase> opbase =
|
|
|
|
@ -194,11 +217,15 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
|
|
|
|
|
VarBase* origin_var = (*input_vars_)[i];
|
|
|
|
|
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
|
|
|
|
|
Variable* var = scope->FindVar(outvar);
|
|
|
|
|
std::string orig_var = grad_to_var_->at(outvar);
|
|
|
|
|
if (origin_var->var_desc_->Name() != orig_var) {
|
|
|
|
|
if (var->IsInitialized()) {
|
|
|
|
|
VLOG(3) << "get grad op output var " << outvar;
|
|
|
|
|
}
|
|
|
|
|
std::string orig_var_name = grad_to_var_->at(outvar);
|
|
|
|
|
if (origin_var->var_desc_->Name() != orig_var_name ||
|
|
|
|
|
origin_var->stop_gradient_) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var;
|
|
|
|
|
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var_name;
|
|
|
|
|
origin_var->ApplyGrad(scope, var);
|
|
|
|
|
found = true;
|
|
|
|
|
ret.push_back(var);
|
|
|
|
|