optimize-time-of-getting-dynamic-info-in-single-op

pull/8646/head
lvliang 4 years ago
parent 7689062c7d
commit f279cd92ec

@ -565,6 +565,13 @@ void AscendSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph
void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
MS_EXCEPTION_IF_NULL(kernel_graph);
const auto &kernels = kernel_graph->execution_order();
auto iter = std::find_if(kernels.begin(), kernels.end(), [](const CNodePtr &kernel) {
return AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL && AnfAlgo::GetBooleanAttr(kernel, kAttrOutputIsDynamicShape);
});
if (iter == kernels.end()) {
return;
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
if (!runtime_instance->GenDynamicKernel(kernel_graph.get())) {

@ -624,6 +624,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
if (abs_list.find(args_spec_list) != abs_list.end()) {
MS_LOG(DEBUG) << "Match prim ok " << op_exec_info->op_name;
op_exec_info->abstract = abs_list[args_spec_list].abs;
op_exec_info->is_dynamic_shape = abs_list[args_spec_list].is_dynamic_shape;
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
is_find = true;
}
@ -634,19 +635,20 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get(), args_spec_list);
}
// get output dynamic shape info
auto abstract = op_exec_info->abstract;
MS_EXCEPTION_IF_NULL(abstract);
auto shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(shape);
auto shape_info = shape->ToString();
if (shape_info.find("-1") != string::npos) {
op_exec_info->is_dynamic_shape = true;
}
}
if (cnode != nullptr) {
cnode->set_abstract(op_exec_info->abstract);
}
// get output dynamic shape info
MS_EXCEPTION_IF_NULL(op_exec_info->abstract);
auto abstract_info = op_exec_info->abstract->ToString();
if (abstract_info.find("-1") != string::npos) {
op_exec_info->is_dynamic_shape = true;
}
op_exec_info->inputs_mask = op_masks;
MS_EXCEPTION_IF_NULL(op_exec_info);
if (op_exec_info->abstract != nullptr) {
@ -668,6 +670,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) {
// const_value need infer every step
auto &out = prim_abs_list_[prim->id()];
out[args_spec_list].abs = op_exec_info->abstract;
out[args_spec_list].is_dynamic_shape = op_exec_info->is_dynamic_shape;
out[args_spec_list].attrs = prim->evaluate_added_attrs();
MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list);
}

@ -46,6 +46,7 @@ using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
struct PrimAbsInfo {
abstract::AbstractBasePtr abs;
bool is_dynamic_shape = false;
std::unordered_map<std::string, ValuePtr> attrs;
};

@ -831,8 +831,9 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
<< " should be equal to the size of kernels " << kernels.size();
}
for (size_t i = 0; i < kernels.size(); ++i) {
auto &kernel = kernels[i];
if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
dynamic_kernel_list[i]->is_dynamic_shape()) {
dynamic_kernel_list[i]->is_dynamic_shape() && AnfAlgo::GetKernelType(kernel) == AICPU_KERNEL) {
dynamic_kernel_list[i]->InferShape();
dynamic_kernel_list[i]->UpdateArgs();
dynamic_kernel_list[i]->Execute();
@ -842,12 +843,12 @@ bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph) {
}
dynamic_kernel_list[i]->PostExecute();
} else {
auto kernel_mod = AnfAlgo::GetKernelMod(kernels[i]);
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
AddressPtrList kernel_inputs;
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
GenLaunchArgs(*kernel_mod, kernels[i], &kernel_inputs, &kernel_workspaces, &kernel_outputs);
GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(ERROR) << "Launch kernel failed.";

Loading…
Cancel
Save