!36 New control sink

Merge pull request !36 from zhoufeng/control-sink
pull/36/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 45ca7863ac

@ -177,37 +177,44 @@ class AicpuTaskInfo : public TaskInfo {
std::vector<void *> output_data_addrs_;
};
class LabelTaskInfo : public TaskInfo {
class LabelSetTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {}
~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
protected:
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id)
: TaskInfo(stream_id, type), label_id_(label_id) {}
virtual ~LabelTaskInfo() override {}
private:
uint32_t label_id_;
};
class LabelSetTaskInfo : public LabelTaskInfo {
class LabelGotoTaskInfo : public TaskInfo {
public:
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {}
~LabelSetTaskInfo() override {}
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {}
~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; }
private:
uint32_t label_id_;
};
class LabelSwitchTaskInfo : public LabelTaskInfo {
class LabelSwitchTaskInfo : public TaskInfo {
public:
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {}
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH),
label_size_(label_size),
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
};
uint32_t label_size() { return label_size_; };
const std::vector<uint32_t> &label_list() { return label_list_; };
void *cond() { return cond_; };
class LabelGotoTaskInfo : public LabelTaskInfo {
public:
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {}
~LabelGotoTaskInfo() override {}
private:
uint32_t label_size_;
std::vector<uint32_t> label_list_;
void *cond_;
};
class EventTaskInfo : public TaskInfo {

@ -116,23 +116,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) {
return true;
}
bool RuntimeModel::InitLabel(uint32_t batch_num) {
GELOGI("batch number:%u.", batch_num);
for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) {
rtLabel_t rt_lLabel = nullptr;
rtError_t rt_ret = rtLabelCreate(&rt_lLabel);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret);
return false;
bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) {
GELOGI("batch number:%u.", davinci_model->GetBatchNum());
label_list_.resize(davinci_model->GetBatchNum());
for (auto &task_info : davinci_model->GetTaskInfoList()) {
if (task_info == nullptr) {
GELOGE(PARAM_INVALID, "task_info is null.");
continue;
}
if (task_info->type() != TaskInfoType::LABEL_SET) {
continue;
}
auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info);
if (rt_lLabel == nullptr) {
GELOGE(RT_FAILED, "rtLabel is nullptr!");
if (label_set_task_info->stream_id() >= stream_list_.size()) {
GELOGE(PARAM_INVALID, "Invalid stream id.");
return false;
}
label_list_.emplace_back(rt_lLabel);
rtLabel_t rt_label = nullptr;
rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret);
return false;
}
label_list_[label_set_task_info->label_id()] = rt_label;
}
return true;
}
@ -164,7 +175,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) {
return false;
}
if (!InitLabel(davinci_model->GetBatchNum())) {
if (!InitLabel(davinci_model)) {
return false;
}

@ -48,7 +48,7 @@ class RuntimeModel {
bool LoadTask();
bool InitStream(std::shared_ptr<DavinciModel> &davinci_model);
bool InitEvent(uint32_t event_num);
bool InitLabel(uint32_t batch_num);
bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model);
bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model);
bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model);

@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_goto_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info)
: TaskRepeater<LabelGotoTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}
LabelGotoTask::~LabelGotoTask() {}
bool LabelGotoTask::Distribute() {
GELOGI("LabelGotoTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelGotoEx(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo);
} // namespace model_runner
} // namespace ge

@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> {
public:
LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info);
~LabelGotoTask() override;
bool Distribute() override;
private:
std::shared_ptr<LabelGotoTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_

@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_set_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info)
: TaskRepeater<LabelSetTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
label_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
auto stream_list = model_context.stream_list();
auto label_list = model_context.label_list();
uint32_t stream_id = task_info->stream_id();
uint32_t label_id = task_info->label_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id);
if (stream_id >= stream_list.size() || label_id >= label_list.size()) {
GELOGW("Stream/Label id invalid.");
return;
}
stream_ = stream_list[stream_id];
label_ = label_list[label_id];
}
LabelSetTask::~LabelSetTask() {}
bool LabelSetTask::Distribute() {
GELOGI("LabelSetTask Distribute start.");
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (label_ == nullptr) {
GELOGE(PARAM_INVALID, "label is null!");
return false;
}
rtError_t rt_ret = rtLabelSet(label_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo);
} // namespace model_runner
} // namespace ge

@ -0,0 +1,41 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> {
public:
LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info);
~LabelSetTask() override;
bool Distribute() override;
private:
std::shared_ptr<LabelSetTaskInfo> task_info_;
void *stream_;
void *label_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_

@ -0,0 +1,131 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ge_runtime/task/label_switch_task.h"
#include "ge_runtime/task/task_factory.h"
namespace ge {
namespace model_runner {
LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context,
const std::shared_ptr<LabelSwitchTaskInfo> &task_info)
: TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
all_label_resource_(),
label_info_(nullptr) {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
return;
}
all_label_resource_ = model_context.label_list();
auto stream_list = model_context.stream_list();
uint32_t stream_id = task_info->stream_id();
GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id);
if (stream_id >= stream_list.size()) {
GELOGW("Stream id invalid.");
return;
}
stream_ = stream_list[stream_id];
}
LabelSwitchTask::~LabelSwitchTask() {
if (label_info_ != nullptr) {
rtError_t rt_ret = rtFree(label_info_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret);
}
label_info_ = nullptr;
}
}
bool LabelSwitchTask::Distribute() {
GELOGI("LabelSwitchTask Distribute start.");
if (!CheckParamValid()) {
return false;
}
const std::vector<uint32_t> &label_index_list = task_info_->label_list();
std::vector<void *> label_list(task_info_->label_size(), nullptr);
for (size_t i = 0; i < task_info_->label_size(); ++i) {
uint32_t label_index = label_index_list[i];
if (label_index >= all_label_resource_.size()) {
GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index,
all_label_resource_.size());
return false;
}
label_list[i] = all_label_resource_[label_index];
GELOGI("Case %zu: label id %zu.", i, label_index);
}
uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size();
rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("DistributeTask end.");
return true;
}
bool LabelSwitchTask::CheckParamValid() {
if (stream_ == nullptr) {
GELOGE(PARAM_INVALID, "stream is null!");
return false;
}
if (task_info_->label_list().empty()) {
GELOGE(PARAM_INVALID, "label_list is empty.");
return false;
}
if (task_info_->label_size() != task_info_->label_list().size()) {
GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(),
task_info_->label_size());
return false;
}
if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) {
GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size());
return false;
}
if (label_info_ != nullptr) {
GELOGE(PARAM_INVALID, "label_info_ has dirty data.");
return false;
}
return true;
}
REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo);
} // namespace model_runner
} // namespace ge

@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
#include <memory>
#include "ge_runtime/task/task.h"
namespace ge {
namespace model_runner {
class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> {
public:
LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info);
~LabelSwitchTask() override;
bool Distribute() override;
private:
bool CheckParamValid();
std::shared_ptr<LabelSwitchTaskInfo> task_info_;
void *stream_;
std::vector<void *> all_label_resource_;
void *label_info_;
};
} // namespace model_runner
} // namespace ge
#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_
Loading…
Cancel
Save