add collective communication library in fleet (#22211)
* add collective communication library in fleet to replace mpi * test=developrevert-22710-feature/integrated_ps_api
parent
05ee05e248
commit
e3a457d34b
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(GLOO_PROJECT "extern_gloo")
|
||||
IF((NOT DEFINED GLOO_VER) OR (NOT DEFINED GLOO_URL))
|
||||
MESSAGE(STATUS "use pre defined download url")
|
||||
SET(GLOO_VER "master" CACHE STRING "" FORCE)
|
||||
SET(GLOO_NAME "gloo" CACHE STRING "" FORCE)
|
||||
SET(GLOO_URL "https://pslib.bj.bcebos.com/gloo.tar.gz" CACHE STRING "" FORCE)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "GLOO_NAME: ${GLOO_NAME}, GLOO_URL: ${GLOO_URL}")
|
||||
SET(GLOO_SOURCE_DIR "${THIRD_PARTY_PATH}/gloo")
|
||||
SET(GLOO_DOWNLOAD_DIR "${GLOO_SOURCE_DIR}/src/${GLOO_PROJECT}")
|
||||
SET(GLOO_DST_DIR "gloo")
|
||||
SET(GLOO_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
|
||||
SET(GLOO_INSTALL_DIR ${GLOO_INSTALL_ROOT}/${GLOO_DST_DIR})
|
||||
SET(GLOO_ROOT ${GLOO_INSTALL_DIR})
|
||||
SET(GLOO_INC_DIR ${GLOO_ROOT}/include)
|
||||
SET(GLOO_LIB_DIR ${GLOO_ROOT}/lib)
|
||||
SET(GLOO_LIB ${GLOO_LIB_DIR}/libgloo.a)
|
||||
#SET(GLOO_IOMP_LIB ${GLOO_LIB_DIR}/libiomp5.so) #todo what is this
|
||||
SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${GLOO_ROOT}/lib")
|
||||
|
||||
INCLUDE_DIRECTORIES(${GLOO_INC_DIR})
|
||||
|
||||
FILE(WRITE ${GLOO_DOWNLOAD_DIR}/CMakeLists.txt
|
||||
"PROJECT(GLOO)\n"
|
||||
"cmake_minimum_required(VERSION 3.0)\n"
|
||||
"install(DIRECTORY ${GLOO_NAME}/include ${GLOO_NAME}/lib \n"
|
||||
" DESTINATION ${GLOO_DST_DIR})\n")
|
||||
|
||||
ExternalProject_Add(
|
||||
${GLOO_PROJECT}
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
PREFIX ${GLOO_SOURCE_DIR}
|
||||
DOWNLOAD_DIR ${GLOO_DOWNLOAD_DIR}
|
||||
DOWNLOAD_COMMAND wget --no-check-certificate ${GLOO_URL} -c -q -O ${GLOO_NAME}.tar.gz
|
||||
&& tar zxvf ${GLOO_NAME}.tar.gz
|
||||
DOWNLOAD_NO_PROGRESS 1
|
||||
UPDATE_COMMAND ""
|
||||
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOO_INSTALL_ROOT}
|
||||
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOO_INSTALL_ROOT}
|
||||
)
|
||||
|
||||
ADD_LIBRARY(gloo SHARED IMPORTED GLOBAL)
|
||||
SET_PROPERTY(TARGET gloo PROPERTY IMPORTED_LOCATION ${GLOO_LIB})
|
||||
ADD_DEPENDENCIES(gloo ${GLOO_PROJECT})
|
@ -0,0 +1,166 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/io/fs.h"
|
||||
#include "paddle/fluid/platform/errors.h"
|
||||
|
||||
namespace gloo {
|
||||
namespace rendezvous {
|
||||
|
||||
HdfsStore::HdfsStore(const std::string& path) {
|
||||
path_ = path;
|
||||
wait_sleep_ms_ = 3000;
|
||||
wait_timeout_ = std::chrono::seconds(999999999);
|
||||
}
|
||||
|
||||
void HdfsStore::set(const std::string& key, const std::vector<char>& data) {
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
auto tmp = TmpPath(key);
|
||||
auto path = ObjectPath(key);
|
||||
bool is_exists = paddle::framework::fs_exists(path);
|
||||
if (is_exists) {
|
||||
LOG(WARNING) << "path exists, will be removed: " << path;
|
||||
paddle::framework::fs_remove(path);
|
||||
}
|
||||
int err_no = 0;
|
||||
std::shared_ptr<FILE> fp = paddle::framework::fs_open_write(tmp, &err_no, "");
|
||||
size_t write_count = fwrite_unlocked(data.data(), 1, data.size(), fp.get());
|
||||
VLOG(3) << "HdfsStore::set write_count=" << write_count << " key " << key;
|
||||
fp.reset();
|
||||
paddle::framework::fs_mv(tmp, path);
|
||||
#endif
|
||||
}
|
||||
|
||||
std::vector<char> HdfsStore::get(const std::string& key) {
|
||||
auto path = ObjectPath(key);
|
||||
std::vector<char> result;
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
// block until key is set
|
||||
wait({key});
|
||||
bool is_exists = paddle::framework::fs_exists(path);
|
||||
PADDLE_ENFORCE_EQ(is_exists, true,
|
||||
paddle::platform::errors::NotFound(
|
||||
"HdfsStore::get, path not exists: " + path));
|
||||
int err_no = 0;
|
||||
std::shared_ptr<FILE> fp = paddle::framework::fs_open_read(path, &err_no, "");
|
||||
char buffer = '\0';
|
||||
size_t read_count = 0;
|
||||
while (fread(&buffer, 1, 1, fp.get()) == 1) {
|
||||
++read_count;
|
||||
result.push_back(buffer);
|
||||
}
|
||||
VLOG(3) << "HdfsStore::get read_count " << read_count;
|
||||
#endif
|
||||
return result;
|
||||
}
|
||||
|
||||
void HdfsStore::wait(const std::vector<std::string>& keys) {
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
wait(keys, wait_timeout_); // NOLINT
|
||||
#endif
|
||||
}
|
||||
|
||||
void HdfsStore::wait(const std::vector<std::string>& keys,
|
||||
const std::chrono::milliseconds&) { // NOLINT
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
while (!Check(keys)) {
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||
std::chrono::steady_clock::now() - start);
|
||||
if (wait_timeout_ != gloo::kNoTimeout && elapsed > wait_timeout_) {
|
||||
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::ExecutionTimeout(
|
||||
"HdfsStore::wait, Wait timeout for key(s): " +
|
||||
::gloo::MakeString(keys)));
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::string HdfsStore::EncodeName(const std::string& name) {
|
||||
thread_local std::hash<std::string> hash_func;
|
||||
return std::to_string(hash_func(name));
|
||||
}
|
||||
|
||||
std::string HdfsStore::TmpPath(const std::string& name) {
|
||||
return path_ + "/" + EncodeName(name) + "_tmp";
|
||||
}
|
||||
|
||||
std::string HdfsStore::ObjectPath(const std::string& name) {
|
||||
return path_ + "/" + EncodeName(name);
|
||||
}
|
||||
|
||||
bool HdfsStore::Check(const std::vector<std::string>& keys) {
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
std::vector<std::string> paths;
|
||||
for (const auto& key : keys) {
|
||||
paths.push_back(ObjectPath(key));
|
||||
}
|
||||
for (const auto& path : paths) {
|
||||
bool is_exists = paddle::framework::fs_exists(path);
|
||||
VLOG(3) << "HdfsStore::Check " << is_exists << " path " << path;
|
||||
if (!is_exists) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace rendezvous
|
||||
} // namespace gloo
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void GlooWrapper::Init(int rank, int size, const std::string& path,
|
||||
const std::string& fs_name, const std::string& fs_ugi,
|
||||
const std::string& iface, const std::string& prefix) {
|
||||
if (is_initialized_) {
|
||||
return;
|
||||
}
|
||||
rank_ = rank;
|
||||
size_ = size;
|
||||
std::string cmd = std::string("hadoop fs");
|
||||
cmd += " -D fs.default.name=" + fs_name;
|
||||
cmd += " -D hadoop.job.ugi=" + fs_ugi;
|
||||
paddle::framework::hdfs_set_command(cmd);
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
gloo::transport::tcp::attr attr;
|
||||
attr.iface = iface;
|
||||
auto file_store = gloo::rendezvous::HdfsStore(path);
|
||||
auto prefix_store = gloo::rendezvous::PrefixStore(prefix, file_store);
|
||||
auto dev = gloo::transport::tcp::CreateDevice(attr);
|
||||
auto context = std::make_shared<gloo::rendezvous::Context>(rank, size);
|
||||
context->setTimeout(file_store.wait_timeout_);
|
||||
context->connectFullMesh(prefix_store, dev);
|
||||
context_ = std::move(context);
|
||||
#endif
|
||||
is_initialized_ = true;
|
||||
}
|
||||
|
||||
template void GlooWrapper::AllReduce<int64_t>(
|
||||
std::vector<int64_t>& sendbuf, // NOLINT
|
||||
std::vector<int64_t>& recvbuf, // NOLINT
|
||||
const std::string& mode);
|
||||
template void GlooWrapper::AllReduce<double>(
|
||||
std::vector<double>& sendbuf, // NOLINT
|
||||
std::vector<double>& recvbuf, // NOLINT
|
||||
const std::string& mode);
|
||||
template std::vector<int64_t> GlooWrapper::AllGather<int64_t>(
|
||||
int64_t& input); // NOLINT
|
||||
template std::vector<double> GlooWrapper::AllGather<double>(
|
||||
double& input); // NOLINT
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,161 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined _WIN32 || defined __APPLE__
|
||||
#else
|
||||
#define _LINUX
|
||||
#endif
|
||||
|
||||
#ifdef _LINUX
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
#include <gloo/allgather.h>
|
||||
#include <gloo/allreduce.h>
|
||||
#include <gloo/barrier.h>
|
||||
#include <gloo/rendezvous/context.h>
|
||||
#include <gloo/rendezvous/file_store.h>
|
||||
#include <gloo/rendezvous/prefix_store.h>
|
||||
#include <gloo/rendezvous/store.h>
|
||||
#include <gloo/transport/tcp/device.h>
|
||||
#endif
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
|
||||
namespace gloo {
|
||||
namespace rendezvous {
|
||||
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
class HdfsStore : public gloo::rendezvous::Store {
|
||||
#else
|
||||
class HdfsStore {
|
||||
#endif
|
||||
public: // NOLINT
|
||||
explicit HdfsStore(const std::string& path);
|
||||
|
||||
virtual ~HdfsStore() {}
|
||||
|
||||
virtual void set(const std::string& key, const std::vector<char>& data);
|
||||
|
||||
virtual std::vector<char> get(const std::string& key);
|
||||
|
||||
virtual void wait(const std::vector<std::string>& keys);
|
||||
|
||||
virtual void wait(const std::vector<std::string>& keys,
|
||||
const std::chrono::milliseconds& timeout);
|
||||
|
||||
std::string EncodeName(const std::string& name);
|
||||
|
||||
std::string TmpPath(const std::string& name);
|
||||
|
||||
std::string ObjectPath(const std::string& name);
|
||||
|
||||
bool Check(const std::vector<std::string>& keys);
|
||||
|
||||
std::string path_;
|
||||
int wait_sleep_ms_;
|
||||
std::chrono::seconds wait_timeout_;
|
||||
};
|
||||
|
||||
} // namespace rendezvous
|
||||
} // namespace gloo
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class GlooWrapper {
|
||||
public:
|
||||
GlooWrapper() {}
|
||||
|
||||
virtual ~GlooWrapper() {}
|
||||
|
||||
void Init(int rank, int size, const std::string& path,
|
||||
const std::string& fs_name, const std::string& fs_ugi,
|
||||
const std::string& iface, const std::string& prefix);
|
||||
|
||||
int Rank() {
|
||||
CHECK_EQ(is_initialized_, true);
|
||||
return rank_;
|
||||
}
|
||||
|
||||
int Size() {
|
||||
CHECK_EQ(is_initialized_, true);
|
||||
return size_;
|
||||
}
|
||||
|
||||
void Barrier() {
|
||||
CHECK_EQ(is_initialized_, true);
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
gloo::BarrierOptions opts(context_);
|
||||
gloo::barrier(opts);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AllReduce(std::vector<T>& sendbuf, std::vector<T>& recvbuf, // NOLINT
|
||||
const std::string& mode = "sum") {
|
||||
CHECK_EQ(is_initialized_, true);
|
||||
CHECK_EQ(sendbuf.size() == recvbuf.size(), true);
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
gloo::AllreduceOptions opts(context_);
|
||||
opts.setInput(sendbuf.data(), sendbuf.size());
|
||||
opts.setOutput(recvbuf.data(), recvbuf.size());
|
||||
if (mode == "sum") {
|
||||
opts.setReduceFunction(
|
||||
static_cast<void (*)(void*, const void*, const void*, size_t)>(
|
||||
&gloo::sum<T>));
|
||||
} else if (mode == "max") {
|
||||
opts.setReduceFunction(
|
||||
static_cast<void (*)(void*, const void*, const void*, size_t)>(
|
||||
&gloo::max<T>));
|
||||
} else if (mode == "min") {
|
||||
opts.setReduceFunction(
|
||||
static_cast<void (*)(void*, const void*, const void*, size_t)>(
|
||||
&gloo::min<T>));
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(0, 1, paddle::platform::errors::InvalidArgument(
|
||||
"AllReduce mode not known: " + mode));
|
||||
}
|
||||
gloo::allreduce(opts);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<T> AllGather(T& input) { // NOLINT
|
||||
CHECK_EQ(is_initialized_, true);
|
||||
std::vector<T> ret(size_, T());
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
gloo::AllgatherOptions opts(context_);
|
||||
opts.setInput(&input, 1);
|
||||
opts.setOutput(ret.data(), size_);
|
||||
gloo::allgather(opts);
|
||||
#endif
|
||||
return std::move(ret);
|
||||
}
|
||||
|
||||
protected:
|
||||
bool is_initialized_ = false;
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
std::shared_ptr<gloo::Context> context_ = nullptr;
|
||||
#endif
|
||||
int rank_ = 0;
|
||||
int size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,66 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fstream>
|
||||
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
|
||||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
|
||||
#include "paddle/fluid/framework/io/fs.h"
|
||||
|
||||
#if defined _WIN32 || defined __APPLE__
|
||||
#else
|
||||
#define _LINUX
|
||||
#endif
|
||||
|
||||
TEST(TEST_GLOO, store_1) {
|
||||
#ifdef _LINUX
|
||||
#ifdef PADDLE_WITH_GLOO
|
||||
#else
|
||||
auto store = gloo::rendezvous::HdfsStore("./test_gllo_store");
|
||||
store.set("1", std::vector<char>{'t', 'e', 's', 't'});
|
||||
store.get("1");
|
||||
try {
|
||||
store.get("2");
|
||||
} catch (...) {
|
||||
VLOG(3) << "catch expected error of not found";
|
||||
}
|
||||
store.wait(std::vector<std::string>{"test"});
|
||||
store.wait(std::vector<std::string>{"test"}, std::chrono::milliseconds(0));
|
||||
store.EncodeName("1");
|
||||
store.TmpPath("1");
|
||||
store.ObjectPath("1");
|
||||
store.Check(std::vector<std::string>{"test"});
|
||||
|
||||
auto gw = paddle::framework::GlooWrapper();
|
||||
gw.Init(0, 1, "", "", "", "", "");
|
||||
gw.Init(0, 1, "", "", "", "", "");
|
||||
gw.Rank();
|
||||
gw.Size();
|
||||
gw.Barrier();
|
||||
std::vector<double> input;
|
||||
std::vector<double> output;
|
||||
gw.AllReduce(input, output);
|
||||
int64_t t;
|
||||
gw.AllGather(t);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(TEST_FLEET, fleet_1) {
|
||||
auto fleet = paddle::framework::FleetWrapper::GetInstance();
|
||||
#ifdef PADDLE_WITH_PSLIB
|
||||
#else
|
||||
fleet->RunServer("", 0);
|
||||
#endif
|
||||
}
|
@ -1,2 +1,4 @@
|
||||
cc_library(fs SRCS fs.cc DEPS string_helper glog boost)
|
||||
cc_library(shell SRCS shell.cc DEPS string_helper glog)
|
||||
|
||||
cc_test(test_fs SRCS test_fs.cc DEPS fs shell)
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fstream>
|
||||
#include "paddle/fluid/framework/io/fs.h"
|
||||
|
||||
#if defined _WIN32 || defined __APPLE__
|
||||
#else
|
||||
#define _LINUX
|
||||
#endif
|
||||
|
||||
TEST(FS, mv) {
|
||||
#ifdef _LINUX
|
||||
std::ofstream out("src.txt");
|
||||
out.close();
|
||||
paddle::framework::fs_mv("src.txt", "dest.txt");
|
||||
paddle::framework::hdfs_mv("", "");
|
||||
paddle::framework::localfs_mv("", "");
|
||||
try {
|
||||
paddle::framework::hdfs_mv("afs:/none", "afs:/none");
|
||||
} catch (...) {
|
||||
VLOG(3) << "test hdfs_mv, catch expected errors of unknown path";
|
||||
}
|
||||
try {
|
||||
paddle::framework::fs_mv("afs:/none", "afs:/none");
|
||||
} catch (...) {
|
||||
VLOG(3) << "test hdfs_mv, catch expected errors of unknown path";
|
||||
}
|
||||
try {
|
||||
paddle::framework::hdfs_mv("unknown:/none", "unknown:/none");
|
||||
} catch (...) {
|
||||
VLOG(3) << "test hdfs_mv, catch expected errors of unknown prefix";
|
||||
}
|
||||
#endif
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include <fcntl.h>
|
||||
|
||||
#ifdef _POSIX_C_SOURCE
|
||||
#undef _POSIX_C_SOURCE
|
||||
#endif
|
||||
|
||||
#ifdef _XOPEN_SOURCE
|
||||
#undef _XOPEN_SOURCE
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
#include "paddle/fluid/pybind/gloo_wrapper_py.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
void BindGlooWrapper(py::module* m) {
|
||||
py::class_<framework::GlooWrapper>(*m, "Gloo")
|
||||
.def(py::init())
|
||||
.def("init", &framework::GlooWrapper::Init)
|
||||
.def("rank", &framework::GlooWrapper::Rank)
|
||||
.def("size", &framework::GlooWrapper::Size)
|
||||
.def("barrier", &framework::GlooWrapper::Barrier)
|
||||
.def("all_reduce", &framework::GlooWrapper::AllReduce<int64_t>)
|
||||
.def("all_reduce", &framework::GlooWrapper::AllReduce<double>)
|
||||
.def("all_gather", &framework::GlooWrapper::AllGather<int64_t>)
|
||||
.def("all_gather", &framework::GlooWrapper::AllGather<double>)
|
||||
.def("Allreduce", &framework::GlooWrapper::AllReduce<int64_t>)
|
||||
.def("Allreduce", &framework::GlooWrapper::AllReduce<double>);
|
||||
} // end BindGlooWrapper
|
||||
} // end namespace pybind
|
||||
} // end namespace paddle
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
void BindGlooWrapper(py::module* m);
|
||||
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
Loading…
Reference in new issue