Move sync_mode device ctx from grpc server (#10881)
parent
5870a6b486
commit
4fb7cc7f5e
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,127 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
constexpr char kRequestSend[] = "RequestSend";
|
||||
constexpr char kRequestGet[] = "RequestGet";
|
||||
constexpr char kRequestPrefetch[] = "RequestPrefetch";
|
||||
|
||||
class RPCServer;
|
||||
|
||||
class RequestHandler {
|
||||
public:
|
||||
explicit RequestHandler(bool sync_mode)
|
||||
: sync_mode_(sync_mode),
|
||||
dev_ctx_(nullptr),
|
||||
executor_(nullptr),
|
||||
scope_(nullptr),
|
||||
program_(nullptr),
|
||||
rpc_server_(nullptr) {}
|
||||
|
||||
virtual ~RequestHandler() {}
|
||||
|
||||
// Set attributes.
|
||||
void SetScope(framework::Scope* scope) { scope_ = scope; }
|
||||
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
|
||||
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
|
||||
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
|
||||
void SetPrefetchPreparedCtx(
|
||||
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
|
||||
prefetch_ctx_.reset(prepared.release());
|
||||
}
|
||||
|
||||
// Used for async.
|
||||
void SetGradToPreparedCtx(
|
||||
std::unordered_map<
|
||||
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
|
||||
grad_to_prepared_ctx_ = g;
|
||||
}
|
||||
|
||||
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
|
||||
|
||||
// Get attributes.
|
||||
bool sync_mode() { return sync_mode_; }
|
||||
framework::Scope* scope() { return scope_; }
|
||||
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
|
||||
framework::ExecutorPrepareContext* prefetch_ctx() {
|
||||
return prefetch_ctx_.get();
|
||||
}
|
||||
framework::ProgramDesc* program() { return program_; }
|
||||
framework::Executor* executor() { return executor_; }
|
||||
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
|
||||
|
||||
// This function processes user's rpc request.
|
||||
// The implemention is in request_handler_impl.
|
||||
// example:
|
||||
// std::string varname = request_.varname();
|
||||
//
|
||||
// auto scope = request_handler_->scope();
|
||||
// auto invar = scope->FindVar(varname);
|
||||
// framework::Variable* outvar = nullptr;
|
||||
//
|
||||
// request_handler_->Handle(varname, scope, invar, &outvar);
|
||||
// if (outvar) {
|
||||
// SerializeToByteBuffer(varname, outvar,
|
||||
// *request_handler_->dev_ctx(), &reply_);
|
||||
// }
|
||||
virtual bool Handle(const std::string& varname, framework::Scope* scope,
|
||||
framework::Variable* var,
|
||||
framework::Variable** outvar) = 0;
|
||||
|
||||
protected:
|
||||
const bool sync_mode_;
|
||||
|
||||
const platform::DeviceContext* dev_ctx_;
|
||||
framework::Executor* executor_;
|
||||
framework::Scope* scope_;
|
||||
framework::ProgramDesc* program_;
|
||||
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
|
||||
|
||||
// Used for async.
|
||||
std::unordered_map<std::string,
|
||||
std::shared_ptr<framework::ExecutorPrepareContext>>*
|
||||
grad_to_prepared_ctx_;
|
||||
|
||||
// Record received sparse variables, so that
|
||||
// we could reset those after execute optimize program
|
||||
std::vector<framework::Variable*> sparse_vars_;
|
||||
RPCServer* rpc_server_;
|
||||
|
||||
std::mutex sparse_var_mutex_;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,115 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/blocking_queue.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/operators/detail/request_handler_impl.h"
|
||||
#include "paddle/fluid/operators/detail/rpc_server.h"
|
||||
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
|
||||
#include "paddle/fluid/operators/detail/variable_response.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
bool RequestSendHandler::Handle(const std::string& varname,
|
||||
framework::Scope* scope,
|
||||
framework::Variable* invar,
|
||||
framework::Variable** outvar) {
|
||||
VLOG(4) << "RequestSendHandler:" << varname;
|
||||
|
||||
// Async
|
||||
if (!sync_mode_) {
|
||||
try {
|
||||
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
|
||||
scope);
|
||||
} catch (std::exception& e) {
|
||||
LOG(ERROR) << "async: run sub program error " << e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Sync
|
||||
if (varname == BATCH_BARRIER_MESSAGE) {
|
||||
VLOG(3) << "sync: recv batch barrier message";
|
||||
rpc_server_->IncreaseBatchBarrier(kRequestSend);
|
||||
} else {
|
||||
VLOG(3) << "sync: received var_name: " << varname;
|
||||
if (sync_mode_) {
|
||||
rpc_server_->WaitCond(kRequestSend);
|
||||
}
|
||||
|
||||
if (invar == nullptr) {
|
||||
LOG(ERROR) << "sync: Can not find server side var: " << varname;
|
||||
PADDLE_THROW("sync: Can not find server side var");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (invar->IsType<framework::SelectedRows>()) {
|
||||
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
|
||||
sparse_vars_.push_back(invar);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RequestGetHandler::Handle(const std::string& varname,
|
||||
framework::Scope* scope,
|
||||
framework::Variable* invar,
|
||||
framework::Variable** outvar) {
|
||||
VLOG(4) << "RequestGetHandler:" << varname;
|
||||
|
||||
if (varname != FETCH_BARRIER_MESSAGE) {
|
||||
if (sync_mode_) {
|
||||
rpc_server_->WaitCond(kRequestGet);
|
||||
}
|
||||
*outvar = scope_->FindVar(varname);
|
||||
return true;
|
||||
}
|
||||
|
||||
// FETCH_BARRIER_MESSAGE
|
||||
if (sync_mode_) {
|
||||
VLOG(3) << "sync: recv fetch barrier message";
|
||||
rpc_server_->IncreaseBatchBarrier(kRequestGet);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RequestPrefetchHandler::Handle(const std::string& varname,
|
||||
framework::Scope* scope,
|
||||
framework::Variable* invar,
|
||||
framework::Variable** outvar) {
|
||||
VLOG(4) << "RequestPrefetchHandler " << varname;
|
||||
|
||||
auto var_desc = program_->Block(0).FindVar(varname);
|
||||
*outvar = scope->FindVar(varname);
|
||||
InitializeVariable(*outvar, var_desc->GetType());
|
||||
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,64 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/executor.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type.h"
|
||||
#include "paddle/fluid/operators/detail/request_handler.h"
|
||||
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
class RequestSendHandler final : public RequestHandler {
|
||||
public:
|
||||
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
|
||||
virtual ~RequestSendHandler() {}
|
||||
bool Handle(const std::string& varname, framework::Scope* scope,
|
||||
framework::Variable* var, framework::Variable** outvar) override;
|
||||
};
|
||||
|
||||
class RequestGetHandler final : public RequestHandler {
|
||||
public:
|
||||
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
|
||||
virtual ~RequestGetHandler() {}
|
||||
bool Handle(const std::string& varname, framework::Scope* scope,
|
||||
framework::Variable* var, framework::Variable** outvar) override;
|
||||
};
|
||||
|
||||
class RequestPrefetchHandler final : public RequestHandler {
|
||||
public:
|
||||
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
|
||||
virtual ~RequestPrefetchHandler() {}
|
||||
bool Handle(const std::string& varname, framework::Scope* scope,
|
||||
framework::Variable* var, framework::Variable** outvar) override;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,113 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/operators/detail/rpc_server.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
void RPCServer::ShutDown() {
|
||||
LOG(INFO) << "RPCServer ShutDown ";
|
||||
ShutDownImpl();
|
||||
|
||||
exit_flag_ = true;
|
||||
barrier_cond_.notify_all();
|
||||
rpc_cond_.notify_all();
|
||||
}
|
||||
|
||||
void RPCServer::SavePort() const {
|
||||
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
|
||||
std::ofstream port_file;
|
||||
port_file.open(file_path);
|
||||
port_file << selected_port_;
|
||||
port_file.close();
|
||||
VLOG(4) << "selected port written to " << file_path;
|
||||
}
|
||||
|
||||
void RPCServer::WaitBarrier(const std::string& rpc_name) {
|
||||
std::unique_lock<std::mutex> lock(this->mutex_);
|
||||
barrier_cond_.wait(lock, [=] {
|
||||
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
|
||||
});
|
||||
|
||||
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
|
||||
}
|
||||
|
||||
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
|
||||
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
|
||||
int b = 0;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
b = ++barrier_counter_[rpc_name];
|
||||
}
|
||||
|
||||
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
|
||||
<< ", barrier_count:" << b << ", fan_in" << client_num_;
|
||||
|
||||
if (b >= client_num_) {
|
||||
barrier_cond_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
void RPCServer::ResetBarrierCounter() {
|
||||
VLOG(3) << "RPCServer ResetBarrierCounter ";
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
for (auto& t : barrier_counter_) {
|
||||
t.second = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void RPCServer::RegisterRPC(const std::string& rpc_name,
|
||||
RequestHandler* handler, int thread_num) {
|
||||
rpc_call_map_[rpc_name] = handler;
|
||||
rpc_thread_num_[rpc_name] = thread_num;
|
||||
|
||||
static int cond = -1;
|
||||
rpc_cond_map_[rpc_name] = ++cond;
|
||||
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
|
||||
<< ", cond:" << rpc_cond_map_[rpc_name];
|
||||
}
|
||||
|
||||
void RPCServer::SetCond(const std::string& rpc_name) {
|
||||
VLOG(3) << "RPCServer SetCond " << rpc_name;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cur_cond_ = rpc_cond_map_[rpc_name];
|
||||
}
|
||||
|
||||
rpc_cond_.notify_all();
|
||||
}
|
||||
|
||||
void RPCServer::WaitCond(const std::string& rpc_name) {
|
||||
VLOG(3) << "RPCServer WaitCond " << rpc_name;
|
||||
int cond = 0;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cond = rpc_cond_map_[rpc_name];
|
||||
}
|
||||
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
rpc_cond_.wait(
|
||||
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,91 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/detail/request_handler.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace detail {
|
||||
|
||||
class RPCServer {
|
||||
public:
|
||||
explicit RPCServer(const std::string& address, int client_num)
|
||||
: cur_cond_(0),
|
||||
bind_address_(address),
|
||||
exit_flag_(false),
|
||||
selected_port_(0),
|
||||
client_num_(client_num) {}
|
||||
|
||||
virtual ~RPCServer() {}
|
||||
virtual void StartServer() = 0;
|
||||
virtual void WaitServerReady() = 0;
|
||||
|
||||
void ShutDown();
|
||||
|
||||
bool IsExit() { return exit_flag_.load(); }
|
||||
|
||||
int GetSelectedPort() const { return selected_port_; }
|
||||
void SavePort() const;
|
||||
|
||||
// RegisterRPC, register the rpc method name to a handler
|
||||
// class, and auto generate a condition id for this call
|
||||
// to be used for the barrier.
|
||||
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
|
||||
int thread_num = 5);
|
||||
|
||||
// Wait util all the clients have reached the barrier for one
|
||||
// rpc method. This function should be called in the
|
||||
// RequestHandler if you want to run the server/client in a
|
||||
// synchronous mode.
|
||||
void WaitBarrier(const std::string& rpc_name);
|
||||
|
||||
void SetCond(const std::string& rpc_name);
|
||||
void WaitCond(const std::string& rpc_name);
|
||||
void IncreaseBatchBarrier(const std::string rpc_name);
|
||||
void ResetBarrierCounter();
|
||||
|
||||
protected:
|
||||
virtual void ShutDownImpl() = 0;
|
||||
|
||||
private:
|
||||
std::mutex mutex_;
|
||||
std::unordered_map<std::string, int> barrier_counter_;
|
||||
std::condition_variable barrier_cond_;
|
||||
|
||||
std::unordered_map<std::string, int> rpc_cond_map_;
|
||||
std::atomic<int> cur_cond_;
|
||||
std::condition_variable rpc_cond_;
|
||||
|
||||
protected:
|
||||
std::string bind_address_;
|
||||
std::atomic<int> exit_flag_;
|
||||
int selected_port_;
|
||||
|
||||
const int client_num_;
|
||||
|
||||
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
|
||||
std::unordered_map<std::string, int> rpc_thread_num_;
|
||||
friend class RequestHandler;
|
||||
};
|
||||
|
||||
}; // namespace detail
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue