You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
213 lines
7.3 KiB
213 lines
7.3 KiB
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#pragma once
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "brpc/channel.h"
|
|
#include "brpc/controller.h"
|
|
#include "brpc/server.h"
|
|
#include "paddle/fluid/distributed/service/ps_client.h"
|
|
|
|
namespace paddle {
|
|
namespace distributed {
|
|
|
|
class DownpourPsClientService : public PsService {
|
|
public:
|
|
DownpourPsClientService() {}
|
|
virtual ~DownpourPsClientService() {}
|
|
|
|
virtual int32_t configure(PSClient *client, size_t rank_id) {
|
|
_client = client;
|
|
_rank = rank_id;
|
|
return 0;
|
|
}
|
|
virtual void service(::google::protobuf::RpcController *controller,
|
|
const ::paddle::PsRequestMessage *request,
|
|
::paddle::PsResponseMessage *response,
|
|
::google::protobuf::Closure *done) override;
|
|
|
|
protected:
|
|
size_t _rank;
|
|
PSClient *_client;
|
|
};
|
|
|
|
class DownpourBrpcClosure : public PSClientClosure {
|
|
public:
|
|
DownpourBrpcClosure(size_t num, PSClientCallBack callback)
|
|
: PSClientClosure(callback) {
|
|
_waiting_num = num;
|
|
|
|
_cntls.resize(num);
|
|
_requests.resize(num);
|
|
_responses.resize(num);
|
|
for (size_t i = 0; i < num; ++i) {
|
|
_cntls[i].reset(new brpc::Controller());
|
|
}
|
|
}
|
|
virtual ~DownpourBrpcClosure() {}
|
|
virtual void Run() override {
|
|
if (_waiting_num.fetch_sub(1) == 1) {
|
|
_callback(this);
|
|
delete this;
|
|
}
|
|
}
|
|
PsRequestMessage *request(size_t i) { return &_requests[i]; }
|
|
PsResponseMessage *response(size_t i) { return &_responses[i]; }
|
|
brpc::Controller *cntl(size_t i) { return _cntls[i].get(); }
|
|
int check_response(size_t request_idx, int cmd_id);
|
|
int check_save_response(size_t request_idx, int cmd_id);
|
|
std::string get_response(size_t request_idx, int cmd_id);
|
|
|
|
private:
|
|
std::atomic<int32_t> _waiting_num;
|
|
std::vector<PsRequestMessage> _requests;
|
|
std::vector<PsResponseMessage> _responses;
|
|
std::vector<std::shared_ptr<brpc::Controller>> _cntls;
|
|
};
|
|
|
|
template <class T>
|
|
struct array_deleter {
|
|
void operator()(T *&x) const { delete[] x; }
|
|
};
|
|
|
|
class BrpcPsClient : public PSClient {
|
|
public:
|
|
BrpcPsClient() {}
|
|
virtual ~BrpcPsClient() {
|
|
// _running = false;
|
|
// try {
|
|
// _async_push_dense_thread.join();
|
|
// _async_push_sparse_thread.join();
|
|
//} catch (...) {
|
|
//}
|
|
}
|
|
virtual int32_t create_client2client_connection(
|
|
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
|
|
virtual std::future<int32_t> shrink(uint32_t table_id) override;
|
|
virtual std::future<int32_t> load(const std::string &epoch,
|
|
const std::string &mode) override;
|
|
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
|
|
const std::string &mode) override;
|
|
|
|
virtual std::future<int32_t> save(const std::string &epoch,
|
|
const std::string &mode) override;
|
|
|
|
virtual std::future<int32_t> save(uint32_t table_id, const std::string &epoch,
|
|
const std::string &mode) override;
|
|
|
|
virtual std::future<int32_t> clear() override;
|
|
|
|
virtual std::future<int32_t> clear(uint32_t table_id) override;
|
|
|
|
virtual std::future<int32_t> stop_server() override;
|
|
|
|
virtual std::future<int32_t> start_profiler() override;
|
|
virtual std::future<int32_t> stop_profiler() override;
|
|
|
|
virtual void finalize_worker() override;
|
|
|
|
virtual std::future<int32_t> pull_dense(Region *regions, size_t region_num,
|
|
size_t table_id);
|
|
|
|
virtual std::future<int32_t> push_dense_param(const Region *regions,
|
|
size_t region_num,
|
|
size_t table_id);
|
|
|
|
virtual std::future<int32_t> pull_sparse(float **select_values,
|
|
size_t table_id,
|
|
const uint64_t *keys, size_t num);
|
|
|
|
virtual std::future<int32_t> print_table_stat(uint32_t table_id);
|
|
|
|
virtual std::future<int32_t> barrier(size_t table_id, uint32_t barrier_type);
|
|
|
|
virtual std::future<int32_t> pull_geo_param(size_t table_id,
|
|
std::vector<float> *values,
|
|
std::vector<uint64_t> *keys,
|
|
int pserver_idx);
|
|
|
|
virtual std::future<int32_t> flush();
|
|
|
|
virtual std::future<int32_t> send_client2client_msg(
|
|
int msg_type, int to_client_id, const std::string &msg) override;
|
|
|
|
private:
|
|
virtual int32_t initialize() override;
|
|
|
|
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
|
|
uint32_t shard_num) {
|
|
return dense_dim_total / shard_num + 1;
|
|
}
|
|
|
|
std::future<int32_t> send_cmd(uint32_t table_id, int cmd_id,
|
|
const std::vector<std::string> ¶m);
|
|
|
|
std::future<int32_t> send_save_cmd(uint32_t table_id, int cmd_id,
|
|
const std::vector<std::string> ¶m);
|
|
|
|
inline brpc::Channel *get_sparse_channel(size_t server_id) {
|
|
return _server_channels[server_id][0].get();
|
|
}
|
|
inline brpc::Channel *get_dense_channel(size_t server_id) {
|
|
return _server_channels[server_id][1].get();
|
|
}
|
|
inline brpc::Channel *get_cmd_channel(size_t server_id) {
|
|
return _server_channels[server_id][2].get();
|
|
}
|
|
|
|
bool _running = false;
|
|
bool _flushing = false;
|
|
std::atomic<uint32_t> _async_call_num; //异步请求计数
|
|
|
|
std::vector<std::shared_ptr<brpc::Channel>>
|
|
_client_channels; // client2client
|
|
std::vector<std::array<std::shared_ptr<brpc::Channel>, 3>>
|
|
_server_channels; // client2server
|
|
virtual std::future<int32_t> push_dense_raw_gradient(
|
|
int table_id, float *total_send_data, size_t total_send_data_size,
|
|
void *done) override;
|
|
|
|
virtual std::future<int32_t> push_sparse_raw_gradient(
|
|
size_t table_id, const uint64_t *keys, const float **update_values,
|
|
size_t num, void *done) override;
|
|
|
|
virtual std::future<int32_t> push_sparse_raw_gradient_partial(
|
|
size_t table_id, const uint64_t *keys, const float **update_values,
|
|
uint32_t num, void *done, int pserver_idx) override;
|
|
|
|
virtual std::future<int32_t> push_sparse_param(size_t table_id,
|
|
const uint64_t *keys,
|
|
const float **update_values,
|
|
size_t num,
|
|
void *done) override;
|
|
|
|
virtual size_t get_server_nums() { return _server_channels.size(); }
|
|
|
|
private:
|
|
int32_t start_client_service();
|
|
|
|
float _mae = 0;
|
|
float _mse = 0;
|
|
uint16_t _push_times = 0;
|
|
brpc::Server _server;
|
|
DownpourPsClientService _service;
|
|
std::atomic_uint grad_num_{0};
|
|
};
|
|
} // namespace distributed
|
|
} // namespace paddle
|