Feature/buffer_shared_inplace (#17911)

* feature/buffer_shared_inplace, test=develop

* refine code, test=develop

* fix elementwise_add op cpu inplace and sum inplace bug, test=develop

* add unittest and debug log, test=develop

* fix parallel_executor scope bug, polish code, test=develop

* fix sum op, activation op, single_in_place_inference bug, test=develop

* remove kLocalExecScopeName, test=develop

* fix unittest,test=develop

* fix out_var first version bug, test=develop

* follow comments,test=develop
sum_op
Zeng Jinle 6 years ago committed by GitHub
parent 1c10dac4f2
commit d3003a1620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -59,7 +59,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(eager_deletion_op_handle SRCS eager_deletion_op_handle.cc DEPS lod_tensor selected_rows reference_count_pass_helper)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass)
cc_library(share_tensor_buffer_op_handle SRCS share_tensor_buffer_op_handle.cc DEPS op_handle_base scope)
set(SSA_GRAPH_EXECUTOR_DEPS graph framework_proto sequential_execution_pass modify_op_lock_and_record_event_pass all_reduce_deps_pass reference_count_pass eager_deletion_pass memory_optimize_pass inplace_op_pass buffer_shared_inplace_op_pass)
if (WITH_GPU)
list(APPEND SSA_GRAPH_EXECUTOR_DEPS reference_count_pass)
endif()

@ -99,10 +99,9 @@ void AllReduceOpHandle::RunImpl() {
std::vector<const LoDTensor *> lod_tensors;
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &local_scope = local_exec_scopes_[i];
auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
local_scope->FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor);
VLOG(10) << "place:" << i << ", input_name:" << in_var_handles[i]->name()
<< ", out_name:" << out_var_handles[i]->name();
@ -140,9 +139,7 @@ void AllReduceOpHandle::RunImpl() {
PADDLE_THROW("Not compiled with CUDA");
#endif
} else { // Special handle CPU only Operator's gradient. Like CRF
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
auto &trg = *this->local_exec_scopes_[0]
->FindVar(out_var_handles[0]->name())
->GetMutable<framework::LoDTensor>();
@ -151,10 +148,9 @@ void AllReduceOpHandle::RunImpl() {
VisitDataType(lod_tensors[0]->type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &scope = local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name());
auto *var = scope->FindVar(out_var_handles[i]->name());
auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {

@ -49,6 +49,9 @@ class AllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))

@ -24,22 +24,20 @@ namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
inline void InitVarsInScope(const std::vector<VarInfo> &var_infos, Scope *scope,
Scope *local_scope) {
VLOG(3) << "InitVarsInScope";
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
auto *var = scope->FindVar(info.name_);
if (var != nullptr) {
VLOG(2) << info.name_
<< " has been initialized beforehand in global scope, skipped";
continue;
}
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
InitializeVariable(local_scope->Var(info.name_), info.type_);
}
}
}
@ -101,14 +99,17 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
local_exec_scopes_(local_exec_scopes),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(local_scopes_.size(), local_exec_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
@ -118,7 +119,8 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
strategy_, {local_scopes_[i]}, {local_exec_scopes_[i]}, {places_[i]},
graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
@ -129,8 +131,9 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
for (size_t i = 0; i < local_scopes_.size(); ++i) {
InitVarsInScope(var_infos_, local_scopes_[i], local_exec_scopes_[i]);
}
ProcessGraph(graphs_, local_scopes_[0]);
}

@ -36,6 +36,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
@ -50,6 +51,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;

@ -40,18 +40,13 @@ void BroadcastOpHandle::RunImpl() {
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
BroadcastOneVar(*in_var_handle, out_var_handles, local_exec_scopes_);
}
void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
const std::vector<Scope *> &var_scopes) {
auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
PADDLE_ENFORCE_NOT_NULL(in_var);
@ -140,10 +135,7 @@ void BroadcastOpHandle::BroadcastOneVar(
void BroadcastOpHandle::InitOutputValue(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles) const {
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto &var_scopes = local_exec_scopes_;
auto *in_var =
var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());

@ -62,9 +62,11 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);
const std::vector<Scope *> &var_scopes);
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;

