|
|
@ -14,6 +14,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/imperative/tracer.h"
|
|
|
|
#include "paddle/fluid/imperative/tracer.h"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/enforce.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace imperative {
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
|
@ -31,16 +35,38 @@ void CreateGradOp(const framework::OpDesc& op_desc,
|
|
|
|
*grad_op_desc = grad_op_descs[0].release();
|
|
|
|
*grad_op_desc = grad_op_descs[0].release();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void InitVar(framework::Variable* var, framework::Variable* grad_var) {
|
|
|
|
void InitVar(framework::Variable* var, framework::Variable* grad_var,
|
|
|
|
|
|
|
|
platform::DeviceContext* dev_ctx) {
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(dev_ctx,
|
|
|
|
|
|
|
|
"Could not get valid device from forward op");
|
|
|
|
auto& var_t = var->Get<framework::LoDTensor>();
|
|
|
|
auto& var_t = var->Get<framework::LoDTensor>();
|
|
|
|
float* data =
|
|
|
|
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
|
|
|
|
grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
|
|
|
|
var_t.dims(), dev_ctx->GetPlace());
|
|
|
|
var_t.dims(), platform::CPUPlace());
|
|
|
|
operators::math::set_constant(
|
|
|
|
std::fill(data, data + var_t.numel(), 0.0);
|
|
|
|
*dev_ctx, grad_var->GetMutable<framework::LoDTensor>(), 0.0);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
|
|
|
|
|
|
|
|
platform::Place result = place;
|
|
|
|
|
|
|
|
for (auto it : inputs) {
|
|
|
|
|
|
|
|
for (VarBase* var : it.second) {
|
|
|
|
|
|
|
|
platform::Place tmp_place =
|
|
|
|
|
|
|
|
var->var_->Get<framework::LoDTensor>().place();
|
|
|
|
|
|
|
|
if (!platform::is_same_place(tmp_place, result)) {
|
|
|
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
|
|
|
"Input variable should keep in the same place: %s, but get place: "
|
|
|
|
|
|
|
|
"%s of input %s instead",
|
|
|
|
|
|
|
|
result, tmp_place, it.first);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
const VarBasePtrMap& outputs, framework::BlockDesc* block,
|
|
|
|
const VarBasePtrMap& outputs, framework::BlockDesc* block,
|
|
|
|
|
|
|
|
const platform::Place expected_place,
|
|
|
|
const bool stop_gradient) {
|
|
|
|
const bool stop_gradient) {
|
|
|
|
std::map<std::string, VarBase*> vars;
|
|
|
|
std::map<std::string, VarBase*> vars;
|
|
|
|
|
|
|
|
|
|
|
@ -105,10 +131,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");
|
|
|
|
|
|
|
|
|
|
|
|
framework::Scope scope;
|
|
|
|
framework::Scope scope;
|
|
|
|
platform::CPUPlace place;
|
|
|
|
op->place_ = GetExpectedPlace(expected_place, inputs);
|
|
|
|
PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place);
|
|
|
|
PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
|
|
|
|
p.op.RuntimeInferShape(scope, place, ctx);
|
|
|
|
prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
|
|
|
|
p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx));
|
|
|
|
prepared_op.func(framework::ExecutionContext(
|
|
|
|
|
|
|
|
prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx));
|
|
|
|
|
|
|
|
|
|
|
|
if (!stop_gradient) {
|
|
|
|
if (!stop_gradient) {
|
|
|
|
framework::OpDesc* grad_op_desc;
|
|
|
|
framework::OpDesc* grad_op_desc;
|
|
|
@ -131,7 +158,8 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
InitVar(var->var_, var->grads_->var_,
|
|
|
|
|
|
|
|
prepared_op.GetDeviceContext());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Douts.
|
|
|
|
// Douts.
|
|
|
|
grad_in_vars.push_back(var->grads_->var_);
|
|
|
|
grad_in_vars.push_back(var->grads_->var_);
|
|
|
@ -144,10 +172,13 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
|
|
|
|
for (const std::string& grad_outvar : it.second) {
|
|
|
|
for (const std::string& grad_outvar : it.second) {
|
|
|
|
block->FindRecursiveOrCreateVar(grad_outvar);
|
|
|
|
block->FindRecursiveOrCreateVar(grad_outvar);
|
|
|
|
auto var_it = grad_to_var->find(grad_outvar);
|
|
|
|
auto var_it = grad_to_var->find(grad_outvar);
|
|
|
|
PADDLE_ENFORCE(var_it != grad_to_var->end());
|
|
|
|
PADDLE_ENFORCE(var_it != grad_to_var->end(),
|
|
|
|
|
|
|
|
"Could not found the grad op output var, should this "
|
|
|
|
|
|
|
|
"operator %s's stop gradient be True",
|
|
|
|
|
|
|
|
op_desc->Type());
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
VarBase* var = vars[var_it->second];
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
if (!var->grads_->var_->IsInitialized()) {
|
|
|
|
InitVar(var->var_, var->grads_->var_);
|
|
|
|
InitVar(var->var_, var->grads_->var_, prepared_op.GetDeviceContext());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
grad_out_vars.push_back(var->grads_->var_);
|
|
|
|
grad_out_vars.push_back(var->grads_->var_);
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -189,16 +220,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
|
|
|
|
for (VarBase* out : outputs) {
|
|
|
|
for (VarBase* out : outputs) {
|
|
|
|
grad_input_vars.push_back(out->var_);
|
|
|
|
grad_input_vars.push_back(out->var_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
platform::CPUPlace place;
|
|
|
|
for (VarBase* out : outputs) {
|
|
|
|
for (VarBase* out : outputs) {
|
|
|
|
grad_input_vars.push_back(out->grads_->var_);
|
|
|
|
grad_input_vars.push_back(out->grads_->var_);
|
|
|
|
if (!grad_input_vars.back()->IsInitialized()) {
|
|
|
|
if (!grad_input_vars.back()->IsInitialized()) {
|
|
|
|
InitVar(out->var_, grad_input_vars.back());
|
|
|
|
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
|
|
|
|
|
|
|
|
InitVar(out->var_, grad_input_vars.back(),
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(place));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (const VarBase* inp : inputs) {
|
|
|
|
for (const VarBase* inp : inputs) {
|
|
|
|
grad_output_vars.push_back(inp->grads_->var_);
|
|
|
|
grad_output_vars.push_back(inp->grads_->var_);
|
|
|
|
if (!grad_output_vars.back()->IsInitialized()) {
|
|
|
|
if (!grad_output_vars.back()->IsInitialized()) {
|
|
|
|
InitVar(inp->var_, grad_output_vars.back());
|
|
|
|
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
|
|
|
|
|
|
|
|
InitVar(inp->var_, grad_output_vars.back(),
|
|
|
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(place));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|