|
|
|
@ -27,7 +27,10 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/scope.h"
|
|
|
|
|
#include "paddle/fluid/framework/variable.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/distributed.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/rpc_client.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed/rpc_common.h"
|
|
|
|
|
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
|
|
|
|
#include "paddle/fluid/platform/device_context.h"
|
|
|
|
@ -268,7 +271,7 @@ class Communicator {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using SparseIdsMap =
|
|
|
|
|
std::unordered_map<std::string, std::unordered_set<int64_t>>;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::unordered_set<int64_t>>>;
|
|
|
|
|
|
|
|
|
|
class AsyncCommunicator : public Communicator {
|
|
|
|
|
public:
|
|
|
|
@ -348,15 +351,18 @@ class GeoSgdCommunicator : public Communicator {
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void SendThread();
|
|
|
|
|
void RecvAll();
|
|
|
|
|
std::unordered_set<int64_t> SparseIdsMerge(
|
|
|
|
|
const std::vector<SparseIdsMap>& ids_send_vec,
|
|
|
|
|
const std::string& var_name);
|
|
|
|
|
const std::string& var_name, const std::string& splited_var_name);
|
|
|
|
|
|
|
|
|
|
void SendUpdateDenseVars(const std::string& var_name);
|
|
|
|
|
void SendUpdateSparseVars(const std::string& var_name,
|
|
|
|
|
const std::string& splited_var_name,
|
|
|
|
|
const std::unordered_set<int64_t>& ids_table);
|
|
|
|
|
void RecvUpdateVars(const std::string& var_name);
|
|
|
|
|
|
|
|
|
|
void RecvUpdateDenseVars(const std::string& var_name);
|
|
|
|
|
void RecvUpdateSparseVars(const std::string& var_name,
|
|
|
|
|
const std::string& splited_var_name);
|
|
|
|
|
|
|
|
|
|
void GeoSgdDenseParamInit(framework::Scope* scope_x,
|
|
|
|
|
framework::Scope* scope_y,
|
|
|
|
@ -366,6 +372,14 @@ class GeoSgdCommunicator : public Communicator {
|
|
|
|
|
framework::Scope* scope_y,
|
|
|
|
|
const std::string var_name);
|
|
|
|
|
|
|
|
|
|
void RpcSend(const std::string& origin_var_name,
|
|
|
|
|
const std::string& splited_var_name,
|
|
|
|
|
const size_t& splited_var_index);
|
|
|
|
|
|
|
|
|
|
void RpcRecv(const std::string& origin_var_name,
|
|
|
|
|
const std::string& splited_var_name,
|
|
|
|
|
const size_t& splited_var_index);
|
|
|
|
|
|
|
|
|
|
const std::string VarToDeltaVar(const std::string var_name) {
|
|
|
|
|
std::string delta_name = var_name;
|
|
|
|
|
const std::string send_name = delta_name.append(".delta");
|
|
|
|
@ -379,6 +393,20 @@ class GeoSgdCommunicator : public Communicator {
|
|
|
|
|
return param_name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetSplitedVarIndex(const std::string var_name,
|
|
|
|
|
const std::string splited_var_name) {
|
|
|
|
|
size_t index = 0;
|
|
|
|
|
for (size_t i = 0;
|
|
|
|
|
i < send_varname_to_ctx_[var_name].splited_var_names.size(); i++) {
|
|
|
|
|
if (send_varname_to_ctx_[var_name].splited_var_names[i] ==
|
|
|
|
|
splited_var_name) {
|
|
|
|
|
index = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return index;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
int trainer_nums_ = 1;
|
|
|
|
|
int geo_need_push_nums_ = 100;
|
|
|
|
@ -390,8 +418,6 @@ class GeoSgdCommunicator : public Communicator {
|
|
|
|
|
std::shared_ptr<Scope> pserver_scope_; // parameter on pserver,gloabl scope
|
|
|
|
|
RpcCtxMap send_varname_to_ctx_;
|
|
|
|
|
RpcCtxMap recv_varname_to_ctx_;
|
|
|
|
|
|
|
|
|
|
std::atomic_uint have_push_{0};
|
|
|
|
|
std::unordered_map<std::string, bool>
|
|
|
|
|
var_list_; // if var is sparse, using selected rows, bool=true
|
|
|
|
|
|
|
|
|
@ -399,9 +425,12 @@ class GeoSgdCommunicator : public Communicator {
|
|
|
|
|
need_push_queue_;
|
|
|
|
|
std::vector<SparseIdsMap> ids_send_vec_;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::vector<int64_t>> absolute_section_;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
|
|
|
|
|
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
|
|
|
|
|
std::unique_ptr<std::thread> send_thread_{nullptr};
|
|
|
|
|
|
|
|
|
|
size_t need_thread_nums_{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace distributed
|
|
|
|
|