add heter ps mode (#25682)
* add heter ps mode * code style test=develop * add with_pslib test=develop * unitest test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * test monitor test=develop * prepare trainer test=develop * code style test=developrevert-24895-update_cub
parent
c8d0d1419b
commit
0cb60c700d
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,123 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#ifdef PADDLE_WITH_PSLIB
|
||||
#include "paddle/fluid/framework/heter_service.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/framework/tensor.h"
|
||||
#include "paddle/fluid/framework/variable_helper.h"
|
||||
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class HeterCpuWorker;
|
||||
|
||||
typedef std::function<void(void*)> HeterRpcCallbackFunc;
|
||||
|
||||
class OnHeterRpcDone : public google::protobuf::Closure {
|
||||
public:
|
||||
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
|
||||
virtual ~OnHeterRpcDone() {}
|
||||
void Run() {
|
||||
std::unique_ptr<OnHeterRpcDone> self_guard(this);
|
||||
handler_(this);
|
||||
}
|
||||
|
||||
HeterRpcCallbackFunc handler_;
|
||||
HeterResponse response;
|
||||
brpc::Controller cntl;
|
||||
};
|
||||
|
||||
class HeterWrapper {
|
||||
public:
|
||||
virtual ~HeterWrapper() {
|
||||
server_.Stop(1000);
|
||||
server_.Join();
|
||||
}
|
||||
|
||||
HeterWrapper() {}
|
||||
|
||||
static void HeterRpcCallBack(HeterResponse* response, brpc::Controller* cntl,
|
||||
HeterCpuWorker* worker,
|
||||
std::shared_ptr<HeterTask> task);
|
||||
|
||||
void CreateClient2XpuConnection();
|
||||
|
||||
void RegisterServiceHandler(int cmd, HeterServiceHandler func);
|
||||
|
||||
void StartXpuService(const std::string& ip, uint32_t port);
|
||||
|
||||
void CallRemoteXpu(std::shared_ptr<HeterTask> task, HeterCpuWorker* worker,
|
||||
int mpi_rank, std::vector<std::string>& send_vars);
|
||||
|
||||
void CallRemoteXpuSync(std::shared_ptr<HeterTask> task,
|
||||
HeterCpuWorker* worker, int mpi_rank,
|
||||
std::vector<std::string>& send_vars);
|
||||
|
||||
void StopXpuService(int num);
|
||||
|
||||
void EndPass(Scope* scope, int num);
|
||||
|
||||
void SerializeToReq(const std::string& varname, Scope* scope,
|
||||
VariableMessage* req_var);
|
||||
|
||||
framework::proto::VarType::Type ToVarType(VariableMessage::Type type);
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
|
||||
platform::Place place,
|
||||
cudaStream_t stream = nullptr);
|
||||
#else
|
||||
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
|
||||
platform::Place place);
|
||||
#endif
|
||||
// HeterWrapper singleton
|
||||
static std::shared_ptr<HeterWrapper> GetInstance() {
|
||||
if (NULL == s_instance_) {
|
||||
s_instance_.reset(new paddle::framework::HeterWrapper());
|
||||
}
|
||||
return s_instance_;
|
||||
}
|
||||
|
||||
std::vector<std::string>& GetXpuList() { return xpu_list_; }
|
||||
|
||||
void SetXpuList(const std::vector<std::string>& xpu_list);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<HeterWrapper> s_instance_;
|
||||
|
||||
protected:
|
||||
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
|
||||
brpc::Server server_;
|
||||
HeterXpuService service_;
|
||||
static bool is_initialized_;
|
||||
DISABLE_COPY_AND_ASSIGN(HeterWrapper);
|
||||
std::vector<std::string> xpu_list_;
|
||||
};
|
||||
|
||||
} // end namespace framework
|
||||
} // end namespace paddle
|
||||
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
syntax = "proto2";
|
||||
package paddle.framework;
|
||||
option cc_generic_services = true;
|
||||
|
||||
// It can be: LoDTensor、SelectedRows or NCCL_ID
|
||||
enum VarType {
|
||||
LOD_TENSOR = 0;
|
||||
SELECTED_ROWS = 1;
|
||||
NCCL_ID = 2;
|
||||
}
|
||||
|
||||
// VariableMessage is serialized paddle variable message.
|
||||
// NOTICE(gongwb):don't modify this proto if you are not
|
||||
// not familar with how we serialize in sendrecvop_utils.h
|
||||
// and deserilize it in variable_response.h.
|
||||
message VariableMessage {
|
||||
enum Type {
|
||||
// Pod Types
|
||||
BOOL = 0;
|
||||
INT16 = 1;
|
||||
INT32 = 2;
|
||||
INT64 = 3;
|
||||
FP16 = 4;
|
||||
FP32 = 5;
|
||||
FP64 = 6;
|
||||
}
|
||||
|
||||
message LodData { repeated int64 lod_data = 1; }
|
||||
optional string varname = 1;
|
||||
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
|
||||
optional VarType type = 2;
|
||||
// bool persistable is not needed for sending.
|
||||
// tensor info:
|
||||
optional Type data_type = 3;
|
||||
repeated int64 dims = 4;
|
||||
|
||||
// lod details:
|
||||
optional int64 lod_level = 5;
|
||||
repeated LodData lod = 6;
|
||||
// selected_rows height, aka. original dim0
|
||||
optional int64 slr_height = 7;
|
||||
// tensor data
|
||||
optional bytes data = 8;
|
||||
}
|
||||
message HeterRequest {
|
||||
required int32 cmd = 1;
|
||||
optional int32 cur_batch = 2;
|
||||
repeated VariableMessage vars = 3;
|
||||
};
|
||||
|
||||
message HeterResponse {
|
||||
// optional VariableMessage vars = 1;
|
||||
repeated VariableMessage vars = 1;
|
||||
};
|
||||
|
||||
service HeterService { rpc service(HeterRequest) returns (HeterResponse); };
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#include <fcntl.h>
|
||||
|
||||
#ifdef _POSIX_C_SOURCE
|
||||
#undef _POSIX_C_SOURCE
|
||||
#endif
|
||||
|
||||
#ifdef _XOPEN_SOURCE
|
||||
#undef _XOPEN_SOURCE
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
|
||||
#include "paddle/fluid/pybind/heter_wrapper_py.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
#ifdef PADDLE_WITH_PSLIB
|
||||
void BindHeterWrapper(py::module* m) {
|
||||
py::class_<framework::HeterWrapper, std::shared_ptr<framework::HeterWrapper>>(
|
||||
*m, "Heter")
|
||||
.def(py::init([]() { return framework::HeterWrapper::GetInstance(); }))
|
||||
.def("create_client2xpu_connection",
|
||||
&framework::HeterWrapper::CreateClient2XpuConnection)
|
||||
.def("set_xpu_list", &framework::HeterWrapper::SetXpuList)
|
||||
.def("start_xpu_service", &framework::HeterWrapper::StartXpuService)
|
||||
.def("end_pass", &framework::HeterWrapper::EndPass)
|
||||
.def("stop_xpu_service", &framework::HeterWrapper::StopXpuService);
|
||||
} // end HeterWrapper
|
||||
#endif
|
||||
} // end namespace pybind
|
||||
} // end namespace paddle
|
@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace paddle {
|
||||
namespace pybind {
|
||||
|
||||
#ifdef PADDLE_WITH_PSLIB
|
||||
void BindHeterWrapper(py::module* m);
|
||||
#endif
|
||||
} // namespace pybind
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue