parent
8aa78c2c8e
commit
96d8c411e7
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/cluster_config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
uint32_t ClusterConfig::worker_num_ = 0;
|
||||
uint32_t ClusterConfig::server_num_ = 0;
|
||||
uint32_t ClusterConfig::heartbeat_interval_ = kHeartbeatInterval;
|
||||
std::unique_ptr<std::string> ClusterConfig::scheduler_host_ = nullptr;
|
||||
uint16_t ClusterConfig::scheduler_port_ = 0;
|
||||
|
||||
void ClusterConfig::Init(const uint32_t &worker_num, const uint32_t &server_num,
|
||||
std::unique_ptr<std::string> scheduler_host, const uint16_t &scheduler_port) {
|
||||
worker_num_ = worker_num;
|
||||
server_num_ = server_num;
|
||||
if (!CommUtil::CheckIp(*scheduler_host.get())) {
|
||||
MS_LOG(EXCEPTION) << "The scheduler_host:" << *scheduler_host.get() << " is illegal!";
|
||||
}
|
||||
scheduler_host_ = std::move(scheduler_host);
|
||||
scheduler_port_ = scheduler_port;
|
||||
}
|
||||
|
||||
uint32_t ClusterConfig::worker_num() { return worker_num_; }
|
||||
|
||||
uint32_t ClusterConfig::server_num() { return server_num_; }
|
||||
|
||||
uint32_t ClusterConfig::heartbeat_interval() { return heartbeat_interval_; }
|
||||
|
||||
void ClusterConfig::set_heartbeat_interval(const uint32_t &heartbeat_interval) {
|
||||
heartbeat_interval_ = heartbeat_interval;
|
||||
}
|
||||
|
||||
std::string ClusterConfig::scheduler_host() { return *scheduler_host_.get(); }
|
||||
|
||||
uint16_t ClusterConfig::scheduler_port() { return scheduler_port_; }
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
||||
#define MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
|
||||
constexpr uint32_t kHeartbeatInterval = 3;
|
||||
|
||||
class ClusterConfig {
|
||||
public:
|
||||
static void Init(const uint32_t &worker_num, const uint32_t &server_num, std::unique_ptr<std::string> scheduler_host,
|
||||
const uint16_t &scheduler_port);
|
||||
static uint32_t worker_num();
|
||||
static uint32_t server_num();
|
||||
static uint32_t heartbeat_interval();
|
||||
static void set_heartbeat_interval(const uint32_t &heartbeat_interval);
|
||||
static std::string scheduler_host();
|
||||
static uint16_t scheduler_port();
|
||||
|
||||
private:
|
||||
static uint32_t worker_num_;
|
||||
static uint32_t server_num_;
|
||||
static uint32_t heartbeat_interval_;
|
||||
static std::unique_ptr<std::string> scheduler_host_;
|
||||
static uint16_t scheduler_port_;
|
||||
};
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PS_CORE_CLUSTER_CONFIG_H_
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/core/cluster_config.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class TestClusterConfig : public UT::Common {
|
||||
public:
|
||||
TestClusterConfig() = default;
|
||||
virtual ~TestClusterConfig() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestClusterConfig, HeartbeatInterval) {
|
||||
ClusterConfig::Init(2, 2, std::make_unique<std::string>("127.0.0.1"), 8080);
|
||||
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 3);
|
||||
ClusterConfig::set_heartbeat_interval(100);
|
||||
EXPECT_TRUE(ClusterConfig::heartbeat_interval() == 100);
|
||||
EXPECT_STREQ(ClusterConfig::scheduler_host().c_str(), "127.0.0.1");
|
||||
EXPECT_TRUE(ClusterConfig::scheduler_port() == 8080);
|
||||
}
|
||||
|
||||
} // namespace core
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "ps/core/comm_util.h"
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class TestCommUtil : public UT::Common {
|
||||
public:
|
||||
TestCommUtil() = default;
|
||||
virtual ~TestCommUtil() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestCommUtil, GetAvailableInterfaceAndIP) {
|
||||
std::string interface;
|
||||
std::string ip;
|
||||
CommUtil::GetAvailableInterfaceAndIP(&interface, &ip);
|
||||
EXPECT_TRUE(!interface.empty());
|
||||
EXPECT_TRUE(!ip.empty());
|
||||
}
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,163 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/core/tcp_message_handler.h"
|
||||
#include "common/common_test.h"
|
||||
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
namespace core {
|
||||
class TestTcpMessageHandler : public UT::Common {
|
||||
public:
|
||||
using messageReceive = std::function<void(const CommMessage &message)>;
|
||||
TestTcpMessageHandler() = default;
|
||||
virtual ~TestTcpMessageHandler() = default;
|
||||
|
||||
void SetUp() override {}
|
||||
void TearDown() override {}
|
||||
};
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[1007];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
handler.ReceiveMessage(result, buf_size + 4);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_1003_Data_4_Header_1003_Data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 1000); });
|
||||
|
||||
std::string data(1000, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[2014];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 4 + buf_size + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 2 * buf_size + 4 * 2);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_4090_Data_2_Header_2_header_4090_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4087); });
|
||||
|
||||
std::string data(4087, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[4096];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + 4 + buf_size, 2, &buf_size, 2);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
ret = memcpy_s(result, 2, &buf_size + 2, 2);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
ret = memcpy_s(result + 2, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4092);
|
||||
}
|
||||
|
||||
TEST_F(TestTcpMessageHandler, 4_Header_4088_Data_4_Header_4088_data) {
|
||||
TcpMessageHandler handler;
|
||||
handler.SetCallback([this](const CommMessage &message) { EXPECT_EQ(message.data().size(), 4085); });
|
||||
|
||||
std::string data(4085, 'a');
|
||||
CommMessage message;
|
||||
message.set_data(data);
|
||||
uint32_t buf_size = message.ByteSizeLong();
|
||||
char result[4096];
|
||||
int ret = memcpy_s(result, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
std::vector<char> serialized(buf_size);
|
||||
message.SerializeToArray(serialized.data(), static_cast<int>(buf_size));
|
||||
ret = memcpy_s(result + 4, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
ret = memcpy_s(result + 4 + buf_size, 4, &buf_size, 4);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4096);
|
||||
|
||||
ret = memcpy_s(result, buf_size, serialized.data(), buf_size);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
|
||||
}
|
||||
|
||||
handler.ReceiveMessage(result, 4088);
|
||||
}
|
||||
|
||||
} // namespace comm
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
Loading…
Reference in new issue