|
|
|
@ -43,24 +43,31 @@ void CreateGradOp(const framework::OpDesc& op_desc,
|
|
|
|
|
|
|
|
|
|
class Tracer {
|
|
|
|
|
public:
|
|
|
|
|
Tracer() {}
|
|
|
|
|
explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
|
|
|
|
|
root_scope_ = new framework::Scope();
|
|
|
|
|
scopes_[root_block_] = root_scope_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual ~Tracer() { delete root_scope_; }
|
|
|
|
|
|
|
|
|
|
void Trace(OpBase* op, const std::vector<VarBase*>& inputs,
|
|
|
|
|
const std::vector<VarBase*>& outputs) {
|
|
|
|
|
const std::vector<VarBase*>& outputs,
|
|
|
|
|
framework::BlockDesc* block) {
|
|
|
|
|
framework::Scope* scope = GetScope(block);
|
|
|
|
|
framework::OpDesc* op_desc = op->op_desc_;
|
|
|
|
|
LOG(ERROR) << "tracer tracing " << op_desc->Type();
|
|
|
|
|
op_desc->InferShape(*block_);
|
|
|
|
|
op_desc->InferVarType(block_);
|
|
|
|
|
op_desc->InferShape(*block);
|
|
|
|
|
op_desc->InferVarType(block);
|
|
|
|
|
std::unique_ptr<framework::OperatorBase> op_base =
|
|
|
|
|
framework::OpRegistry::CreateOp(*op_desc);
|
|
|
|
|
|
|
|
|
|
*op->input_vars_ = inputs;
|
|
|
|
|
for (VarBase* input : inputs) {
|
|
|
|
|
const std::string vname = input->var_desc_->Name();
|
|
|
|
|
framework::Variable* var = scope_->Var(vname);
|
|
|
|
|
framework::Variable* var = scope->Var(vname);
|
|
|
|
|
input->var_ = var;
|
|
|
|
|
if (!var->IsInitialized()) {
|
|
|
|
|
framework::VarDesc* var_desc = block_->FindVar(vname);
|
|
|
|
|
framework::VarDesc* var_desc = block->FindVar(vname);
|
|
|
|
|
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
} else {
|
|
|
|
@ -78,9 +85,9 @@ class Tracer {
|
|
|
|
|
*op->output_vars_ = outputs;
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
|
|
|
const std::string vname = outputs[i]->var_desc_->Name();
|
|
|
|
|
framework::Variable* var = scope_->Var(vname);
|
|
|
|
|
framework::Variable* var = scope->Var(vname);
|
|
|
|
|
if (!var->IsInitialized()) {
|
|
|
|
|
framework::VarDesc* var_desc = block_->FindVar(vname);
|
|
|
|
|
framework::VarDesc* var_desc = block->FindVar(vname);
|
|
|
|
|
if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
|
|
|
|
|
var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
} else {
|
|
|
|
@ -91,28 +98,30 @@ class Tracer {
|
|
|
|
|
outputs[i]->pre_op_ = op;
|
|
|
|
|
outputs[i]->pre_op_out_idx_ = i;
|
|
|
|
|
}
|
|
|
|
|
op_base->Run(*scope_, platform::CPUPlace());
|
|
|
|
|
|
|
|
|
|
op_base->Run(*scope, platform::CPUPlace());
|
|
|
|
|
framework::OpDesc* grad_op_desc;
|
|
|
|
|
auto grad_to_var = new std::unordered_map<std::string, std::string>();
|
|
|
|
|
CreateGradOp(*op_desc, {}, {block_}, &grad_op_desc, grad_to_var);
|
|
|
|
|
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
|
|
|
|
|
op->grad_op_desc_ = grad_op_desc;
|
|
|
|
|
op->grad_to_var_ = grad_to_var;
|
|
|
|
|
op->block_ = block_;
|
|
|
|
|
op->block_ = block;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetScope(framework::Scope* scope) { scope_ = scope; }
|
|
|
|
|
|
|
|
|
|
void SetBlock(framework::BlockDesc* block) { block_ = block; }
|
|
|
|
|
|
|
|
|
|
framework::Scope* Scope() const { return scope_; }
|
|
|
|
|
|
|
|
|
|
framework::BlockDesc* Block() const { return block_; }
|
|
|
|
|
framework::Scope* GetScope(framework::BlockDesc* block) {
|
|
|
|
|
if (scopes_.find(block) != scopes_.end()) {
|
|
|
|
|
return scopes_.at(block);
|
|
|
|
|
}
|
|
|
|
|
framework::BlockDesc* parent_block = block->ParentBlock();
|
|
|
|
|
PADDLE_ENFORCE(scopes_.find(parent_block) != scopes_.end());
|
|
|
|
|
framework::Scope* scope = &scopes_[parent_block]->NewScope();
|
|
|
|
|
scopes_[block] = scope;
|
|
|
|
|
return scope;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
framework::BlockDesc* block_;
|
|
|
|
|
framework::Scope* scope_;
|
|
|
|
|
std::vector<Runnable*> runnables_;
|
|
|
|
|
std::map<framework::BlockDesc*, framework::Scope*> scopes_;
|
|
|
|
|
framework::BlockDesc* root_block_;
|
|
|
|
|
framework::Scope* root_scope_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace imperative
|
|
|
|
|