Add reduce sparse tensor feature. (#14757)
parent
c83d5b7a16
commit
f1fb64b17f
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <condition_variable> // NOLINT
|
||||
#include <string>
|
||||
#include "gflags/gflags.h"
|
||||
|
||||
#include "paddle/fluid/operators/distributed/collective_client.h"
|
||||
|
||||
DECLARE_int32(rpc_deadline);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
std::once_flag CollectiveClient::init_flag_;
|
||||
std::unique_ptr<CollectiveClient> CollectiveClient::client_(nullptr);
|
||||
|
||||
bool CollectiveClient::Gather(const std::vector<RemoteVar>& remote_vars,
|
||||
std::vector<const framework::SelectedRows*>* dst,
|
||||
const platform::DeviceContext& ctx,
|
||||
framework::Scope* scope, int64_t time_out) {
|
||||
for (auto r : remote_vars) {
|
||||
VLOG(50) << "begin gather from ep:" << r.String();
|
||||
scope->Var(r.var_name_)->GetMutable<framework::SelectedRows>();
|
||||
VarHandlePtr ptr = rpc_client_->AsyncGetMonomerVariable(
|
||||
r.ep_, ctx, *scope, r.var_name_, time_out);
|
||||
}
|
||||
|
||||
rpc_client_->Wait();
|
||||
|
||||
for (auto r : remote_vars) {
|
||||
auto select_rows =
|
||||
scope->FindVar(r.var_name_)->GetMutable<framework::SelectedRows>();
|
||||
dst->push_back(select_rows);
|
||||
|
||||
VLOG(4) << "gather from ep:" << r.String()
|
||||
<< ", select_rows:" << GetSelectedRowsInfo(*select_rows);
|
||||
|
||||
rpc_client_->AsyncGetMonomerBarrier(r.ep_, r.var_name_);
|
||||
}
|
||||
|
||||
rpc_client_->Wait();
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,93 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "gflags/gflags.h"
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/operators/detail/macros.h"
|
||||
#include "paddle/fluid/operators/distributed/request_handler.h"
|
||||
|
||||
DECLARE_int32(rpc_deadline);
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
inline std::string GetSelectedRowsInfo(const framework::SelectedRows& slr) {
|
||||
std::stringstream ss;
|
||||
ss << ", height:" << slr.height() << ", rows:[";
|
||||
for (unsigned int i = 0; i < slr.rows().size(); i++) {
|
||||
if (i != slr.rows().size() - 1) {
|
||||
ss << slr.rows()[i] << ",";
|
||||
} else {
|
||||
ss << slr.rows()[i];
|
||||
}
|
||||
}
|
||||
ss << "], dims:" << slr.value().dims();
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
struct RemoteVar {
|
||||
std::string ep_;
|
||||
std::string var_name_;
|
||||
int trainer_id_{0};
|
||||
|
||||
std::string String() {
|
||||
std::stringstream ss;
|
||||
ss << "ep:" << ep_ << ", var_name:" << var_name_
|
||||
<< ", trainer_id:" << trainer_id_;
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
class CollectiveClient {
|
||||
public:
|
||||
CollectiveClient() {
|
||||
rpc_client_.reset(new RPCCLIENT_T());
|
||||
rpc_client_->InitImpl();
|
||||
}
|
||||
virtual ~CollectiveClient() {}
|
||||
|
||||
// note this function will retain the rank order.
|
||||
bool Gather(const std::vector<RemoteVar>& remote_vars,
|
||||
std::vector<const framework::SelectedRows*>* dst,
|
||||
const platform::DeviceContext& ctx, framework::Scope* scope,
|
||||
int64_t time_out = FLAGS_rpc_deadline);
|
||||
|
||||
static CollectiveClient* GetInstance() {
|
||||
std::call_once(init_flag_, [&]() {
|
||||
if (client_.get() == nullptr) {
|
||||
client_.reset(new CollectiveClient());
|
||||
}
|
||||
});
|
||||
return client_.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<RPCClient> rpc_client_;
|
||||
|
||||
static std::once_flag init_flag_;
|
||||
static std::unique_ptr<CollectiveClient> client_;
|
||||
};
|
||||
} // namespace distributed
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,74 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <stdio.h> // for removing the port file
|
||||
#include <csignal>
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/operators/distributed/collective_server.h"
|
||||
|
||||
DEFINE_int32(collective_get_thread_num, 5, "number of threads for rpc get");
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
std::once_flag CollectiveServer::init_flag_;
|
||||
std::shared_ptr<CollectiveServer> CollectiveServer::collective_server_(nullptr);
|
||||
|
||||
CollectiveServer::CollectiveServer(const std::string& end_point, int fan_in) {
|
||||
VLOG(1) << "Create colllective server:" << end_point << ", fan_in:" << fan_in;
|
||||
rpc_server_.reset(new RPCSERVER_T(end_point, fan_in));
|
||||
}
|
||||
|
||||
void CollectiveServer::Stop() {
|
||||
rpc_server_->ShutDown();
|
||||
server_thread_->join();
|
||||
loop_thread_->join();
|
||||
}
|
||||
|
||||
void CollectiveServer::StartServer() {
|
||||
get_monomer_handler_.reset(new GetMonomerHandler());
|
||||
get_monomer_handler_->SetRPCServer(rpc_server_.get());
|
||||
|
||||
get_barrier_handler_.reset(new GetMonomerBarrierHandler());
|
||||
get_barrier_handler_->SetRPCServer(rpc_server_.get());
|
||||
|
||||
rpc_server_->RegisterRPC(distributed::kRequestGetMonomerVariable,
|
||||
get_monomer_handler_.get(),
|
||||
FLAGS_collective_get_thread_num);
|
||||
rpc_server_->RegisterRPC(distributed::kRequestGetMonomerBarrier,
|
||||
get_barrier_handler_.get(), 1);
|
||||
|
||||
server_thread_.reset(new std::thread([&]() { rpc_server_->StartServer(); }));
|
||||
rpc_server_->WaitServerReady();
|
||||
|
||||
loop_thread_.reset(new std::thread([&]() {
|
||||
while (true) {
|
||||
if (rpc_server_->IsExit()) {
|
||||
LOG(WARNING) << "get exit!rpc_processor break!";
|
||||
break;
|
||||
}
|
||||
sleep(1);
|
||||
}
|
||||
VLOG(1) << "CollectiveServer loop_thread end";
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace distributed
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -0,0 +1,110 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
|
||||
#include "paddle/fluid/operators/detail/macros.h"
|
||||
#include "paddle/fluid/operators/distributed/request_handler.h"
|
||||
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
|
||||
#include "paddle/fluid/operators/distributed/rpc_server.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace distributed {
|
||||
|
||||
class CollectiveServer;
|
||||
|
||||
class GetMonomerHandler final : public RequestHandler {
|
||||
public:
|
||||
GetMonomerHandler() : RequestHandler(true) {}
|
||||
virtual ~GetMonomerHandler() {}
|
||||
bool Handle(const std::string& var_name, framework::Scope* scope,
|
||||
framework::Variable* var, framework::Variable** outvar,
|
||||
const int trainer_id, const std::string& out_var_name = "",
|
||||
const std::string& table_name = "") override {
|
||||
VLOG(50) << "GetMonomerHandler recv " << var_name;
|
||||
|
||||
*outvar = scope->FindVar(var_name);
|
||||
PADDLE_ENFORCE(outvar != nullptr, "%s not found", var_name);
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class GetMonomerBarrierHandler final : public RequestHandler {
|
||||
public:
|
||||
GetMonomerBarrierHandler() : RequestHandler(true) {}
|
||||
virtual ~GetMonomerBarrierHandler() {}
|
||||
bool Handle(const std::string& var_name, framework::Scope* scope,
|
||||
framework::Variable* var, framework::Variable** outvar,
|
||||
const int trainer_id, const std::string& out_var_name = "",
|
||||
const std::string& table_name = "") override {
|
||||
VLOG(50) << "GetMonomerHandler recv " << var_name;
|
||||
|
||||
rpc_server_->IncreaseVarBarrier(var_name);
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class CollectiveServer final {
|
||||
public:
|
||||
explicit CollectiveServer(const std::string& end_point, int fan_in);
|
||||
|
||||
virtual ~CollectiveServer() {}
|
||||
|
||||
void StartServer();
|
||||
|
||||
static CollectiveServer* GetInstance(const std::string& end_point,
|
||||
int fan_in) {
|
||||
std::call_once(init_flag_, [&]() {
|
||||
if (collective_server_.get() == nullptr) {
|
||||
collective_server_.reset(new CollectiveServer(end_point, fan_in));
|
||||
collective_server_->StartServer();
|
||||
}
|
||||
});
|
||||
|
||||
return collective_server_.get();
|
||||
}
|
||||
|
||||
std::shared_ptr<RPCServer> GetRPCServer() { return rpc_server_; }
|
||||
|
||||
void Stop();
|
||||
|
||||
private:
|
||||
std::unique_ptr<GetMonomerHandler> get_monomer_handler_;
|
||||
std::unique_ptr<GetMonomerBarrierHandler> get_barrier_handler_;
|
||||
|
||||
std::shared_ptr<distributed::RPCServer> rpc_server_;
|
||||
std::shared_ptr<std::thread> server_thread_;
|
||||
std::shared_ptr<std::thread> loop_thread_;
|
||||
|
||||
bool ready_{false};
|
||||
|
||||
static std::once_flag init_flag_;
|
||||
static std::shared_ptr<CollectiveServer> collective_server_;
|
||||
};
|
||||
|
||||
}; // namespace distributed
|
||||
}; // namespace operators
|
||||
}; // namespace paddle
|
@ -0,0 +1,115 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/framework/block_desc.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
#include "paddle/fluid/operators/detail/macros.h"
|
||||
#include "paddle/fluid/operators/distributed/collective_client.h"
|
||||
#include "paddle/fluid/operators/distributed/collective_server.h"
|
||||
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace framework = paddle::framework;
|
||||
namespace platform = paddle::platform;
|
||||
namespace distributed = paddle::operators::distributed;
|
||||
|
||||
std::unique_ptr<distributed::CollectiveServer> StartServer(
|
||||
const std::string& ep, int fan_in, framework::Scope* scope,
|
||||
platform::DeviceContext* dev_ctx) {
|
||||
distributed::CollectiveServer* server =
|
||||
distributed::CollectiveServer::GetInstance(ep, fan_in);
|
||||
|
||||
auto rpc_server = server->GetRPCServer();
|
||||
rpc_server->RegisterVar("var1", distributed::kRequestGetMonomerVariable,
|
||||
scope, dev_ctx);
|
||||
|
||||
std::cout << "StartServer return" << std::endl;
|
||||
return std::unique_ptr<distributed::CollectiveServer>(server);
|
||||
}
|
||||
|
||||
std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto& ctx = *pool.Get(place);
|
||||
|
||||
framework::Scope* scope = new framework::Scope();
|
||||
framework::Variable* var = scope->Var("var1");
|
||||
auto* slr = var->GetMutable<framework::SelectedRows>();
|
||||
slr->set_height(1000);
|
||||
|
||||
auto* tensor = slr->mutable_value();
|
||||
auto* rows = slr->mutable_rows();
|
||||
|
||||
tensor->Resize(framework::make_ddim({3, 5}));
|
||||
tensor->mutable_data<float>(place);
|
||||
|
||||
paddle::operators::math::set_constant(ctx, tensor, 32.7);
|
||||
for (int i = 0; i < 3; ++i) rows->push_back(i);
|
||||
|
||||
std::cout << "src:" << distributed::GetSelectedRowsInfo(*slr);
|
||||
|
||||
return std::unique_ptr<framework::Scope>(scope);
|
||||
}
|
||||
|
||||
void Gather(const std::vector<distributed::RemoteVar>& vars,
|
||||
platform::DeviceContext* dev_ctx) {
|
||||
distributed::CollectiveClient* client =
|
||||
distributed::CollectiveClient::GetInstance();
|
||||
|
||||
framework::Scope* scope = new framework::Scope();
|
||||
framework::Variable* var = scope->Var("var1");
|
||||
var->GetMutable<framework::SelectedRows>();
|
||||
|
||||
std::vector<const framework::SelectedRows*> dst;
|
||||
client->Gather(vars, &dst, *dev_ctx, scope);
|
||||
std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]);
|
||||
}
|
||||
|
||||
TEST(PREFETCH, GPU) {
|
||||
platform::CUDAPlace place;
|
||||
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
|
||||
auto& ctx = *pool.Get(place);
|
||||
|
||||
std::string ep = "127.0.0.1:7164";
|
||||
auto scope = GenerateVars(place);
|
||||
|
||||
auto* v1 = scope->FindVar("var1");
|
||||
std::cout << "var1:" << v1 << std::endl;
|
||||
|
||||
auto server = StartServer(ep, 2, scope.get(), &ctx);
|
||||
auto rpc_server = server->GetRPCServer();
|
||||
|
||||
distributed::RemoteVar var;
|
||||
var.ep_ = ep;
|
||||
var.var_name_ = "var1";
|
||||
var.trainer_id_ = 0;
|
||||
|
||||
std::vector<distributed::RemoteVar> vars{var};
|
||||
Gather(vars, &ctx);
|
||||
Gather(vars, &ctx);
|
||||
|
||||
std::cout << "begin WaitVarBarrier" << std::endl;
|
||||
rpc_server->WaitVarBarrier("var1");
|
||||
rpc_server->ClearRegisteredVars();
|
||||
server->Stop();
|
||||
|
||||
scope.release();
|
||||
server.release();
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue