Fixing workspace mismatch

pull/1170/head
chuxing 4 years ago
parent 90e9c8c1e5
commit b783cf2f79

@ -66,7 +66,7 @@ Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &nod
} }
AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs); AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs);
std::unique_ptr<NodeTask> node_task; std::unique_ptr<AiCoreNodeTask> node_task;
GE_CHK_STATUS_RET(builder.BuildTask(node_task, true, is_single_op), GE_CHK_STATUS_RET(builder.BuildTask(node_task, true, is_single_op),
"[%s] Failed to build op tasks.", node->GetName().c_str()); "[%s] Failed to build op tasks.", node->GetName().c_str());
task = std::move(node_task); task = std::move(node_task);
@ -99,7 +99,7 @@ Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key
return SUCCESS; return SUCCESS;
} }
bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::shared_ptr<NodeTask> task) { bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::shared_ptr<AiCoreNodeTask> &task) {
GE_CHECK_NOTNULL(task); GE_CHECK_NOTNULL(task);
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto iter = reg_node_tasks_.find(node_key); auto iter = reg_node_tasks_.find(node_key);
@ -111,7 +111,7 @@ bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::sha
return ret.second; return ret.second;
} }
std::shared_ptr<NodeTask> AiCoreNodeTaskRegistry::GetTask(const std::string &node_key) { std::shared_ptr<AiCoreNodeTask> AiCoreNodeTaskRegistry::GetTask(const std::string &node_key) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
auto iter = reg_node_tasks_.find(node_key); auto iter = reg_node_tasks_.find(node_key);
return (iter != reg_node_tasks_.end()) ? iter->second : nullptr; return (iter != reg_node_tasks_.end()) ? iter->second : nullptr;
@ -140,9 +140,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model,
auto node_key = std::to_string(model.GetModelId()) + "/" + shape_key; auto node_key = std::to_string(model.GetModelId()) + "/" + shape_key;
GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str()); GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str());
task = registry.GetTask(node_key); auto aicore_task = registry.GetTask(node_key);
if (task != nullptr) { if (task != nullptr) {
// The workspaces needed by a operator may differ with different shapes
op_desc->SetWorkspaceBytes(aicore_task->GetWorkspaceSizes());
GELOGI("AiCoreNodeExecutor(%s) CompileTask Skip.", node->GetName().c_str()); GELOGI("AiCoreNodeExecutor(%s) CompileTask Skip.", node->GetName().c_str());
task = std::move(aicore_task);
return SUCCESS; return SUCCESS;
} }
@ -153,16 +156,18 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model,
GELOGD("successfully generated task_defs: %s", node->GetName().c_str()); GELOGD("successfully generated task_defs: %s", node->GetName().c_str());
AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs); AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs);
std::unique_ptr<NodeTask> node_task; std::unique_ptr<AiCoreNodeTask> node_task;
GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), "[%s] Failed to build op tasks.", node->GetName().c_str()); GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), "[%s] Failed to build op tasks.", node->GetName().c_str());
task = std::move(node_task); node_task->SetWorkspaceSizes(op_desc->GetWorkspaceBytes());
aicore_task = std::move(node_task);
GELOGD("successfully created node task: %s", node->GetName().c_str()); GELOGD("successfully created node task: %s", node->GetName().c_str());
if (!registry.AddTask(node_key, task)) { if (!registry.AddTask(node_key, aicore_task)) {
GELOGE(INTERNAL_ERROR, "Add NodeTask failed, op name = %s.", node->GetName().c_str()); GELOGE(INTERNAL_ERROR, "Add NodeTask failed, op name = %s.", node->GetName().c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
task = std::move(aicore_task);
GELOGI("AiCoreNodeExecutor(%s) CompileTask End.", node->GetName().c_str()); GELOGI("AiCoreNodeExecutor(%s) CompileTask End.", node->GetName().c_str());
return SUCCESS; return SUCCESS;
} }
@ -247,6 +252,14 @@ bool AiCoreNodeTask::IsSupportDynamicShape() {
return true; return true;
} }
const vector<int64_t> &AiCoreNodeTask::GetWorkspaceSizes() const {
return workspace_sizes_;
}
void AiCoreNodeTask::SetWorkspaceSizes(const vector<int64_t> &workspace_sizes) {
workspace_sizes_ = workspace_sizes;
}
TaskCompilerFactory &TaskCompilerFactory::GetInstance() { TaskCompilerFactory &TaskCompilerFactory::GetInstance() {
static TaskCompilerFactory instance; static TaskCompilerFactory instance;
return instance; return instance;

@ -24,7 +24,6 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
class TaskCompiler { class TaskCompiler {
public: public:
TaskCompiler() = default; TaskCompiler() = default;
@ -42,11 +41,11 @@ class AiCoreNodeTaskRegistry {
return instance; return instance;
} }
std::shared_ptr<NodeTask> GetTask(const std::string &node_key); std::shared_ptr<AiCoreNodeTask> GetTask(const std::string &node_key);
bool AddTask(const std::string &node_key, const std::shared_ptr<NodeTask> task); bool AddTask(const std::string &node_key, const std::shared_ptr<AiCoreNodeTask> &task);
private: private:
AiCoreNodeTaskRegistry() = default; AiCoreNodeTaskRegistry() = default;
std::map<std::string, std::shared_ptr<NodeTask>> reg_node_tasks_; std::map<std::string, std::shared_ptr<AiCoreNodeTask>> reg_node_tasks_;
std::mutex mutex_; std::mutex mutex_;
}; };
@ -59,8 +58,12 @@ class AiCoreNodeTask : public NodeTask {
Status UpdateArgs(TaskContext &context) override; Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
const vector<int64_t> &GetWorkspaceSizes() const;
void SetWorkspaceSizes(const vector<int64_t> &workspace_sizes);
private: private:
std::vector<std::unique_ptr<AiCoreOpTask>> tasks_; std::vector<std::unique_ptr<AiCoreOpTask>> tasks_;
std::vector<int64_t> workspace_sizes_;
}; };
class AiCoreNodeExecutor : public NodeExecutor { class AiCoreNodeExecutor : public NodeExecutor {

@ -37,7 +37,7 @@ AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector
: op_desc_(op_desc), task_defs_(task_defs) { : op_desc_(op_desc), task_defs_(task_defs) {
} }
Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<NodeTask> &node_task, Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task,
bool ignore_failure_on_atomic, bool ignore_failure_on_atomic,
bool is_single_op) { bool is_single_op) {
GE_CHECK_NOTNULL(op_desc_); GE_CHECK_NOTNULL(op_desc_);

@ -27,6 +27,7 @@
namespace ge { namespace ge {
namespace hybrid { namespace hybrid {
class AiCoreNodeTask;
class AiCoreKernelRegistry { class AiCoreKernelRegistry {
public: public:
~AiCoreKernelRegistry() = default; ~AiCoreKernelRegistry() = default;
@ -47,7 +48,9 @@ class AiCoreTaskBuilder {
AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector<domi::TaskDef> &task_defs); AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector<domi::TaskDef> &task_defs);
~AiCoreTaskBuilder() = default; ~AiCoreTaskBuilder() = default;
Status BuildTask(std::unique_ptr<NodeTask> &node_task, bool ignore_failure_on_atomic, bool is_single_op = false); Status BuildTask(std::unique_ptr<AiCoreNodeTask> &node_task,
bool ignore_failure_on_atomic,
bool is_single_op = false);
private: private:
bool ExpectAtomicAddrCleanTask(); bool ExpectAtomicAddrCleanTask();

@ -61,11 +61,11 @@ Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vector<domi::Task
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGI("AiCoreTaskCompiler(%s) CompileOp Start.", node->GetName().c_str()); GELOGI("AiCoreTaskCompiler(%s) CompileOp Start.", node->GetName().c_str());
auto op_desc = node->GetOpDesc();
op_desc->SetWorkspaceBytes({});
GE_CHK_STATUS_RET_NOLOG(DoCompileOp(node)); GE_CHK_STATUS_RET_NOLOG(DoCompileOp(node));
GELOGD("successfully compiled op: %s", node->GetName().c_str()); GELOGD("successfully compiled op: %s", node->GetName().c_str());
auto op_desc = node->GetOpDesc();
std::vector<int64_t> input_offsets(op_desc->GetInputsSize(), kMemBase); std::vector<int64_t> input_offsets(op_desc->GetInputsSize(), kMemBase);
std::vector<int64_t> output_offsets(op_desc->GetOutputsSize(), kMemBase); std::vector<int64_t> output_offsets(op_desc->GetOutputsSize(), kMemBase);
op_desc->SetInputOffset(input_offsets); op_desc->SetInputOffset(input_offsets);

Loading…
Cancel
Save