commit
ad7c1a934f
@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IF(NOT ${WITH_WBAES})
|
||||
return()
|
||||
ENDIF(NOT ${WITH_WBAES})
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
SET(WBAES_DST_DIR "wbaes")
|
||||
SET(WBAES_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(WBAES_INSTALL_DIR ${WBAES_INSTALL_ROOT}/${WBAES_DST_DIR})
|
||||
SET(WBAES_ROOT ${WBAES_INSTALL_DIR})
|
||||
SET(WBAES_INC_DIR ${WBAES_ROOT}/include)
|
||||
SET(WBAES_LIB_DIR ${WBAES_ROOT}/lib)
|
||||
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${WBAES_ROOT}/lib")
|
||||
SET(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
|
||||
|
||||
IF(APPLE)
|
||||
SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE)
|
||||
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.mac.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
|
||||
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.dylib)
|
||||
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dylib)
|
||||
ELSEIF(WIN32)
|
||||
SET(WBAES_TAG "v1.0.0" CACHE STRING "" FORCE)
|
||||
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.windows-x64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
|
||||
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.lib)
|
||||
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.dll)
|
||||
ELSE()
|
||||
SET(WBAES_TAG "v1.0.2" CACHE STRING "" FORCE)
|
||||
SET(WBAES_URL "http://paddlepaddledeps.bj.bcebos.com/wbaes-sdk.linux-x86_64.${WBAES_TAG}.tgz" CACHE STRING "" FORCE)
|
||||
SET(WBAES_LIB ${WBAES_LIB_DIR}/libwbaes.so)
|
||||
SET(WBAES_SHARED_LIB ${WBAES_LIB_DIR}/libwbaes.so)
|
||||
ENDIF()
|
||||
|
||||
SET(WBAES_PROJECT "extern_wbaes")
|
||||
MESSAGE(STATUS "WBAES_URL: ${WBAES_URL}, WBAES_LIB: ${WBAES_LIB}")
|
||||
SET(WBAES_SOURCE_DIR "${THIRD_PARTY_PATH}/wbaes")
|
||||
SET(WBAES_DOWNLOAD_DIR "${WBAES_SOURCE_DIR}/src/${WBAES_PROJECT}")
|
||||
|
||||
ExternalProject_Add(
|
||||
${WBAES_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${WBAES_SOURCE_DIR}
|
||||
URL ${WBAES_URL}
|
||||
DOWNLOAD_DIR ${WBAES_DOWNLOAD_DIR}
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/include ${WBAES_INC_DIR} &&
|
||||
${CMAKE_COMMAND} -E copy_directory ${WBAES_DOWNLOAD_DIR}/lib ${WBAES_LIB_DIR}
|
||||
)
|
||||
|
||||
INCLUDE_DIRECTORIES(${WBAES_INC_DIR})
|
||||
|
||||
ADD_LIBRARY(wbaes SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_LOCATION ${WBAES_LIB})
|
||||
SET_PROPERTY(TARGET wbaes PROPERTY IMPORTED_NO_SONAME 1)
|
||||
ADD_DEPENDENCIES(wbaes ${WBAES_PROJECT})
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,150 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_feed.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Dataset is a abstract class, which defines user interfaces
|
||||
// Example Usage:
|
||||
// Dataset* dataset = DatasetFactory::CreateDataset("InMemoryDataset")
|
||||
// dataset->SetFileList(std::vector<std::string>{"a.txt", "b.txt"})
|
||||
// dataset->SetThreadNum(1)
|
||||
// dataset->CreateReaders();
|
||||
// dataset->SetDataFeedDesc(your_data_feed_desc);
|
||||
// dataset->LoadIntoMemory();
|
||||
// dataset->SetTrainerNum(2);
|
||||
// dataset->GlobalShuffle();
|
||||
class Dataset {
|
||||
public:
|
||||
Dataset() {}
|
||||
virtual ~Dataset() {}
|
||||
// set file list
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
|
||||
// set readers' num
|
||||
virtual void SetThreadNum(int thread_num) = 0;
|
||||
// set workers' num
|
||||
virtual void SetTrainerNum(int trainer_num) = 0;
|
||||
// set fs name and ugi
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi) = 0;
|
||||
// set data fedd desc, which contains:
|
||||
// data feed name, batch size, slots
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
|
||||
// get file list
|
||||
virtual const std::vector<std::string>& GetFileList() = 0;
|
||||
// get thread num
|
||||
virtual int GetThreadNum() = 0;
|
||||
// get worker num
|
||||
virtual int GetTrainerNum() = 0;
|
||||
// get hdfs config
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() = 0;
|
||||
// get data fedd desc
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
|
||||
// get readers, the reader num depend both on thread num
|
||||
// and filelist size
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders() = 0;
|
||||
// register message handler between workers
|
||||
virtual void RegisterClientToClientMsgHandler() = 0;
|
||||
// load all data into memory
|
||||
virtual void LoadIntoMemory() = 0;
|
||||
// release all memory data
|
||||
virtual void ReleaseMemory() = 0;
|
||||
// local shuffle data
|
||||
virtual void LocalShuffle() = 0;
|
||||
// global shuffle data
|
||||
virtual void GlobalShuffle() = 0;
|
||||
// create readers
|
||||
virtual void CreateReaders() = 0;
|
||||
// destroy readers
|
||||
virtual void DestroyReaders() = 0;
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg) = 0;
|
||||
};
|
||||
|
||||
// DatasetImpl is the implementation of Dataset,
|
||||
// it holds memory data if user calls load_into_memory
|
||||
template <typename T>
|
||||
class DatasetImpl : public Dataset {
|
||||
public:
|
||||
DatasetImpl();
|
||||
virtual ~DatasetImpl() {}
|
||||
|
||||
virtual void SetFileList(const std::vector<std::string>& filelist);
|
||||
virtual void SetThreadNum(int thread_num);
|
||||
virtual void SetTrainerNum(int trainer_num);
|
||||
virtual void SetHdfsConfig(const std::string& fs_name,
|
||||
const std::string& fs_ugi);
|
||||
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str);
|
||||
|
||||
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
|
||||
virtual int GetThreadNum() { return thread_num_; }
|
||||
virtual int GetTrainerNum() { return trainer_num_; }
|
||||
virtual std::pair<std::string, std::string> GetHdfsConfig() {
|
||||
return std::make_pair(fs_name_, fs_ugi_);
|
||||
}
|
||||
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
|
||||
return data_feed_desc_;
|
||||
}
|
||||
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
||||
GetReaders();
|
||||
|
||||
virtual void RegisterClientToClientMsgHandler();
|
||||
virtual void LoadIntoMemory();
|
||||
virtual void ReleaseMemory();
|
||||
virtual void LocalShuffle();
|
||||
virtual void GlobalShuffle();
|
||||
virtual void CreateReaders();
|
||||
virtual void DestroyReaders();
|
||||
|
||||
protected:
|
||||
virtual int ReceiveFromClient(int msg_type, int client_id,
|
||||
const std::string& msg);
|
||||
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
|
||||
std::vector<T> memory_data_;
|
||||
std::mutex mutex_for_update_memory_data_;
|
||||
int thread_num_;
|
||||
paddle::framework::DataFeedDesc data_feed_desc_;
|
||||
int trainer_num_;
|
||||
std::vector<std::string> filelist_;
|
||||
size_t file_idx_;
|
||||
std::mutex mutex_for_pick_file_;
|
||||
std::string fs_name_;
|
||||
std::string fs_ugi_;
|
||||
unsigned int rand_seed;
|
||||
};
|
||||
|
||||
// use std::vector<MultiSlotType> as data type
|
||||
class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
|
||||
public:
|
||||
MultiSlotDataset() {}
|
||||
virtual ~MultiSlotDataset() {}
|
||||
};
|
||||
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/dataset_factory.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "paddle/fluid/framework/data_set.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
typedef std::shared_ptr<Dataset> (*CreateDatasetFunction)();
|
||||
typedef std::unordered_map<std::string, CreateDatasetFunction> datasetMap;
|
||||
datasetMap g_dataset_map;
|
||||
|
||||
#define REGISTER_DATASET_CLASS(dataset_class) \
|
||||
namespace { \
|
||||
std::shared_ptr<Dataset> Creator_##dataset_class() { \
|
||||
return std::shared_ptr<Dataset>(new dataset_class); \
|
||||
} \
|
||||
class __Registerer_##dataset_class { \
|
||||
public: \
|
||||
__Registerer_##dataset_class() { \
|
||||
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
|
||||
} \
|
||||
}; \
|
||||
__Registerer_##dataset_class g_registerer_##dataset_class; \
|
||||
} // namespace
|
||||
|
||||
std::string DatasetFactory::DatasetTypeList() {
|
||||
std::string dataset_types;
|
||||
for (auto iter = g_dataset_map.begin(); iter != g_dataset_map.end(); ++iter) {
|
||||
if (iter != g_dataset_map.begin()) {
|
||||
dataset_types += ", ";
|
||||
}
|
||||
dataset_types += iter->first;
|
||||
}
|
||||
return dataset_types;
|
||||
}
|
||||
|
||||
std::shared_ptr<Dataset> DatasetFactory::CreateDataset(
|
||||
std::string dataset_class) {
|
||||
if (g_dataset_map.count(dataset_class) < 1) {
|
||||
LOG(WARNING) << "Your Dataset " << dataset_class
|
||||
<< "is not supported currently";
|
||||
LOG(WARNING) << "Supported Dataset: " << DatasetTypeList();
|
||||
exit(-1);
|
||||
}
|
||||
return g_dataset_map[dataset_class]();
|
||||
}
|
||||
|
||||
REGISTER_DATASET_CLASS(MultiSlotDataset);
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,29 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/data_set.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class DatasetFactory {
|
||||
public:
|
||||
static std::string DatasetTypeList();
|
||||
static std::shared_ptr<Dataset> CreateDataset(std::string dataset_class);
|
||||
};
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,203 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
|
||||
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
|
||||
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||
#include "paddle/fluid/operators/distributed/communicator.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
inline void NewTempScopeAndInitVars(const std::vector<VarInfo> &var_infos,
|
||||
Scope *scope) {
|
||||
VLOG(3) << "NewTempScopeAndInitVars";
|
||||
Scope &local_scope = scope->NewScope();
|
||||
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
|
||||
&local_scope;
|
||||
|
||||
for (auto &info : var_infos) {
|
||||
if (scope->FindVar(info.name_) != nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (info.persistable_) { // Persistable
|
||||
InitializeVariable(scope->Var(info.name_), info.type_);
|
||||
} else {
|
||||
InitializeVariable(local_scope.Var(info.name_), info.type_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// get RpcContext and remote send and recv op
|
||||
void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
|
||||
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||
using RpcCtxMap = operators::distributed::RpcCtxMap;
|
||||
VLOG(3) << "ProcessGraph";
|
||||
RpcCtxMap send_varname_to_ctx;
|
||||
RpcCtxMap recv_varname_to_ctx;
|
||||
for (auto i = 0; i < graphs.size(); ++i) {
|
||||
std::vector<ir::Node *> nodes_to_delete;
|
||||
for (auto &node : graphs[i]->Nodes()) {
|
||||
VLOG(3) << "node name " << node->Name();
|
||||
if (node && node->IsOp()) {
|
||||
if (node->Name() == "send") {
|
||||
auto send_var_name = node->Op()->Input("X")[0];
|
||||
auto send_varnames = boost::get<std::vector<std::string>>(
|
||||
node->Op()->GetNullableAttr("send_varnames"));
|
||||
auto epmap = boost::get<std::vector<std::string>>(
|
||||
node->Op()->GetNullableAttr("epmap"));
|
||||
auto height_section = boost::get<std::vector<int64_t>>(
|
||||
node->Op()->GetNullableAttr("sections"));
|
||||
send_varname_to_ctx[send_var_name] =
|
||||
operators::distributed::RpcContext(send_var_name, send_varnames,
|
||||
epmap, height_section);
|
||||
VLOG(3) << "find and init an send op: "
|
||||
<< send_varname_to_ctx[send_var_name];
|
||||
} else if (node->Name() == "recv") {
|
||||
auto recv_var_name = node->Op()->Output("Out")[0];
|
||||
auto recv_varnames = boost::get<std::vector<std::string>>(
|
||||
node->Op()->GetNullableAttr("recv_varnames"));
|
||||
auto epmap = boost::get<std::vector<std::string>>(
|
||||
node->Op()->GetNullableAttr("epmap"));
|
||||
recv_varname_to_ctx[recv_var_name] =
|
||||
operators::distributed::RpcContext(recv_var_name, recv_varnames,
|
||||
epmap, {});
|
||||
nodes_to_delete.push_back(node);
|
||||
VLOG(3) << "find and remove an recv op: "
|
||||
<< recv_varname_to_ctx[recv_var_name];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// init communicator here
|
||||
if (send_varname_to_ctx.size() > 0) {
|
||||
VLOG(3) << "this is distribute mode, will use communicator";
|
||||
operators::distributed::Communicator::Init(send_varname_to_ctx,
|
||||
recv_varname_to_ctx, scope);
|
||||
operators::distributed::Communicator::GetInstance()->Start();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
|
||||
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places, std::vector<ir::Graph *> graphs)
|
||||
: strategy_(std::move(strategy)),
|
||||
local_scopes_(std::move(local_scopes)),
|
||||
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
|
||||
places_(std::move(places)),
|
||||
graphs_(std::move(graphs)) {
|
||||
VLOG(3) << "build AsyncSSAGraphExecutor";
|
||||
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
|
||||
|
||||
// set the correct size of thread pool to each device.
|
||||
strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
|
||||
? 1UL
|
||||
: strategy_.num_threads_ / places_.size();
|
||||
VLOG(1) << "set num_threads: " << strategy_.num_threads_
|
||||
<< " to run the operators of the graph on each device.";
|
||||
for (size_t i = 0; i < places.size(); ++i) {
|
||||
executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
|
||||
strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i]));
|
||||
}
|
||||
|
||||
for (auto &node : graphs_[0]->Nodes()) {
|
||||
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
|
||||
var_infos_.emplace_back();
|
||||
var_infos_.back().name_ = node->Var()->Name();
|
||||
var_infos_.back().type_ = node->Var()->GetType();
|
||||
var_infos_.back().persistable_ = node->Var()->Persistable();
|
||||
}
|
||||
}
|
||||
for (auto *scope : local_scopes_) {
|
||||
NewTempScopeAndInitVars(var_infos_, scope);
|
||||
}
|
||||
ProcessGraph(graphs_, local_scopes_[0]);
|
||||
}
|
||||
|
||||
void AsyncSSAGraphExecutor::StartOffPythonTrainLoop() {
|
||||
VLOG(3) << "StartOffPythonTrainLoop size = " << places_.size();
|
||||
for (size_t i = 1; i < places_.size(); ++i) {
|
||||
auto call = [this, i]() -> void {
|
||||
VLOG(3) << "start off python thread " << i;
|
||||
try {
|
||||
while (true) {
|
||||
executors_[i]->Run({});
|
||||
}
|
||||
} catch (...) {
|
||||
exception_holder_.Catch(std::current_exception());
|
||||
VLOG(3) << "get exception type = " << exception_holder_.Type();
|
||||
}
|
||||
VLOG(3) << "thread " << i << " exited!";
|
||||
};
|
||||
run_futures_.emplace_back(pool_->enqueue(std::move(call)));
|
||||
}
|
||||
}
|
||||
|
||||
void AsyncSSAGraphExecutor::HandleException() {
|
||||
if (exception_holder_.IsCaught()) {
|
||||
for (auto &f : run_futures_) {
|
||||
VLOG(3) << "wait future";
|
||||
f.wait();
|
||||
}
|
||||
VLOG(3) << "caught exception " << exception_holder_.Type()
|
||||
<< ", rethrow it";
|
||||
run_futures_.clear();
|
||||
exception_holder_.ReThrow();
|
||||
}
|
||||
}
|
||||
|
||||
FeedFetchList AsyncSSAGraphExecutor::Run(
|
||||
const std::vector<std::string> &fetch_tensors) {
|
||||
// init once
|
||||
if (run_futures_.size() == 0 && places_.size() > 1) {
|
||||
exception_holder_.Clear();
|
||||
StartOffPythonTrainLoop();
|
||||
}
|
||||
|
||||
if (places_.size() == 1) {
|
||||
exception_holder_.Clear();
|
||||
} else {
|
||||
HandleException();
|
||||
}
|
||||
|
||||
FeedFetchList fetch_data;
|
||||
fetch_data.reserve(fetch_tensors.size());
|
||||
|
||||
try {
|
||||
fetch_data = executors_[0]->Run(fetch_tensors);
|
||||
} catch (...) {
|
||||
exception_holder_.Catch(std::current_exception());
|
||||
}
|
||||
|
||||
HandleException();
|
||||
|
||||
FeedFetchList ret;
|
||||
for (size_t fetch_idx = 0; fetch_idx < fetch_tensors.size(); ++fetch_idx) {
|
||||
std::vector<const LoDTensor *> lodtensor_ptrs;
|
||||
lodtensor_ptrs.push_back(&fetch_data.at(fetch_idx));
|
||||
ret.emplace_back();
|
||||
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,65 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ThreadPool.h"
|
||||
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct VarInfo {
|
||||
std::string name_;
|
||||
proto::VarType::Type type_;
|
||||
bool persistable_;
|
||||
};
|
||||
|
||||
class AsyncSSAGraphExecutor : public SSAGraphExecutor {
|
||||
public:
|
||||
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
std::vector<ir::Graph *> graphs);
|
||||
~AsyncSSAGraphExecutor() final = default;
|
||||
const ir::Graph &Graph() const override { return *graphs_[0]; }
|
||||
|
||||
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
|
||||
|
||||
private:
|
||||
void StartOffPythonTrainLoop();
|
||||
void HandleException();
|
||||
|
||||
private:
|
||||
ExecutionStrategy strategy_;
|
||||
std::vector<Scope *> local_scopes_;
|
||||
std::unique_ptr<::ThreadPool> pool_{nullptr};
|
||||
std::vector<platform::Place> places_;
|
||||
std::vector<ir::Graph *> graphs_;
|
||||
|
||||
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
|
||||
ExceptionHolder exception_holder_;
|
||||
std::vector<std::future<void>> run_futures_;
|
||||
std::vector<VarInfo> var_infos_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue