merge branch, test=develop

revert-16734-refine/test_imperative_transformer
lujun 6 years ago
commit e97ded835a

@ -75,6 +75,7 @@ option(WITH_INFERENCE_API_TEST "Test fluid inference high-level api interface"
option(WITH_SYSTEM_BLAS "Use system blas library" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_FAST_MATH "Make use of fast math library, might affect the precision to some extent" ON)
option(WITH_WBAES "Compile PaddlePaddle with WBAES support" ON)
# PY_VERSION
if(NOT PY_VERSION)
@ -148,6 +149,7 @@ include(external/dlpack)
include(external/snappy) # download snappy
include(external/snappystream) # download snappystream
include(external/warpctc) # download, build, install warpctc
include(external/wbaes) # download wbaes
if (NOT WIN32)
# there is no official support of nccl, cupti in windows

@ -157,3 +157,7 @@ endif(WITH_BRPC_RDMA)
if(ON_INFER)
add_definitions(-DPADDLE_ON_INFERENCE)
endif(ON_INFER)
if(WITH_WBAES)
add_definitions(-DPADDLE_WITH_WBAES)
endif(WITH_WBAES)

@ -0,0 +1,71 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IF(NOT ${WITH_WBAES})
return()
ENDIF(NOT ${WITH_WBAES})
INCLUDE(ExternalProject)
SET(WBAES_DST_DIR "wbaes")
SET(WBAES_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
SET(WBAES_INSTALL_DIR ${WBAES_INSTALL_ROOT}/${WBAES_DST_DIR})
SET(WBAES_ROOT ${WBAES_INSTALL_DIR})
SET(WBAES_INC_DIR ${WBAES_ROOT}/include)
SET(WBAES_LIB_DIR ${WBAES_ROOT}/lib)
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${WBAES_ROOT}/lib")
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
IF(APPLE)
SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE)
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.mac.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.dylib)
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dylib)
ELSEIF(WIN32)
SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE)
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.windows-x64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.lib)
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dll)
ELSE()
SET(WBAES_TAG "v1.0.2" CACHE STRING "" FORCE)
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.linux-x86_64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.so)
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.so)
ENDIF()
SET(WBAES_PROJECT "extern_wbaes")
MESSAGE(STATUS "WBAES_URL: ${WBAES_URL}, WBAES_LIB: ${WBAES_LIB}")
SET(WBAES_SOURCE_DIR "${THIRD_PARTY_PATH}/wbaes")
SET(WBAES_DOWNLOAD_DIR "${WBAES_SOURCE_DIR}/src/${WBAES_PROJECT}")
ExternalProject_Add(
${WBAES_PROJECT}
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${WBAES_SOURCE_DIR}
URL ${WBAES_URL}
DOWNLOAD_DIR ${WBAES_DOWNLOAD_DIR}
DOWNLOAD_NO_PROGRESS 1
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/include ${WBAES_INC_DIR} &&
${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/lib ${WBAES_LIB_DIR}
)
INCLUDE_DIRECTORIES(${WBAES_INC_DIR})
ADD_LIBRARY(wbaes SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_LOCATION ${WBAES_LIB})
SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_NO_SONAME 1)
ADD_DEPENDENCIES(wbaes ${WBAES_PROJECT})

@ -264,6 +264,14 @@ function(cc_library TARGET_NAME)
list(REMOVE_ITEM cc_library_DEPS warpctc)
add_dependencies(${TARGET_NAME} warpctc)
endif()
# Only deps libwbaes.so, not link
if("${cc_library_DEPS};" MATCHES "wbaes;")
list(REMOVE_ITEM cc_library_DEPS wbaes)
if(NOT "${TARGET_NAME}" MATCHES "dynload_wbaes")
list(APPEND cc_library_DEPS dynload_wbaes)
endif()
add_dependencies(${TARGET_NAME} wbaes)
endif()
# Only deps libmklml.so, not link
if("${cc_library_DEPS};" MATCHES "mklml;")
list(REMOVE_ITEM cc_library_DEPS mklml)

@ -170,6 +170,14 @@ copy(snappystream_lib
DSTS ${dst_dir} ${dst_dir}/lib
DEPS snappystream)
if (WITH_WBAES)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/wbaes")
copy(wbaes_lib
SRCS ${WBAES_INC_DIR} ${WBAES_LIB}
DSTS ${dst_dir} ${dst_dir}/lib
DEPS wbaes)
endif ()
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib")
copy(zlib_lib
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}

