Merge pull request #9975 from reyoung/feature/VarHandleCtor

Using constructor for VarHandle
wangkuiyi-patch-2
Yu Yang 7 years ago committed by GitHub
commit 9b60d0decb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
local_scopes_[input_scope_idx]->Var("input"); local_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
auto* in_var_handle =
vars_.emplace_back(new VarHandle()); new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(in_var_handle);
in_var_handle->place_ = gpu_list_[input_scope_idx];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = input_scope_idx;
in_var_handle->generated_op_ = nullptr;
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle()); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(out_var_handle);
out_var_handle->place_ = gpu_list_[j];
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = j;
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }

@ -79,13 +79,8 @@ struct TestGatherOpHandle {
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle()); auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(in_var_handle);
in_var_handle->place_ = gpu_list_[j];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = j;
in_var_handle->generated_op_ = nullptr;
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
@ -97,12 +92,9 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
vars_.emplace_back(new VarHandle()); auto* out_var_handle =
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get()); new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
out_var_handle->place_ = gpu_list_[input_scope_idx]; vars_.emplace_back(out_var_handle);
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = input_scope_idx;
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var

@ -177,13 +177,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto &prev_grad = vars[vars.size() - 1]; auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
vars.emplace_back(new VarHandle); auto var = new VarHandle(vars.size() - 1, i, og, p);
auto &var = vars.back(); vars.emplace_back(var);
var->place_ = p; op_handle->AddOutput(var);
var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(var.get());
} }
#else #else
PADDLE_ENFORCE("Not implemented"); PADDLE_ENFORCE("Not implemented");

@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle); var = new VarHandle(0, place_offset, each_var_name, place);
auto &init_var = var_holder[0]; var_holder.emplace_back(var);
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
} }
@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size(); size_t version = vars.size();
vars.emplace_back(new VarHandle()); auto var = new VarHandle(version, place_offset, each_var_name, place);
auto &var = vars.back(); vars.emplace_back(var);
var->version_ = version; op_handle->AddOutput(var);
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
} }
template <typename Callback> template <typename Callback>

@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
@ -33,10 +34,10 @@ struct VarHandleBase {
// The operator who generate this variable. nullptr if the variable // The operator who generate this variable. nullptr if the variable
// is a root node. // is a root node.
OpHandleBase *generated_op_; OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready. // Operators which depend on this variable ready.
std::unordered_set<OpHandleBase *> pending_ops_; std::unordered_set<OpHandleBase*> pending_ops_;
}; };
// VarHandle is actually a single version of Runtime Variable. // VarHandle is actually a single version of Runtime Variable.
@ -47,6 +48,13 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
// version field currently is not used, however, just store the version to // version field currently is not used, however, just store the version to
// debug easily. // debug easily.
size_t version_; size_t version_;

Loading…
Cancel
Save