parent
452cb0dd4e
commit
a3b9218919
@ -0,0 +1,207 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cxx_api/model/model_converter_utils/multi_process.h"
|
||||
#include <unistd.h>
|
||||
#include <sys/wait.h>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "cxx_api/model/model_converter_utils/shared_memory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
namespace {
|
||||
uint64_t kSharedMemorySize = 100ull << 20; // 100 MB
|
||||
}
|
||||
|
||||
MultiProcess::MultiProcess() = default;
|
||||
|
||||
MultiProcess::~MultiProcess() = default;
|
||||
|
||||
Status MultiProcess::MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process) {
|
||||
MS_EXCEPTION_IF_NULL(parent_process);
|
||||
MS_EXCEPTION_IF_NULL(child_process);
|
||||
Status ret;
|
||||
memory_size_ = kSharedMemorySize; // 100 MB
|
||||
SharedMemory shared_memory;
|
||||
ret = shared_memory.Create(memory_size_);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Create shared memory failed";
|
||||
return ret;
|
||||
}
|
||||
pid_t pid = fork();
|
||||
if (pid < 0) {
|
||||
shared_memory.Destroy();
|
||||
MS_LOG_ERROR << "Fork process to convert model failed";
|
||||
return FAILED;
|
||||
}
|
||||
ret = shared_memory.Attach();
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Process attach shared memory failed, pid " << pid;
|
||||
return ret;
|
||||
}
|
||||
shmat_addr_ = shared_memory.GetSharedMemoryAddr();
|
||||
if (shmat_addr_ == nullptr) {
|
||||
MS_LOG_ERROR << "Get shared memory failed";
|
||||
return ret;
|
||||
}
|
||||
shmat_data_addr_ = shmat_addr_ + sizeof(MessageFlag) * 2;
|
||||
shmat_data_max_size_ = memory_size_ - (shmat_data_addr_ - shmat_addr_);
|
||||
|
||||
MS_LOG_INFO << "Shm addr " << (uint64_t)shmat_addr_;
|
||||
if (pid == 0) {
|
||||
ChildProcess(child_process);
|
||||
shared_memory.Detach();
|
||||
MS_LOG_INFO << "Model converter: child process exit";
|
||||
exit(0);
|
||||
} else { // parent process
|
||||
ret = ParentProcess(parent_process);
|
||||
shared_memory.Detach();
|
||||
int status;
|
||||
wait(&status);
|
||||
shared_memory.Destroy();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
Status MultiProcess::ParentProcess(ProcessFuncCall parent_process) {
|
||||
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
|
||||
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
|
||||
send_msg_ = parent_msg;
|
||||
receive_msg_ = child_msg;
|
||||
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
|
||||
Status ret;
|
||||
try {
|
||||
ret = parent_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Parent process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_ERROR << "Catch parent process runtime error: " << ex.what();
|
||||
ret = FAILED;
|
||||
}
|
||||
stopped_ = true;
|
||||
send_msg_->stop = true;
|
||||
heartbeat_thread.join();
|
||||
return ret;
|
||||
}
|
||||
|
||||
void MultiProcess::ChildProcess(ProcessFuncCall child_process) {
|
||||
auto parent_msg = reinterpret_cast<MessageFlag *>(shmat_addr_);
|
||||
auto child_msg = reinterpret_cast<MessageFlag *>(shmat_addr_ + sizeof(MessageFlag));
|
||||
send_msg_ = child_msg;
|
||||
receive_msg_ = parent_msg;
|
||||
std::thread heartbeat_thread(MultiProcess::HeartbeatThreadFunc, this);
|
||||
try {
|
||||
auto ret = child_process(this);
|
||||
if (!ret.IsSuccess()) {
|
||||
MS_LOG_ERROR << "Child process process failed";
|
||||
}
|
||||
} catch (const std::runtime_error &ex) {
|
||||
MS_LOG_ERROR << "Catch child process runtime error: " << ex.what();
|
||||
}
|
||||
stopped_ = true;
|
||||
send_msg_->stop = true;
|
||||
heartbeat_thread.join();
|
||||
}
|
||||
|
||||
Status MultiProcess::SendMsg(const void *buffer, uint64_t msg_len) {
|
||||
MS_LOG_INFO << "Start to send message to peer process, msg len " << msg_len;
|
||||
send_msg_->msg_total_len = msg_len;
|
||||
uint64_t cur_offset = 0;
|
||||
while (msg_len > cur_offset) {
|
||||
uint64_t sub_msg_len = std::min(msg_len - cur_offset, shmat_data_max_size_);
|
||||
|
||||
memcpy_s(shmat_data_addr_, shmat_data_max_size_, static_cast<const uint8_t *>(buffer) + cur_offset, sub_msg_len);
|
||||
cur_offset += sub_msg_len;
|
||||
|
||||
send_msg_->msg_len = sub_msg_len;
|
||||
send_msg_->read_finish_flag = false;
|
||||
send_msg_->read_ready_flag = true;
|
||||
MS_LOG_INFO << "Send start " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
|
||||
while (!send_msg_->read_finish_flag && !peer_stopped_) {
|
||||
usleep(1000); // 1ms
|
||||
}
|
||||
if (peer_stopped_) {
|
||||
if (!send_msg_->read_finish_flag) {
|
||||
return FAILED;
|
||||
}
|
||||
break;
|
||||
}
|
||||
MS_LOG_INFO << "Send end " << cur_offset << ", msg len " << sub_msg_len << ", total len " << msg_len;
|
||||
}
|
||||
MS_LOG_INFO << "End to send message to peer process, msg len " << msg_len;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status MultiProcess::ReceiveMsg(CreateBufferCall create_buffer_call) {
|
||||
uint64_t cur_offset = 0;
|
||||
uint8_t *msg_buffer = nullptr;
|
||||
uint64_t msg_len = 0;
|
||||
do {
|
||||
MS_LOG_INFO << "Receive start from " << cur_offset;
|
||||
while (!receive_msg_->read_ready_flag && !peer_stopped_) {
|
||||
usleep(1000); // 1ms
|
||||
}
|
||||
if (peer_stopped_) {
|
||||
return FAILED;
|
||||
}
|
||||
if (msg_buffer == nullptr) {
|
||||
msg_len = receive_msg_->msg_total_len;
|
||||
msg_buffer = create_buffer_call(msg_len);
|
||||
}
|
||||
memcpy_s(msg_buffer + cur_offset, msg_len - cur_offset, shmat_data_addr_, receive_msg_->msg_len);
|
||||
cur_offset += receive_msg_->msg_len;
|
||||
receive_msg_->read_ready_flag = false;
|
||||
receive_msg_->read_finish_flag = true;
|
||||
MS_LOG_INFO << "Receive end, current length " << cur_offset << ", total length " << msg_len << std::endl;
|
||||
} while (msg_len > cur_offset);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void MultiProcess::HeartbeatThreadFunc(MultiProcess *multi_process) { multi_process->HeartbeatThreadFuncInner(); }
|
||||
|
||||
void MultiProcess::HeartbeatThreadFuncInner() {
|
||||
uint64_t last_beat_cnt = 0;
|
||||
uint64_t repeat_cnt = 0;
|
||||
while (!stopped_) {
|
||||
if (receive_msg_->stop) {
|
||||
peer_stopped_ = true;
|
||||
MS_LOG_WARNING << "Peer stopped";
|
||||
break;
|
||||
}
|
||||
uint64_t heartbeat_gap = receive_msg_->heartbeat - last_beat_cnt;
|
||||
if (heartbeat_gap > 0 && heartbeat_gap < 1024) {
|
||||
last_beat_cnt = receive_msg_->heartbeat;
|
||||
repeat_cnt = 0;
|
||||
} else {
|
||||
repeat_cnt++;
|
||||
if (repeat_cnt > 30) { // 30*100ms = 3s no reply
|
||||
peer_stopped_ = true;
|
||||
MS_LOG_WARNING << "Peer stopped";
|
||||
break;
|
||||
}
|
||||
}
|
||||
send_msg_->heartbeat += 1;
|
||||
usleep(100000); // sleep 100 ms
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
struct MessageFlag {
|
||||
uint64_t heartbeat = 0;
|
||||
uint64_t stop = false;
|
||||
uint64_t msg_len = 0;
|
||||
uint64_t msg_total_len = 0;
|
||||
uint64_t read_ready_flag = false;
|
||||
uint64_t read_finish_flag = false;
|
||||
};
|
||||
|
||||
class MultiProcess;
|
||||
using ProcessFuncCall = std::function<Status(MultiProcess *multi_process)>;
|
||||
using CreateBufferCall = std::function<uint8_t *(size_t msg_len)>;
|
||||
|
||||
class MultiProcess {
|
||||
public:
|
||||
MultiProcess();
|
||||
~MultiProcess();
|
||||
|
||||
Status MainProcess(ProcessFuncCall parent_process, ProcessFuncCall child_process);
|
||||
Status SendMsg(const void *buffer, uint64_t msg_len);
|
||||
Status ReceiveMsg(CreateBufferCall create_buffer_call);
|
||||
|
||||
private:
|
||||
uint8_t *shmat_addr_ = nullptr;
|
||||
uint8_t *shmat_data_addr_ = nullptr;
|
||||
uint64_t shmat_data_max_size_ = 0;
|
||||
uint64_t memory_size_ = 0;
|
||||
|
||||
bool peer_stopped_ = false;
|
||||
bool stopped_ = false;
|
||||
MessageFlag *send_msg_ = nullptr;
|
||||
MessageFlag *receive_msg_ = nullptr;
|
||||
|
||||
static void HeartbeatThreadFunc(MultiProcess *multi_process);
|
||||
void HeartbeatThreadFuncInner();
|
||||
Status ParentProcess(ProcessFuncCall parent_process);
|
||||
void ChildProcess(ProcessFuncCall child_process);
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_MULTI_PROCESS_H
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/model/model_converter_utils/shared_memory.h"
|
||||
#include <sys/shm.h>
|
||||
#include <sys/stat.h>
|
||||
#include <string>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
Status SharedMemory::Create(uint64_t memory_size) {
|
||||
auto access_mode = S_IRUSR | S_IWUSR | S_IROTH | S_IWOTH | S_IRGRP | S_IWGRP;
|
||||
shm_id_ = shmget(IPC_PRIVATE, memory_size, IPC_CREAT | IPC_EXCL | access_mode);
|
||||
if (shm_id_ == -1) {
|
||||
MS_LOG_ERROR << "Shared memory creation failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
}
|
||||
MS_LOG_INFO << "shmget success, shm id " << shm_id_;
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status SharedMemory::Attach() {
|
||||
void *shmat_addr = shmat(shm_id_, nullptr, 0);
|
||||
if (shmat_addr == reinterpret_cast<void *>(-1)) {
|
||||
MS_LOG_ERROR << "Shared memory attach failed. Errno " + std::to_string(errno);
|
||||
return FAILED;
|
||||
}
|
||||
shmat_addr_ = reinterpret_cast<uint8_t *>(shmat_addr);
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
void SharedMemory::Detach() {
|
||||
if (shmat_addr_) {
|
||||
auto err = shmdt(shmat_addr_);
|
||||
if (err == -1) {
|
||||
MS_LOG_ERROR << "Shared memory detach failed. Errno " + std::to_string(errno);
|
||||
return;
|
||||
}
|
||||
}
|
||||
shmat_addr_ = nullptr;
|
||||
}
|
||||
|
||||
void SharedMemory::Destroy() {
|
||||
// Remove the shared memory and never mind about the return code.
|
||||
auto err = shmctl(shm_id_, IPC_RMID, nullptr);
|
||||
if (err == -1) {
|
||||
std::string errMsg = "Unable to remove shared memory with id " + std::to_string(shm_id_);
|
||||
errMsg += ". Errno :" + std::to_string(errno);
|
||||
errMsg += "\nPlesae remove it manually using ipcrm -m command";
|
||||
MS_LOG_ERROR << errMsg;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
||||
#include <iostream>
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
|
||||
class SharedMemory {
|
||||
public:
|
||||
Status Create(uint64_t memory_size);
|
||||
Status Attach();
|
||||
void Detach();
|
||||
void Destroy();
|
||||
uint8_t *GetSharedMemoryAddr() { return shmat_addr_; }
|
||||
|
||||
private:
|
||||
int shm_id_ = -1;
|
||||
uint8_t *shmat_addr_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SHARED_MEMORY_H
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
#define MINDSPORE_CCSRC_SESSION_SESSION_H
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "include/api/status.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
|
||||
#ifdef ENABLE_D
|
||||
#include "runtime/context.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MsModel : public ModelImpl {
|
||||
public:
|
||||
explicit MsModel(uint32_t device_id);
|
||||
~MsModel();
|
||||
|
||||
Status LoadModel(const Buffer &model_data, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status LoadModel(const std::string &file_name, ModelType type,
|
||||
const std::map<std::string, std::string> &options) override;
|
||||
Status UnloadModel() override;
|
||||
|
||||
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
|
||||
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
|
||||
|
||||
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
|
||||
|
||||
Status InitEnv(const std::unordered_map<std::string, std::string> &other_options);
|
||||
Status FinalizeEnv();
|
||||
|
||||
private:
|
||||
std::shared_ptr<session::SessionBasic> session_impl_ = nullptr;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_ = 0;
|
||||
#ifdef ENABLE_D
|
||||
rtContext_t context_ = nullptr;
|
||||
#endif
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool load_flag_ = false;
|
||||
|
||||
std::shared_ptr<FuncGraph> LoadModel(const char *model_buf, size_t size, const std::string &device);
|
||||
Buffer ReadFile(const std::string &file);
|
||||
static void RegAllOp();
|
||||
Status CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr);
|
||||
Status CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
};
|
||||
|
||||
API_REG_MODEL(AscendMS, MsModel);
|
||||
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_SESSION_SESSION_BASIC_H
|
Loading…
Reference in new issue