add heter ps mode (#25682)
* add heter ps mode * code style test=develop * add with_pslib test=develop * unitest test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * code style test=develop * test monitor test=develop * prepare trainer test=develop * code style test=developrevert-24895-update_cub
parent
c8d0d1419b
commit
0cb60c700d
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,123 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <ctime>
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_PSLIB
|
||||||
|
#include "paddle/fluid/framework/heter_service.h"
|
||||||
|
#include "paddle/fluid/framework/scope.h"
|
||||||
|
#include "paddle/fluid/framework/tensor.h"
|
||||||
|
#include "paddle/fluid/framework/variable_helper.h"
|
||||||
|
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
|
||||||
|
class HeterCpuWorker;
|
||||||
|
|
||||||
|
typedef std::function<void(void*)> HeterRpcCallbackFunc;
|
||||||
|
|
||||||
|
class OnHeterRpcDone : public google::protobuf::Closure {
|
||||||
|
public:
|
||||||
|
OnHeterRpcDone(HeterRpcCallbackFunc func) : handler_(func) {}
|
||||||
|
virtual ~OnHeterRpcDone() {}
|
||||||
|
void Run() {
|
||||||
|
std::unique_ptr<OnHeterRpcDone> self_guard(this);
|
||||||
|
handler_(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
HeterRpcCallbackFunc handler_;
|
||||||
|
HeterResponse response;
|
||||||
|
brpc::Controller cntl;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HeterWrapper {
|
||||||
|
public:
|
||||||
|
virtual ~HeterWrapper() {
|
||||||
|
server_.Stop(1000);
|
||||||
|
server_.Join();
|
||||||
|
}
|
||||||
|
|
||||||
|
HeterWrapper() {}
|
||||||
|
|
||||||
|
static void HeterRpcCallBack(HeterResponse* response, brpc::Controller* cntl,
|
||||||
|
HeterCpuWorker* worker,
|
||||||
|
std::shared_ptr<HeterTask> task);
|
||||||
|
|
||||||
|
void CreateClient2XpuConnection();
|
||||||
|
|
||||||
|
void RegisterServiceHandler(int cmd, HeterServiceHandler func);
|
||||||
|
|
||||||
|
void StartXpuService(const std::string& ip, uint32_t port);
|
||||||
|
|
||||||
|
void CallRemoteXpu(std::shared_ptr<HeterTask> task, HeterCpuWorker* worker,
|
||||||
|
int mpi_rank, std::vector<std::string>& send_vars);
|
||||||
|
|
||||||
|
void CallRemoteXpuSync(std::shared_ptr<HeterTask> task,
|
||||||
|
HeterCpuWorker* worker, int mpi_rank,
|
||||||
|
std::vector<std::string>& send_vars);
|
||||||
|
|
||||||
|
void StopXpuService(int num);
|
||||||
|
|
||||||
|
void EndPass(Scope* scope, int num);
|
||||||
|
|
||||||
|
void SerializeToReq(const std::string& varname, Scope* scope,
|
||||||
|
VariableMessage* req_var);
|
||||||
|
|
||||||
|
framework::proto::VarType::Type ToVarType(VariableMessage::Type type);
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
|
||||||
|
platform::Place place,
|
||||||
|
cudaStream_t stream = nullptr);
|
||||||
|
#else
|
||||||
|
void DeSerializeToTensor(Scope* scope, const VariableMessage& req_var,
|
||||||
|
platform::Place place);
|
||||||
|
#endif
|
||||||
|
// HeterWrapper singleton
|
||||||
|
static std::shared_ptr<HeterWrapper> GetInstance() {
|
||||||
|
if (NULL == s_instance_) {
|
||||||
|
s_instance_.reset(new paddle::framework::HeterWrapper());
|
||||||
|
}
|
||||||
|
return s_instance_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string>& GetXpuList() { return xpu_list_; }
|
||||||
|
|
||||||
|
void SetXpuList(const std::vector<std::string>& xpu_list);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static std::shared_ptr<HeterWrapper> s_instance_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
|
||||||
|
brpc::Server server_;
|
||||||
|
HeterXpuService service_;
|
||||||
|
static bool is_initialized_;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(HeterWrapper);
|
||||||
|
std::vector<std::string> xpu_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace framework
|
||||||
|
} // end namespace paddle
|
||||||
|
#endif
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,69 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
syntax = "proto2";
|
||||||
|
package paddle.framework;
|
||||||
|
option cc_generic_services = true;
|
||||||
|
|
||||||
|
// It can be: LoDTensor、SelectedRows or NCCL_ID
|
||||||
|
enum VarType {
|
||||||
|
LOD_TENSOR = 0;
|
||||||
|
SELECTED_ROWS = 1;
|
||||||
|
NCCL_ID = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// VariableMessage is serialized paddle variable message.
|
||||||
|
// NOTICE(gongwb):don't modify this proto if you are not
|
||||||
|
// not familar with how we serialize in sendrecvop_utils.h
|
||||||
|
// and deserilize it in variable_response.h.
|
||||||
|
message VariableMessage {
|
||||||
|
enum Type {
|
||||||
|
// Pod Types
|
||||||
|
BOOL = 0;
|
||||||
|
INT16 = 1;
|
||||||
|
INT32 = 2;
|
||||||
|
INT64 = 3;
|
||||||
|
FP16 = 4;
|
||||||
|
FP32 = 5;
|
||||||
|
FP64 = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message LodData { repeated int64 lod_data = 1; }
|
||||||
|
optional string varname = 1;
|
||||||
|
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
|
||||||
|
optional VarType type = 2;
|
||||||
|
// bool persistable is not needed for sending.
|
||||||
|
// tensor info:
|
||||||
|
optional Type data_type = 3;
|
||||||
|
repeated int64 dims = 4;
|
||||||
|
|
||||||
|
// lod details:
|
||||||
|
optional int64 lod_level = 5;
|
||||||
|
repeated LodData lod = 6;
|
||||||
|
// selected_rows height, aka. original dim0
|
||||||
|
optional int64 slr_height = 7;
|
||||||
|
// tensor data
|
||||||
|
optional bytes data = 8;
|
||||||
|
}
|
||||||
|
message HeterRequest {
|
||||||
|
required int32 cmd = 1;
|
||||||
|
optional int32 cur_batch = 2;
|
||||||
|
repeated VariableMessage vars = 3;
|
||||||
|
};
|
||||||
|
|
||||||
|
message HeterResponse {
|
||||||
|
// optional VariableMessage vars = 1;
|
||||||
|
repeated VariableMessage vars = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
service HeterService { rpc service(HeterRequest) returns (HeterResponse); };
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#include <fcntl.h>
|
||||||
|
|
||||||
|
#ifdef _POSIX_C_SOURCE
|
||||||
|
#undef _POSIX_C_SOURCE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef _XOPEN_SOURCE
|
||||||
|
#undef _XOPEN_SOURCE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
||||||
|
#include "google/protobuf/text_format.h"
|
||||||
|
#include "paddle/fluid/framework/fleet/heter_wrapper.h"
|
||||||
|
#include "paddle/fluid/pybind/heter_wrapper_py.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace pybind {
|
||||||
|
#ifdef PADDLE_WITH_PSLIB
|
||||||
|
void BindHeterWrapper(py::module* m) {
|
||||||
|
py::class_<framework::HeterWrapper, std::shared_ptr<framework::HeterWrapper>>(
|
||||||
|
*m, "Heter")
|
||||||
|
.def(py::init([]() { return framework::HeterWrapper::GetInstance(); }))
|
||||||
|
.def("create_client2xpu_connection",
|
||||||
|
&framework::HeterWrapper::CreateClient2XpuConnection)
|
||||||
|
.def("set_xpu_list", &framework::HeterWrapper::SetXpuList)
|
||||||
|
.def("start_xpu_service", &framework::HeterWrapper::StartXpuService)
|
||||||
|
.def("end_pass", &framework::HeterWrapper::EndPass)
|
||||||
|
.def("stop_xpu_service", &framework::HeterWrapper::StopXpuService);
|
||||||
|
} // end HeterWrapper
|
||||||
|
#endif
|
||||||
|
} // end namespace pybind
|
||||||
|
} // end namespace paddle
|
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/stl.h"
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace pybind {
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_PSLIB
|
||||||
|
void BindHeterWrapper(py::module* m);
|
||||||
|
#endif
|
||||||
|
} // namespace pybind
|
||||||
|
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue