You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/fleet/gloo_wrapper.h

245 lines
6.8 KiB

/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#ifdef _LINUX
#include <sys/types.h>
#include <unistd.h>
#endif
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef PADDLE_WITH_GLOO
#include <gloo/allgather.h>
#include <gloo/allreduce.h>
#include <gloo/barrier.h>
#include <gloo/rendezvous/context.h>
#include <gloo/rendezvous/file_store.h>
#include <gloo/rendezvous/http_store.h>
#include <gloo/rendezvous/prefix_store.h>
#include <gloo/rendezvous/store.h>
#include <gloo/transport/tcp/device.h>
#endif
#include "paddle/fluid/framework/variable_helper.h"
namespace gloo {
class Context;
namespace transport {
class Device;
} // namespace transport
} // namespace gloo
namespace gloo {
namespace rendezvous {
#ifdef PADDLE_WITH_GLOO
class HdfsStore : public gloo::rendezvous::Store {
#else
class HdfsStore {
#endif
public: // NOLINT
explicit HdfsStore(const std::string& path);
virtual ~HdfsStore() {}
virtual void set(const std::string& key, const std::vector<char>& data);
virtual std::vector<char> get(const std::string& key);
virtual void wait(const std::vector<std::string>& keys);
virtual void wait(const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout);
virtual void SetTimeoutSeconds(int timeout_seconds);
std::string EncodeName(const std::string& name);
std::string TmpPath(const std::string& name);
std::string ObjectPath(const std::string& name);
bool Check(const std::vector<std::string>& keys,
std::vector<bool>* keys_check_status);
void SetRank(int rank) { self_rank_ = rank; }
std::string path_;
int wait_sleep_ms_;
std::chrono::seconds wait_timeout_;
int retry_times_;
int self_rank_;
};
#ifdef PADDLE_WITH_GLOO
class ParallelConnectContext : public gloo::rendezvous::Context {
public:
ParallelConnectContext(int rank, int size, int base = 2)
: gloo::rendezvous::Context(rank, size, base) {}
virtual ~ParallelConnectContext() {}
// in gloo::rendezvous::Context wait&get one by one,
// slowly in case big size, especialy in HdfsStore
void connectFullMesh(Store& store, // NOLINT
std::shared_ptr<transport::Device>& dev); // NOLINT
protected:
int thread_num_ = 6;
};
#endif
} // namespace rendezvous
} // namespace gloo
namespace paddle {
namespace framework {
enum GlooStoreType { HDFS, HTTP };
class GlooWrapper {
public:
static std::shared_ptr<GlooWrapper> GetInstance() {
static auto s_instance = std::make_shared<GlooWrapper>();
return s_instance;
}
GlooWrapper() {}
virtual ~GlooWrapper() {}
void Init();
void SetTimeoutSeconds(int init_seconds, int run_seconds) {
init_timeout_ = std::chrono::seconds(init_seconds);
run_timeout_ = std::chrono::seconds(run_seconds);
}
int Rank() { return rank_; }
int Size() { return size_; }
void SetRank(int rank) { rank_ = rank; }
void SetSize(int size) { size_ = size; }
void SetIface(const std::string& iface) { iface_ = iface; }
void SetPrefix(const std::string& prefix) { prefix_ = prefix; }
void SetHdfsStore(const std::string& path, const std::string& fs_name,
const std::string& fs_ugi) {
store_type_ = GlooStoreType::HDFS;
hdfs_path_ = path;
hdfs_name_ = fs_name;
hdfs_ugi_ = fs_ugi;
}
void SetHttpStore(const std::string& ip, int port, const std::string& scope) {
store_type_ = GlooStoreType::HTTP;
http_ip_ = ip;
http_port_ = port;
http_scope_ = scope;
}
void Barrier() {
CHECK_EQ(is_initialized_, true);
#ifdef PADDLE_WITH_GLOO
gloo::BarrierOptions opts(context_);
gloo::barrier(opts);
#else
LOG(WARNING) << "Barrier does nothing when WITH_GLOO=OFF";
#endif
}
bool IsInitialized() { return is_initialized_; }
#ifdef PADDLE_WITH_GLOO
std::shared_ptr<gloo::Context> GetContext() { return context_; }
#endif
template <typename T>
std::vector<T> AllReduce(std::vector<T>& sendbuf, // NOLINT
const std::string& mode = "sum") { // NOLINT
CHECK_EQ(is_initialized_, true);
std::vector<T> recvbuf(sendbuf.size(), T());
CHECK_EQ(sendbuf.size() == recvbuf.size(), true);
#ifdef PADDLE_WITH_GLOO
gloo::AllreduceOptions opts(context_);
opts.setInput(sendbuf.data(), sendbuf.size());
opts.setOutput(recvbuf.data(), recvbuf.size());
if (mode == "sum") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::sum<T>));
} else if (mode == "max") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::max<T>));
} else if (mode == "min") {
opts.setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
} else {
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::InvalidArgument(
"AllReduce mode not known: " + mode));
}
gloo::allreduce(opts);
#else
LOG(WARNING) << "AllReduce does nothing when WITH_GLOO=OFF";
#endif
return recvbuf;
}
template <typename T>
std::vector<T> AllGather(T& input) { // NOLINT
CHECK_EQ(is_initialized_, true);
std::vector<T> ret(size_, T());
#ifdef PADDLE_WITH_GLOO
gloo::AllgatherOptions opts(context_);
opts.setInput(&input, 1);
opts.setOutput(ret.data(), size_);
gloo::allgather(opts);
#else
LOG(WARNING) << "AllGather does nothing when WITH_GLOO=OFF";
#endif
return std::move(ret);
}
protected:
bool is_initialized_ = false;
#ifdef PADDLE_WITH_GLOO
std::shared_ptr<gloo::Context> context_ = nullptr;
#endif
int rank_ = 0;
int size_ = 0;
std::chrono::seconds init_timeout_ = std::chrono::seconds(9999999);
std::chrono::seconds run_timeout_ = std::chrono::seconds(9999999);
std::string iface_ = "lo";
std::string prefix_;
GlooStoreType store_type_ = GlooStoreType::HDFS;
// configs for hdfs store
std::string hdfs_path_;
std::string hdfs_name_;
std::string hdfs_ugi_;
std::string http_ip_;
// configs for http store
int http_port_;
std::string http_scope_;
};
} // namespace framework
} // namespace paddle