@ -14,7 +14,9 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "gtest/gtest.h"
@ -92,14 +94,13 @@ struct TestBroadcastOpHandle {
void InitBroadcastOp(size_t input_scope_idx) {
nodes_.clear();
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t j = 0; j < place_list_.size(); ++j) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
local_scope.Var("out");
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
param_scopes_[input_scope_idx]->Var("input");
@ -122,6 +123,8 @@ struct TestBroadcastOpHandle {
#endif
}
op_handle_->SetLocalExecScopes(scope_map);
nodes_.emplace_back(
ir::CreateNodeForTest("node1", ir::Node::Type::kVariable));
auto* in_var_handle = new VarHandle(nodes_.back().get(), 1, input_scope_idx,

@ -92,16 +92,14 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_relu_depthwise_conv_pass");
}
// NOTE(dzhwinter): A note for automatical inplace.
// 1. modify program desc passes should put
// before inplace pass.
// 2. manually configured inplace should put
// before inplace_pass
// Add automatically inplace.
if (strategy_.enable_inplace_) {
VLOG(1) << "Add inplace_pass";
AppendPass("inplace_pass");
// TODO(zjl): refactor MemoryOptimizePass to fit
// new strategy, which does not need to set
// var.persistable = True
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.enable_inplace_) {
VLOG(5) << "Add inplace_pass";
AppendPass("inplace_pass");
}
}
if (strategy_.fuse_elewise_add_act_ops_) {
@ -160,9 +158,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// the de-fact IR, any reuse on Graph is meaningless.
// A side-effect of that, memory optimize cannot forsee the fetched vars
// , so fetchlist should be set persistable before call the Run interface.
if (strategy_.memory_optimize_) {
VLOG(1) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
if (strategy_.use_legacy_memory_optimize_strategy_) {
if (strategy_.memory_optimize_) {
VLOG(5) << "Add memory_optimize_pass";
AppendPass("memory_optimize_pass");
}
}
// runtime_context_cache pass should be the last pass to enable the attr of

@ -114,7 +114,12 @@ struct BuildStrategy {
// it is not appropriate, because kStaleProgramOpDescs will be removed in the
// near future.
bool memory_optimize_{false};
bool enable_inplace_{false};
// Turn on inplace by default.
bool enable_inplace_{true};
// TODO(zjl): Remove this flag when MemoryOptimizePass is refactored
bool use_legacy_memory_optimize_strategy_{false};
// FIXME(zcd): is_distribution_ is a temporary field, because in pserver mode,
// num_trainers is 1, so the current fields of build_strategy doesn't tell if

@ -31,9 +31,7 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
if (is_lock_and_record_event_free_) {
run_func();

@ -38,6 +38,8 @@ class ComputationOpHandle : public OpHandleBase {
const Scope *GetScope() const { return scope_; }
Scope *GetScope() { return scope_; }
const platform::Place &GetPlace() const { return place_; }
void SetLockAndRecordEventFree(bool b) { is_lock_and_record_event_free_ = b; }
@ -49,6 +51,8 @@ class ComputationOpHandle : public OpHandleBase {
bool NeedWait(VarHandleBase *in_var) override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
std::unique_ptr<OperatorBase> op_;
Scope *scope_;

@ -17,6 +17,7 @@
#include <utility>
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
@ -30,14 +31,13 @@ namespace framework {
namespace details {
EagerDeletionOpHandle::EagerDeletionOpHandle(
ir::Node *node, const Scope *scope, const platform::Place &place,
const std::unordered_set<std::string> &var_names, GarbageCollector *gc,
ir::AtomicReferenceCountMap *ref_cnts)
ir::Node *node, Scope *scope, const platform::Place &place,
const std::unordered_set<ir::MemOptVarInfo *> &vars, GarbageCollector *gc)
: OpHandleBase(node),
scope_(scope),
var_names_(var_names.begin(), var_names.end()),
gc_(gc),
ref_cnts_(ref_cnts) {
place_(place),
var_infos_(vars.begin(), vars.end()),
gc_(gc) {
#ifdef PADDLE_WITH_CUDA
if (platform::is_gpu_place(place)) {
dev_ctx_ = reinterpret_cast<platform::CUDADeviceContext *>(
@ -50,7 +50,10 @@ EagerDeletionOpHandle::EagerDeletionOpHandle(
}
}
#endif
PADDLE_ENFORCE(!var_names_.empty(), "Var names cannot be empty");
PADDLE_ENFORCE(!vars.empty(), "Var names cannot be empty");
for (auto *var : var_infos_) {
PADDLE_ENFORCE_NOT_NULL(var);
}
}
EagerDeletionOpHandle::~EagerDeletionOpHandle() {
@ -63,30 +66,43 @@ EagerDeletionOpHandle::~EagerDeletionOpHandle() {
#endif
}
void EagerDeletionOpHandle::InitCUDA() {
#ifdef PADDLE_WITH_CUDA
int dev_id =
boost::get<platform::CUDAPlace>(dev_ctxes_.begin()->first).device;
events_[dev_id] = nullptr;
#endif
}
void EagerDeletionOpHandle::CallOnce() {
PADDLE_ENFORCE(vars_.empty(), "vars_ must be initialized here");
Scope *exec_scope = local_exec_scopes_[0];
for (auto *var_info : var_infos_) {
auto *var = exec_scope->FindVar(var_info->Name());
PADDLE_ENFORCE_NOT_NULL(var, "Variable %s should not be nullptr",
var_info->Name());
vars_.emplace_back(var);
}
}
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
void EagerDeletionOpHandle::RunImpl() {
if (vars_.size() != var_infos_.size()) {
CallOnce();
}
platform::RecordEvent record_event(Name());
Scope *exec_scope = nullptr;
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &name : var_names_) {
auto it = ref_cnts_->find(name);
// Reference count has not decreased to 0
if (it == ref_cnts_->end() || it->second.fetch_sub(1) != 1) {
for (size_t i = 0; i < var_infos_.size(); ++i) {
auto *var_info = var_infos_[i];
if (var_info->IsSkipped() || !var_info->DecreaseRefCnt()) {
continue;
}
if (!exec_scope) {
exec_scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
// Var not found
auto *var = exec_scope->FindVar(name);
if (var == nullptr) {
continue;
}
VLOG(2) << "Erase variable " << var_info->Name() << " on " << place_;
VLOG(2) << "Erase variable " << name;
Variable *var = vars_[i];
if (var->IsType<LoDTensor>()) {
garbages.emplace_back(var->GetMutable<LoDTensor>()->MoveMemoryHolder());
@ -100,7 +116,7 @@ void EagerDeletionOpHandle::RunImpl() {
}
} else {
PADDLE_THROW("Type %s of %s is not supported eager deletion",
framework::ToTypeName(var->Type()), name);
framework::ToTypeName(var->Type()), var_info->Name());
}
}

@ -26,15 +26,18 @@ namespace paddle {
namespace framework {
class Scope;
namespace ir {
class MemOptVarInfo;
} // namespace ir
namespace details {
class EagerDeletionOpHandle : public OpHandleBase {
public:
EagerDeletionOpHandle(ir::Node *node, const Scope *scope,
EagerDeletionOpHandle(ir::Node *node, Scope *scope,
const platform::Place &place,
const std::unordered_set<std::string> &var_names,
GarbageCollector *gc,
ir::AtomicReferenceCountMap *ref_cnts);
const std::unordered_set<ir::MemOptVarInfo *> &vars,
GarbageCollector *gc);
~EagerDeletionOpHandle();
@ -50,13 +53,20 @@ class EagerDeletionOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
void InitCUDA() override;
std::vector<Scope *> GetLocalScopes() override { return {scope_}; }
private:
void ClearGarbages(std::deque<std::shared_ptr<memory::Allocation>> *garbages);
const Scope *scope_;
std::vector<std::string> var_names_;
GarbageCollector *gc_; // not own
ir::AtomicReferenceCountMap *ref_cnts_; // not own
void CallOnce();
Scope *scope_;
platform::Place place_;
std::vector<ir::MemOptVarInfo *> var_infos_; // not own
GarbageCollector *gc_; // not own
std::vector<Variable *> vars_;
#ifdef PADDLE_WITH_CUDA
platform::CUDADeviceContext *dev_ctx_{nullptr};
cudaEvent_t event_{nullptr};

@ -28,9 +28,11 @@ namespace details {
FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places, ir::Graph *graph)
: strategy_(strategy),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
places_(places),
graph_(graph),
fetch_ctxs_(places),
@ -143,7 +145,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_);
fetch_ops->emplace_back(op);
for (auto &p : places_) {

@ -33,6 +33,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
FastThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
@ -43,6 +44,7 @@ class FastThreadedSSAGraphExecutor : public SSAGraphExecutor {
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_;
ir::Graph *graph_;

@ -42,9 +42,7 @@ bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
auto run_func = [this]() { op_->Run(*local_exec_scopes_[0], place_); };
if (is_lock_and_record_event_free_) {
run_func();

@ -44,6 +44,8 @@ struct FetchBarrierOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
bool NeedWait(VarHandleBase *in_var) override;
private:

@ -22,11 +22,13 @@ namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes)
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes) {}
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes) {}
FetchOpHandle::~FetchOpHandle() {}
@ -49,14 +51,12 @@ void FetchOpHandle::RunImpl() {
tensors_.resize(inputs_.size());
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
auto &scopes = *local_exec_scopes_;
for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>()
->FindVar(var_handle->name());
auto *var = scope->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_handle->name());

@ -29,7 +29,8 @@ namespace details {
struct FetchOpHandle : public OpHandleBase {
public:
FetchOpHandle(ir::Node *node, FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes);
~FetchOpHandle();
@ -44,12 +45,15 @@ struct FetchOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void WaitInputVarGenerated(const platform::Place &place) override;
private:
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
std::vector<LoDTensor> tensors_;
};

@ -185,9 +185,7 @@ void FusedAllReduceOpHandle::RunImpl() {
} else {
// Special handle CPU only Operator's gradient. Like CRF
auto grad_name = grads_tensor.at(0).at(0).first;
auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName)
->Get<Scope *>()
auto &trg = *this->local_exec_scopes_[0]
->FindVar(grad_name)
->GetMutable<framework::LoDTensor>();
@ -195,9 +193,8 @@ void FusedAllReduceOpHandle::RunImpl() {
ReduceBufferData func(lod_tensor_data, trg.data<void>(), numel);
VisitDataType(trg.type(), func);
for (size_t i = 1; i < local_scopes_.size(); ++i) {
auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
for (size_t i = 1; i < local_exec_scopes_.size(); ++i) {
auto &scope = *local_exec_scopes_[i];
auto &p = places_[i];
auto *var = scope.FindVar(grad_name);
auto *dev_ctx = dev_ctxes_.at(p);
@ -215,8 +212,7 @@ void FusedAllReduceOpHandle::GetGradLoDTensor(
const size_t &scope_idx, const std::vector<VarHandle *> &in_var_handles,
const std::vector<VarHandle *> &out_var_handles,
std::vector<std::pair<std::string, const LoDTensor *>> *grad_tensor) const {
auto *local_scope =
local_scopes_.at(scope_idx)->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *local_scope = local_exec_scopes_[scope_idx];
size_t place_num = places_.size();
for (size_t j = 0; j < in_var_handles.size(); j += place_num) {

@ -52,6 +52,8 @@ struct FusedAllReduceOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return local_scopes_; }
private:
std::vector<Scope *> local_scopes_;
#if !(defined(PADDLE_WITH_CUDA) && !defined(_WIN32))

@ -31,11 +31,6 @@ void FusedBroadcastOpHandle::RunImpl() {
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
@ -44,7 +39,7 @@ void FusedBroadcastOpHandle::RunImpl() {
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
local_exec_scopes_);
}
}

@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include <memory>
#include <unordered_map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/details/broadcast_op_handle_test.h"
@ -27,17 +29,16 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
void InitFusedBroadcastOp(std::vector<size_t> input_scope_idxes) {
nodes_.clear();
// initialize scope and var
std::unordered_map<Scope*, Scope*> scope_map;
for (size_t i = 0; i < place_list_.size(); ++i) {
local_scopes_.push_back(&(g_scope_.NewScope()));
Scope& local_scope = local_scopes_.back()->NewScope();
*local_scopes_.back()
->Var(details::kLocalExecScopeName)
->GetMutable<Scope*>() = &local_scope;
for (size_t j = 0; j < input_scope_idxes.size(); ++j) {
local_scope.Var("out_var" + std::to_string(j));
if (i == j) local_scope.Var("in_var" + std::to_string(j));
}
param_scopes_.emplace_back(&local_scope);
scope_map.emplace(local_scopes_.back(), param_scopes_.back());
}
// create op handle node
@ -60,6 +61,8 @@ struct TestFusedBroadcastOpHandle : TestBroadcastOpHandle {
#endif
}
op_handle_->SetLocalExecScopes(scope_map);
for (size_t i = 0; i < input_scope_idxes.size(); ++i) {
// add input var handle
nodes_.emplace_back(ir::CreateNodeForTest("in_node" + std::to_string(i),

@ -42,10 +42,7 @@ void GatherOpHandle::RunImpl() {
out_var_handle = out_var_handles.front();
}
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
auto &var_scopes = local_exec_scopes_;
auto in_0_handle = in_var_handles[0];
auto pre_in_var =

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save