|
|
|
@ -42,6 +42,40 @@ constexpr auto kStartCompileOp = "start_compile_op";
|
|
|
|
|
constexpr auto kWaitOne = "wait_one";
|
|
|
|
|
constexpr auto kResetTaskInfo = "reset_task_info";
|
|
|
|
|
|
|
|
|
|
bool TbeOpParallelPreBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
|
|
|
|
auto build_manger = std::make_shared<ParallelBuildManager>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_manger);
|
|
|
|
|
for (const auto &anf_node : anf_nodes) {
|
|
|
|
|
// gen kernel json
|
|
|
|
|
nlohmann::json kernel_json;
|
|
|
|
|
TbeKernelJsonCreator creator(OP_PRE_COMPILE);
|
|
|
|
|
if (!creator.GenTbeSingleKernelJson(anf_node, &kernel_json)) {
|
|
|
|
|
MS_LOG(ERROR) << "GenTbeSingleKernelJson failed";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
kernel_json["compile_type"] = "pre_build";
|
|
|
|
|
// op build
|
|
|
|
|
auto task_id = build_manger->StartCompileOp(kernel_json);
|
|
|
|
|
build_manger->SavePreTaskInfo(task_id, anf_node);
|
|
|
|
|
}
|
|
|
|
|
while (!build_manger->IsAllPreTaskFinish()) {
|
|
|
|
|
int task_id = -1;
|
|
|
|
|
char *task_result = nullptr;
|
|
|
|
|
char *pre_build_result = nullptr;
|
|
|
|
|
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "Pre Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((task_result != nullptr) && (strcmp(task_result, "Success") != 0)) {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "task pre compile Failed, task id:" << task_id << ", cause:" << task_result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
build_manger->PreTaskFinishProcess(task_id, pre_build_result);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {
|
|
|
|
|
auto build_manger = std::make_shared<ParallelBuildManager>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(build_manger);
|
|
|
|
@ -82,7 +116,8 @@ bool TbeOpParallelBuild(std::vector<AnfNodePtr> anf_nodes) {
|
|
|
|
|
while (!build_manger->IsAllTaskFinish()) {
|
|
|
|
|
int task_id = -1;
|
|
|
|
|
char *task_result = nullptr;
|
|
|
|
|
auto ret = build_manger->WaitOne(&task_id, &task_result);
|
|
|
|
|
char *pre_build_result = nullptr;
|
|
|
|
|
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
|
|
|
|
}
|
|
|
|
@ -116,7 +151,7 @@ int32_t ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json)
|
|
|
|
|
return task_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParallelBuildManager::WaitOne(int *task_id, char **task_result) const {
|
|
|
|
|
bool ParallelBuildManager::WaitOne(int *task_id, char **task_result, char **pre_build_result) const {
|
|
|
|
|
MS_LOG(INFO) << "wait task start.";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(task_id);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(task_result);
|
|
|
|
@ -128,10 +163,15 @@ bool ParallelBuildManager::WaitOne(int *task_id, char **task_result) const {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "Failed to call function wait_one";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
(void)PyArg_ParseTuple(pRes, "is", task_id, task_result);
|
|
|
|
|
(void)PyArg_ParseTuple(pRes, "iss", task_id, task_result, pre_build_result);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelBuildManager::SavePreTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node) {
|
|
|
|
|
MS_LOG(INFO) << "SavePreTaskInfo, task id: " << task_id;
|
|
|
|
|
pre_task_map_[task_id] = anf_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNodePtr &anf_node,
|
|
|
|
|
const std::string &json_name, const std::vector<size_t> &input_size_list,
|
|
|
|
|
const std::vector<size_t> &output_size_list, int32_t scope_id) {
|
|
|
|
@ -150,11 +190,24 @@ void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNod
|
|
|
|
|
task_map_[task_id] = task_info;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParallelBuildManager::IsAllPreTaskFinish() const {
|
|
|
|
|
MS_LOG(INFO) << "wait pre build process task_num: " << pre_task_map_.size();
|
|
|
|
|
return pre_task_map_.empty();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ParallelBuildManager::IsAllTaskFinish() const {
|
|
|
|
|
MS_LOG(INFO) << "wait process task_num: " << task_map_.size();
|
|
|
|
|
return task_map_.empty();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) {
|
|
|
|
|
auto task_iter = pre_task_map_.find(task_id);
|
|
|
|
|
if (task_iter == pre_task_map_.end()) {
|
|
|
|
|
MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id;
|
|
|
|
|
}
|
|
|
|
|
(void)pre_task_map_.erase(task_iter);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) {
|
|
|
|
|
auto task_iter = task_map_.find(task_id);
|
|
|
|
|
if (task_iter == task_map_.end()) {
|
|
|
|
|