You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
272 lines
8.7 KiB
272 lines
8.7 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
#include <atomic>
|
|
#include <ctime>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <random>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <vector>
|
|
#include "brpc/channel.h"
|
|
#include "brpc/controller.h"
|
|
#include "brpc/server.h"
|
|
#include "paddle/fluid/distributed/service/brpc_utils.h"
|
|
#include "paddle/fluid/distributed/service/sendrecv.pb.h"
|
|
#include "paddle/fluid/framework/executor.h"
|
|
#include "paddle/fluid/framework/program_desc.h"
|
|
#include "paddle/fluid/framework/scope.h"
|
|
#include "paddle/fluid/framework/tensor.h"
|
|
#include "paddle/fluid/framework/variable_helper.h"
|
|
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
|
#include "paddle/fluid/platform/profiler.h"
|
|
|
|
namespace google {
|
|
namespace protobuf {
|
|
class Closure;
|
|
class RpcController;
|
|
} // namespace protobuf
|
|
} // namespace google
|
|
namespace paddle {
|
|
namespace framework {
|
|
class Executor;
|
|
class ProgramDesc;
|
|
} // namespace framework
|
|
namespace platform {
|
|
class DeviceContext;
|
|
} // namespace platform
|
|
} // namespace paddle
|
|
|
|
DECLARE_double(eager_delete_tensor_gb);
|
|
namespace paddle {
|
|
namespace distributed {
|
|
|
|
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
|
|
using VarMsg = ::paddle::distributed::VariableMessage;
|
|
|
|
class HeterService;
|
|
|
|
typedef int32_t (HeterService::*serviceHandlerFunc)(
|
|
const PsRequestMessage& request, PsResponseMessage& response,
|
|
brpc::Controller* cntl);
|
|
|
|
typedef std::function<void(void*)> HeterRpcCallbackFunc;
|
|
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
|
|
HeterServiceHandler;
|
|
|
|
class HeterService : public ::paddle::distributed::PsService {
|
|
public:
|
|
HeterService() {
|
|
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
|
|
_service_handler_map[PS_START_PROFILER] = &HeterService::start_profiler;
|
|
_service_handler_map[PS_STOP_PROFILER] = &HeterService::stop_profiler;
|
|
}
|
|
|
|
virtual ~HeterService() {}
|
|
|
|
virtual void service(::google::protobuf::RpcController* controller,
|
|
const PsRequestMessage* request,
|
|
PsResponseMessage* response,
|
|
::google::protobuf::Closure* done) {
|
|
brpc::ClosureGuard done_guard(done);
|
|
std::string log_label("ReceiveCmd-");
|
|
|
|
response->set_err_code(0);
|
|
response->set_err_msg("");
|
|
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
|
|
auto itr = _service_handler_map.find(request->cmd_id());
|
|
if (itr == _service_handler_map.end()) {
|
|
std::string err_msg(
|
|
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
|
|
err_msg.append(std::to_string(request->cmd_id()));
|
|
return;
|
|
}
|
|
serviceHandlerFunc handler_func = itr->second;
|
|
int service_ret = (this->*handler_func)(*request, *response, cntl);
|
|
if (service_ret != 0) {
|
|
response->set_err_code(service_ret);
|
|
response->set_err_msg("server internal error");
|
|
}
|
|
}
|
|
|
|
void SendAndRecvVariable(::google::protobuf::RpcController* controller,
|
|
const MultiVarMsg* request, MultiVarMsg* response,
|
|
::google::protobuf::Closure* done) {
|
|
brpc::ClosureGuard done_guard(done);
|
|
std::string message_name = request->message_name();
|
|
auto itr = handler_map_.find(message_name);
|
|
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
|
|
PADDLE_ENFORCE_NE(
|
|
itr, handler_map_.end(),
|
|
platform::errors::InvalidArgument(
|
|
"HeterService::SendAndRecvVariable Get illegal message_name: %s "
|
|
"which is not in HeterService::handler_map_",
|
|
message_name));
|
|
itr->second(request, response, cntl);
|
|
}
|
|
|
|
void RegisterServiceHandler(std::string message_name,
|
|
HeterServiceHandler func) {
|
|
handler_map_[message_name] = func;
|
|
}
|
|
|
|
void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; }
|
|
void SetFanin(const int& fan_in) { fan_in_ = fan_in; }
|
|
bool IsExit() { return is_exit_; }
|
|
|
|
private:
|
|
int32_t stop_profiler(const PsRequestMessage& request,
|
|
PsResponseMessage& response, brpc::Controller* cntl);
|
|
|
|
int32_t start_profiler(const PsRequestMessage& request,
|
|
PsResponseMessage& response, brpc::Controller* cntl);
|
|
|
|
int32_t stop_heter_worker(const PsRequestMessage& request,
|
|
PsResponseMessage& response,
|
|
brpc::Controller* cntl);
|
|
|
|
private:
|
|
std::string endpoint_;
|
|
std::unordered_map<std::string, HeterServiceHandler> handler_map_;
|
|
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
|
|
std::unordered_set<int> stop_cpu_worker_set_;
|
|
int fan_in_;
|
|
bool is_exit_ = false;
|
|
};
|
|
|
|
class HeterServer {
|
|
public:
|
|
virtual ~HeterServer() {}
|
|
|
|
void Stop() {
|
|
VLOG(3) << "HeterServer Stop()";
|
|
std::unique_lock<std::mutex> lock(mutex_);
|
|
stoped_ = true;
|
|
cv_.notify_all();
|
|
server_.Stop(1000);
|
|
server_.Join();
|
|
}
|
|
|
|
bool IsExit() { return service_.IsExit(); }
|
|
|
|
HeterServer() {}
|
|
|
|
void RegisterServiceHandler(std::string message_name,
|
|
HeterServiceHandler func);
|
|
|
|
void StartHeterService();
|
|
|
|
void SetEndPoint(std::string& endpoint);
|
|
void SetFanin(int& fan_in);
|
|
|
|
// HeterWrapper singleton
|
|
static std::shared_ptr<HeterServer> GetInstance() {
|
|
if (NULL == s_instance_) {
|
|
s_instance_.reset(new HeterServer());
|
|
}
|
|
return s_instance_;
|
|
}
|
|
|
|
void WaitServerReady();
|
|
|
|
private:
|
|
static std::shared_ptr<HeterServer> s_instance_;
|
|
mutable std::mutex mutex_;
|
|
std::condition_variable cv_;
|
|
std::condition_variable condition_ready_;
|
|
bool stoped_ = false;
|
|
std::string endpoint_;
|
|
|
|
protected:
|
|
brpc::Server server_;
|
|
HeterService service_;
|
|
DISABLE_COPY_AND_ASSIGN(HeterServer);
|
|
std::mutex mutex_ready_;
|
|
|
|
int ready_;
|
|
};
|
|
|
|
class HeterRequestHandler {
|
|
public:
|
|
HeterRequestHandler()
|
|
: dev_ctx_(nullptr),
|
|
executor_(nullptr),
|
|
scope_(nullptr),
|
|
program_(nullptr) {}
|
|
|
|
virtual ~HeterRequestHandler() {}
|
|
|
|
void SetScope(framework::Scope* scope) { scope_ = scope; }
|
|
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
|
|
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
|
|
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
|
|
|
|
void SetGradToPreparedCtx(
|
|
std::unordered_map<
|
|
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
|
|
message_to_prepared_ctx_ = g;
|
|
}
|
|
|
|
virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response,
|
|
brpc::Controller* cntl) = 0;
|
|
|
|
protected:
|
|
const platform::DeviceContext* dev_ctx_;
|
|
framework::Executor* executor_;
|
|
framework::Scope* scope_;
|
|
framework::ProgramDesc* program_;
|
|
|
|
std::unordered_map<std::string,
|
|
std::shared_ptr<framework::ExecutorPrepareContext>>*
|
|
message_to_prepared_ctx_;
|
|
};
|
|
|
|
class RequestSendAndRecvHandler final : public HeterRequestHandler {
|
|
public:
|
|
RequestSendAndRecvHandler() {}
|
|
virtual ~RequestSendAndRecvHandler() {}
|
|
int Handle(const MultiVarMsg* request, MultiVarMsg* response,
|
|
brpc::Controller* cntl) override {
|
|
platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle");
|
|
FLAGS_eager_delete_tensor_gb = -1;
|
|
auto& local_scope = scope_->NewScope();
|
|
auto message_name = request->message_name();
|
|
auto& request_io_buffer = cntl->request_attachment();
|
|
distributed::DeserializeFromMultiVarMsgAndIOBuf(
|
|
*request, &request_io_buffer, *dev_ctx_, &local_scope);
|
|
executor_->RunPreparedContext(
|
|
(*message_to_prepared_ctx_)[message_name].get(), &local_scope, false);
|
|
|
|
auto response_var_nums = request->recv_var_names_size();
|
|
std::vector<std::string> response_var_names(response_var_nums),
|
|
empty_var_names{};
|
|
|
|
for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) {
|
|
response_var_names[var_idx] = request->recv_var_names(var_idx);
|
|
}
|
|
auto& response_io_buffer = cntl->response_attachment();
|
|
distributed::SerializeToMultiVarMsgAndIOBuf(
|
|
message_name, response_var_names, empty_var_names, *dev_ctx_,
|
|
&local_scope, response, &response_io_buffer);
|
|
scope_->DeleteScope(&local_scope);
|
|
return 0;
|
|
}
|
|
};
|
|
|
|
} // end namespace distributed
|
|
} // end namespace paddle
|