@ -196,7 +196,7 @@ endif()
target_link_libraries(executor while_op_helper executor_gc_helper)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)

@ -96,6 +96,12 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
cc_library(parallel_ssa_graph_executor SRCS parallel_ssa_graph_executor.cc DEPS threaded_ssa_graph_executor)
set(ASYNC_SSA_GRAPH_EXECUTOR_DEPS threaded_ssa_graph_executor)
if(WITH_DISTRIBUTE)
list(APPEND ASYNC_SSA_GRAPH_EXECUTOR_DEPS communicator)
endif()
cc_library(async_ssa_graph_executor SRCS async_ssa_graph_executor.cc DEPS ${ASYNC_SSA_GRAPH_EXECUTOR_DEPS})
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory

@ -0,0 +1,203 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/communicator.h"
#endif
namespace paddle {
namespace framework {
namespace details {
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
Scope *scope) {
VLOG(3) << "NewTempScopeAndInitVars";
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
// get RpcContext and remote send and recv op
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
#ifdef PADDLE_WITH_DISTRIBUTE
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto i = 0; i < graphs.size(); ++i) {
std::vector<ir::Node *> nodes_to_delete;
for (auto &node : graphs[i]->Nodes()) {
VLOG(3) << "node name " << node->Name();
if (node && node->IsOp()) {
if (node->Name() == "send") {
auto send_var_name = node->Op()->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("send_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
auto height_section = boost::get<std::vector<int64_t>>(
node->Op()->GetNullableAttr("sections"));
send_varname_to_ctx[send_var_name] =
operators::distributed::RpcContext(send_var_name, send_varnames,
epmap, height_section);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
auto recv_var_name = node->Op()->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("recv_varnames"));
auto epmap = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr("epmap"));
recv_varname_to_ctx[recv_var_name] =
operators::distributed::RpcContext(recv_var_name, recv_varnames,
epmap, {});
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
}
}
}
}
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
}
#endif
}
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
: strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)),
graphs_(std::move(graphs)) {
VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
? 1UL
: strategy_.num_threads_ / places_.size();
VLOG(1) << "set num_threads: " << strategy_.num_threads_
<< " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
}
for (auto &node : graphs_[0]->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos_.emplace_back();
var_infos_.back().name_ = node->Var()->Name();
var_infos_.back().type_ = node->Var()->GetType();
var_infos_.back().persistable_ = node->Var()->Persistable();
}
}
for (auto *scope : local_scopes_) {
NewTempScopeAndInitVars(var_infos_, scope);
}
ProcessGraph(graphs_, local_scopes_[0]);
}
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
for (size_t i = 1; i < places_.size(); ++i) {
auto call = [this, i]() -> void {
VLOG(3) << "start off python thread " << i;
try {
while (true) {
executors_[i]->Run({});
}
} catch (...) {
exception_holder_.Catch(std::current_exception());
VLOG(3) << "get exception type = " << exception_holder_.Type();
}
VLOG(3) << "thread " << i << " exited!";
};
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
}
}
void AsyncSSAGraphExecutor::HandleException() {
if (exception_holder_.IsCaught()) {
for (auto &f : run_futures_) {
VLOG(3) << "wait future";
f.wait();
}
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
run_futures_.clear();
exception_holder_.ReThrow();
}
}
FeedFetchList AsyncSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
exception_holder_.Clear();
StartOffPythonTrainLoop();
}
if (places_.size() == 1) {
exception_holder_.Clear();
} else {
HandleException();
}
FeedFetchList fetch_data;
fetch_data.reserve(fetch_tensors.size());
try {
fetch_data = executors_[0]->Run(fetch_tensors);
} catch (...) {
exception_holder_.Catch(std::current_exception());
}
HandleException();
FeedFetchList ret;
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
std::vector<const LoDTensor *> lodtensor_ptrs;
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
}
return ret;
}
} // namespace details
} // namespace framework
} // namespace paddle

@ -0,0 +1,65 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
namespace paddle {
namespace framework {
namespace details {
struct VarInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
public:
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
std::vector<ir::Graph *> graphs);
~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
private:
void StartOffPythonTrainLoop();
void HandleException();
private:
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_;
std::vector<ir::Graph *> graphs_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_;
std::vector<std::future<void>> run_futures_;
std::vector<VarInfo> var_infos_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -184,8 +184,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Convert graph to run on multi-devices.
void AppendMultiDevPass(const BuildStrategy &strategy) {
ir::Pass *multi_devices_pass = nullptr;
if (strategy.is_distribution_) {
VLOG(10) << "Add dist_multi_devices_pass";
if (strategy_.async_mode_) {
multi_devices_pass = AppendPass("async_multi_devices_pass").get();
} else if (strategy_.is_distribution_) {
VLOG(10)
<< "Add dist_multi_devices_pass, multi device parameter server mode";
multi_devices_pass = AppendPass("dist_multi_devices_pass").get();
} else {
if (strategy.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
@ -234,10 +238,12 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
#else
const bool use_cuda) const {
#endif
VLOG(3) << "apply all passes";
// Create a default one if not finalized by user.
CreatePassesFromStrategy(false);
for (std::shared_ptr<ir::Pass> &pass : pass_builder_->AllPasses()) {
VLOG(3) << "apply " << pass->Type();
if (IsMultiDevPass(pass->Type())) {
pass->Erase(kPlaces);
pass->SetNotOwned<const std::vector<platform::Place>>(kPlaces, &places);
@ -293,6 +299,7 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph,
graph = pass->Apply(graph);
VLOG(3) << "Finish Apply Pass " << pass->Type();
}
VLOG(3) << "All Passes Applied";
return graph;
}

@ -97,6 +97,7 @@ struct BuildStrategy {
// num_trainers is 1, so the current fields of build_strategy doesn't tell if
// it's distributed model.
bool is_distribution_{false};
bool async_mode_{false};
int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;

@ -14,6 +14,9 @@
#pragma once
#include <memory>
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
@ -64,6 +67,21 @@ class ExceptionHolder {
ClearImpl();
}
std::string Type() {
std::lock_guard<std::mutex> lock(mu_);
switch (type_) {
case kNone:
return "None";
case kEnforceNotMet: {
return "EnforceNotMet";
}
case kEOF: {
return "EOF";
}
}
return "unknown";
}
private:
void ClearImpl() {
exception_.reset();

@ -31,6 +31,8 @@ struct ExecutionStrategy {
size_t num_iteration_per_drop_scope_{1};
ExecutorType type_{kDefault};
bool dry_run_{false};
size_t num_iteration_per_run_{1}; // only use with async_ssa_graph_executor
// and pyreader with data queue
};
} // namespace details

@ -198,8 +198,22 @@ void MultiDevSSAGraphBuilderBase::ApplyImpl(ir::Graph *graph) const {
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
auto backward_vars =
@ -256,6 +270,8 @@ void MultiDevSSAGraphBuilderBase::InsertScaleLossGradOp(
break;
}
VLOG(3) << "loss_scale: " << loss_scale;
if (loss_scale) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
@ -407,7 +423,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id));
@ -494,9 +510,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
}
}
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result,
const std::string &og,
int dst_dev_id) const {
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
@ -774,6 +789,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
}
@ -781,6 +798,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else {
int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id);
@ -961,6 +979,7 @@ bool DistSSAGraphBuilder::IsEncoded(const std::string &p_name) const {
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name,
const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0;
switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce:
@ -1049,3 +1068,5 @@ REGISTER_MULTI_DEVICES_PASS(
paddle::framework::details::AllReduceSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(dist_multi_devices_pass,
paddle::framework::details::DistSSAGraphBuilder);
REGISTER_MULTI_DEVICES_PASS(async_multi_devices_pass,
paddle::framework::details::AsyncSSAGraphBuilder);

@ -56,8 +56,8 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
bool UseGPU() const;
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const;
virtual bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const;
bool IsScaleLossOp(ir::Node *node) const;
@ -70,10 +70,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const;
@ -115,6 +115,35 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
virtual void InsertPostprocessOps(ir::Graph *result) const {}
};
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const override {}
bool NeedCollectiveForGrad(const std::string &grad_name,
std::vector<ir::Node *> ops) const {
return false;
}
bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
if (node->Op()->Type() == "recv") {
VLOG(1) << "set recv op do_not_run to true";
node->Op()->SetAttr("do_not_run", true);
node->Op()->Flush();
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
// in async_mode, we do not need remote prefetch, because communicator
// will do async parameter recv.
VLOG(1) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
node->Op()->Flush();
}
return false;
}
void InsertPostprocessOps(ir::Graph *result) const override {}
};
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected:
int GetVarDeviceID(const std::string &varname) const;

@ -16,6 +16,7 @@ limitations under the License. */
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -183,6 +184,10 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker();
};
info->use_default_grad_op_desc_maker_ =
std::is_base_of<DefaultGradOpDescMaker<true>, T>::value ||
std::is_base_of<DefaultGradOpDescMaker<false>, T>::value;
}
};

@ -31,11 +31,23 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
prepare_pool_(1),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr) {
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
for (auto *node : graph_->Nodes()) {
if (node->IsOp() && node->Name() == "read") {
read_op_num++;
}
}
if (read_op_num == 0) {
LOG(WARNING) << "when num_iteration_per_run_ is larger then 1, the model "
"should use pyreader to feed data!";
}
}
PrepareOpDeps();
CopyOpDeps();
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
inline FeedFetchList ThreadedSSAGraphExecutor::RunImpl(
const std::vector<std::string> &fetch_tensors) {
std::unique_ptr<platform::RecordEvent> event(
new platform::RecordEvent("ThreadedSSAGraphExecutorPrepare"));
@ -84,6 +96,8 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars->PopAll(1, &timeout);
if (timeout) {
if (exception_holder_.IsCaught()) {
VLOG(3) << "caught exception " << exception_holder_.Type()
<< ", rethrow it";
for (auto &run_op_future : run_op_futures_) {
run_op_future.wait();
}
@ -114,6 +128,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
return fetch_data;
}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
for (size_t j = 0; j < strategy_.num_iteration_per_run_ - 1; ++j) {
RunImpl({});
}
return RunImpl(fetch_tensors);
}
void ThreadedSSAGraphExecutor::InsertFetchOps(
const std::vector<std::string> &fetch_tensors,
std::vector<FetchOpHandle *> *fetch_ops,

@ -23,7 +23,9 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include "ThreadPool.h" // ThreadPool in thrird party
#include <ThreadPool.h> // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
@ -59,6 +61,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() final = default;
private:
inline FeedFetchList RunImpl(const std::vector<std::string> &fetch_tensors);
void RunOp(const std::shared_ptr<BlockingQueue<VarHandleBase *>> &ready_var_q,
details::OpHandleBase *op);

@ -147,7 +147,7 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDesc>> operator()() const final {
std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply());
return retv;
@ -158,14 +158,14 @@ class SingleGradOpDescMaker : public GradOpDescMakerBase {
};
template <bool DropEmptyIG = true>
class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
class DefaultGradOpDescMaker final : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<OpDesc> Apply() const {
std::unique_ptr<OpDesc> Apply() const final {
auto* grad = new OpDesc();
grad->SetType(this->GradOpType());
grad->SetType(this->ForwardOpType() + "_grad");
for (auto& input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
@ -182,18 +182,12 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
return std::unique_ptr<OpDesc>(grad);
}
virtual std::string GradOpType() const {
return this->ForwardOpType() + "_grad";
}
};
class EmptyGradOpMaker : public GradOpDescMakerBase {
class EmptyGradOpMaker final : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDesc>> operator()() const override {
return {};
}
std::vector<std::unique_ptr<OpDesc>> operator()() const final { return {}; }
};
} // namespace framework

@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/pass.h"
#include <memory>
#include <utility>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace paddle {
namespace framework {
namespace ir {
Graph* Pass::Apply(Graph* graph) const {
PADDLE_ENFORCE(graph, "graph passed to Pass::Apply() cannot be empty.");
for (const std::string& attr : required_pass_attrs_) {

@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_info.h"
#include <set>
#include <string>
#include <vector>
namespace paddle {
namespace framework {
@ -24,5 +27,17 @@ OpInfoMap& OpInfoMap::Instance() {
static OpInfoMap g_op_info_map;
return g_op_info_map;
}
std::vector<std::string> OpInfoMap::GetUseDefaultGradOpDescMakerOps() const {
// Use set to sort op names
std::set<std::string> result_ops;
for (auto& pair : map_) {
if (pair.second.use_default_grad_op_desc_maker_) {
result_ops.insert(pair.first);
}
}
return std::vector<std::string>(result_ops.begin(), result_ops.end());
}
} // namespace framework
} // namespace paddle

@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
@ -42,6 +43,10 @@ struct OpInfo {
InferInplaceOpFN infer_inplace_;
InferNoNeedBufferVarsFN infer_no_need_buffer_vars_;
// NOTE(zjl): this flag is added to check whether
// the grad maker is the default one.
bool use_default_grad_op_desc_maker_{false};
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
}
@ -105,6 +110,8 @@ class OpInfoMap {
std::unordered_map<std::string, OpInfo>* mutable_map() { return &map_; }
std::vector<std::string> GetUseDefaultGradOpDescMakerOps() const;
private:
OpInfoMap() = default;
std::unordered_map<std::string, OpInfo> map_;

@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/details/all_reduce_deps_pass.h"
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h"
@ -218,6 +219,18 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
}
}
std::vector<ir::Graph *> graphs;
if (build_strategy.async_mode_) {
PADDLE_ENFORCE(!member_->use_cuda_,
"gpu mode does not support async_mode_ now!");
graphs.push_back(graph);
for (int i = 1; i < places.size(); ++i) {
auto *tmp_graph = new ir::Graph(graph->OriginProgram());
async_graphs_.emplace_back(tmp_graph);
graphs.push_back(tmp_graph);
}
}
// FIXME(Yancey1989): parallel graph mode get better performance
// in GPU allreduce distributed training. Need an elegant way to
// choice the execution strategy.
@ -294,19 +307,46 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
if (need_broadcast()) {
BCastParamsToDevices(bcast_vars, build_strategy.trainer_id_);
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
std::vector<ir::Graph *> async_graphs(places.size());
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] =
build_strategy.Apply(graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1,
member_->use_cuda_, member_->nccl_ctxs_.get());
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_, member_->nccl_ctxs_.get());
}
#else
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
if (build_strategy.async_mode_) {
VLOG(3) << "use local async mode";
graph = build_strategy.Apply(graph, {member_->places_[0]}, loss_var_name,
{member_->local_scopes_[0]}, 1,
member_->use_cuda_);
for (int i = 1; i < member_->places_.size(); ++i) {
graphs[i] = build_strategy.Apply(
graphs[i], {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, 1, member_->use_cuda_);
async_graphs[i] = graphs[i];
}
} else {
graph = build_strategy.Apply(graph, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
}
#endif
auto max_memory_size = GetEagerDeletionThreshold();
@ -317,6 +357,8 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
static_cast<size_t>(max_memory_size));
}
async_graphs[0] = graph;
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
@ -344,7 +386,12 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
}
}
if (build_strategy.enable_parallel_graph_) {
if (build_strategy.async_mode_) {
VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, async_graphs));
} else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA
// TODO(Yancey1989): Remove passing in the main_program when
// allreduce_seq_pass doesn't need it as the attr.
@ -356,21 +403,27 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
#endif
} else {
if (exec_strategy.type_ == ExecutionStrategy::kDefault) {
VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, graph));
}
}
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
VLOG(3) << "use ScopeBufferedSSAGraphExecutor";
if (!build_strategy.async_mode_) {
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
}
}
void ParallelExecutor::BCastParamsToDevices(
const std::vector<std::string> &vars, int trainer_id) const {
VLOG(3) << "BCastParamsToDevices";
// the initializing bcast, all vars would be bcast from device(0).
for (auto &var : vars) {
framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var);
@ -425,14 +478,22 @@ void ParallelExecutor::BCastParamsToDevices(
auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
auto copy_memory = [&] {
t->Resize(dims);
t->mutable_data(cpu, main_tensor.type());
paddle::framework::TensorCopy(main_tensor, cpu, t);
};
auto share_memory = [&] { t->ShareDataWith(main_tensor); };
// FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix.
if (member_->build_strategy_.async_mode_) {
share_memory();
} else if (member_->use_all_reduce_ || member_->use_cuda_ ||
var == "@LR_DECAY_COUNTER@") {
copy_memory();
} else {
t->ShareDataWith(main_tensor);
share_memory();
}
}
}

@ -81,6 +81,7 @@ class ParallelExecutor {
const BuildStrategy &build_strategy) const;
ParallelExecutorPrivate *member_;
std::vector<std::unique_ptr<ir::Graph>> async_graphs_;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ncclUniqueId> local_nccl_id_;
#endif

@ -69,6 +69,9 @@ void ReaderBase::Start() {
ReaderBase::~ReaderBase() {}
DecoratedReader::~DecoratedReader() { reader_->Shutdown(); }
DecoratedReader::~DecoratedReader() {
VLOG(1) << "~DecoratedReader";
reader_->Shutdown();
}
} // namespace framework
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save