From d3b978147fca821092cdc6bbdcfa30b29dbedc21 Mon Sep 17 00:00:00 2001 From: caifubi Date: Fri, 9 Oct 2020 09:45:49 +0800 Subject: [PATCH] Ascend Dynamic Shape --- .../parallel_compile/tbe_compiler/compiler.py | 27 ++- .../parallel_compile/tbe_compiler/helper.py | 4 + mindspore/ccsrc/CMakeLists.txt | 7 +- .../backend/kernel_compiler/CMakeLists.txt | 1 + .../aicpu/aicpu_kernel_build.cc | 93 +++++--- .../kernel_compiler/aicpu/aicpu_kernel_mod.cc | 17 +- .../kernel_compiler/aicpu/aicpu_kernel_mod.h | 1 + .../kernel_compiler/aicpu/aicpu_util.cc | 34 ++- .../kernel_compiler/aicpu/aicpu_util.h | 16 +- .../kernel_compiler/hccl/hccl_kernel.cc | 44 ++++ .../kernel_compiler/hccl/hccl_kernel.h | 1 + .../host/dynamic_shape_kernel.cc | 52 +++++ .../host/dynamic_shape_kernel.h | 43 ++++ .../kernel_compiler/host/host_kernel_build.cc | 42 ++++ .../kernel_compiler/host/host_kernel_build.h | 27 +++ .../host/host_kernel_metadata.cc | 59 +++++ .../host/host_kernel_metadata.h | 30 +++ .../kernel_compiler/host/host_kernel_mod.cc | 98 ++++++++ .../kernel_compiler/host/host_kernel_mod.h | 86 +++++++ .../kernel_compiler/kash/kernel_pack.cc | 3 + .../ccsrc/backend/kernel_compiler/kernel.h | 15 +- .../backend/kernel_compiler/kernel_fusion.cc | 6 +- .../backend/kernel_compiler/kernel_query.cc | 4 + .../backend/kernel_compiler/oplib/opinfo.h | 7 +- .../backend/kernel_compiler/oplib/oplib.cc | 13 +- .../backend/kernel_compiler/oplib/oplib.h | 3 +- .../kernel_compiler/rts/memcpy_async.cc | 31 +++ .../kernel_compiler/rts/memcpy_async.h | 1 + .../rts/profiling_kernel_mod.cc | 6 + .../rts/profiling_kernel_mod.h | 1 + .../kernel_compiler/tbe/tbe_adapter.cc | 151 ------------ .../backend/kernel_compiler/tbe/tbe_adapter.h | 1 - .../tbe/tbe_dynaminc_shape_util.cc | 139 +++++++++++ .../tbe/tbe_dynaminc_shape_util.h | 49 ++++ .../kernel_compiler/tbe/tbe_kernel_build.cc | 105 ++++++--- .../kernel_compiler/tbe/tbe_kernel_build.h | 2 +- .../kernel_compiler/tbe/tbe_kernel_mod.cc | 46 ++++ .../kernel_compiler/tbe/tbe_kernel_mod.h | 1 + .../tbe/tbe_kernel_parallel_build.cc | 78 +++++-- .../tbe/tbe_kernel_parallel_build.h | 13 +- .../tbe_kernel_select/tbe_kernel_select.cc | 5 +- .../backend/optimizer/ascend/ascend_helper.h | 7 +- .../enhancer/insert_pad_for_nms_with_mask.cc | 3 + .../format_type/deal_ref_trans_and_cast.cc | 4 +- .../ascend/ir_fission/concat_fission.cc | 3 + .../ir_fission/layer_norm_grad_split.cc | 3 + .../ascend/ir_fission/pack_fission.cc | 3 + .../ascend/ir_fission/reduce_min_fission.cc | 3 + .../ascend/ir_fission/split_fission.cc | 3 + .../optimizer/ascend/ir_fission/topk_split.cc | 3 + .../ascend/ir_fusion/add_input_to_output.cc | 2 +- .../ir_fusion/confusion_mul_grad_fusion.cc | 3 + .../ascend/ir_fusion/derelu_fusion.cc | 3 + .../ir_fusion/momentum_lossscale_fusion.cc | 3 + .../ascend/ir_fusion/mul_addn_fusion.cc | 3 + .../ascend/ir_fusion/remove_reshape_pair.cc | 4 + .../ir_fusion/reshape_transpose_fusion.cc | 3 + .../ir_fusion/transpose_reshape_fusion.cc | 3 + .../ccsrc/backend/optimizer/common/helper.cc | 3 + .../pass/convert_const_input_to_attr.cc | 4 + .../backend/session/anf_runtime_algorithm.cc | 122 ++++++++-- .../backend/session/anf_runtime_algorithm.h | 6 + .../ccsrc/backend/session/ascend_session.cc | 4 + .../ccsrc/backend/session/kernel_graph.cc | 11 + .../ccsrc/backend/session/kernel_graph.h | 14 +- .../ccsrc/backend/session/session_basic.cc | 92 ++++++++ .../ccsrc/backend/session/session_basic.h | 1 + .../jit/static_analysis/static_analysis.cc | 11 + .../jit/static_analysis/static_analysis.h | 2 + mindspore/ccsrc/runtime/device/CMakeLists.txt | 2 +- .../device/ascend/ascend_device_address.cc | 8 +- .../device/ascend/ascend_kernel_runtime.cc | 123 ++++++++++ .../device/ascend/ascend_kernel_runtime.h | 2 + .../device/ascend/ascend_stream_assign.cc | 2 +- .../runtime/device/ascend/dump/data_dumper.cc | 12 +- .../ascend/executor/ai_core_dynamic_kernel.cc | 182 +++++++++++++++ .../ascend/executor/ai_core_dynamic_kernel.h | 70 ++++++ .../ascend/executor/ai_cpu_dynamic_kernel.cc | 204 ++++++++++++++++ .../ascend/executor/ai_cpu_dynamic_kernel.h | 76 ++++++ .../ascend/executor/aicpu_ext_info_handle.cc | 218 ++++++++++++++++++ .../ascend/executor/aicpu_ext_info_handle.h | 88 +++++++ .../ascend/executor/executor_callback.cc | 41 ++++ .../ascend/executor/executor_callback.h | 49 ++++ .../ascend/executor/hccl_dynamic_kernel.cc | 187 +++++++++++++++ .../ascend/executor/hccl_dynamic_kernel.h | 82 +++++++ .../ascend/executor/host_dynamic_kernel.h | 36 +++ .../executor/rts/memcpy_rts_dynamic_kernel.cc | 32 +++ .../executor/rts/memcpy_rts_dynamic_kernel.h | 45 ++++ .../rts/profiling_rts_dynamic_kernel.cc | 32 +++ .../rts/profiling_rts_dynamic_kernel.h | 43 ++++ .../executor/tiling/op_tiling_calculater.cc | 188 +++++++++++++++ .../executor/tiling/op_tiling_calculater.h | 55 +++++ .../runtime/device/ascend/ge_types_convert.cc | 137 +++++++++++ .../{dump/ge_dump.h => ge_types_convert.h} | 57 +---- .../device/ascend/kernel_build_ascend.cc | 5 + .../device/ascend/kernel_select_ascend.cc | 6 +- .../runtime/device/cpu/cpu_kernel_runtime.h | 2 + .../runtime/device/executor/dynamic_kernel.cc | 128 ++++++++++ .../runtime/device/executor/dynamic_kernel.h | 62 +++++ .../runtime/device/gpu/gpu_kernel_runtime.h | 2 + .../ccsrc/runtime/device/kernel_adjust.cc | 8 + .../ccsrc/runtime/device/kernel_runtime.cc | 10 + .../ccsrc/runtime/device/kernel_runtime.h | 14 +- mindspore/ccsrc/utils/utils.h | 10 + mindspore/core/abstract/infer_functions.h | 31 +++ mindspore/core/abstract/prim_arrays.cc | 165 ++++++++++--- mindspore/core/abstract/prim_maths.cc | 50 ++++ mindspore/core/abstract/prim_others.cc | 68 ++++++ .../core/abstract/primitive_infer_map.cc | 16 ++ mindspore/core/abstract/utils.cc | 35 +++ mindspore/core/abstract/utils.h | 2 + mindspore/core/base/core_ops.h | 6 + mindspore/core/utils/convert_utils_base.h | 3 +- mindspore/ops/_op_impl/aicpu/__init__.py | 1 + mindspore/ops/_op_impl/aicpu/dynamic_shape.py | 40 ++++ mindspore/ops/_op_impl/aicpu/unique.py | 31 +++ mindspore/ops/_op_impl/tbe/__init__.py | 12 +- mindspore/ops/_op_impl/tbe/accumulate_n_v2.py | 4 +- mindspore/ops/_op_impl/tbe/apply_adam.py | 2 +- mindspore/ops/_op_impl/tbe/apply_ftrl.py | 2 +- mindspore/ops/_op_impl/tbe/apply_momentum.py | 2 +- mindspore/ops/_op_impl/tbe/assign_add.py | 4 +- mindspore/ops/_op_impl/tbe/batchnorm_grad.py | 4 +- mindspore/ops/_op_impl/tbe/bias_add_grad.py | 4 +- .../ops/_op_impl/tbe/confusion_mul_grad.py | 2 + mindspore/ops/_op_impl/tbe/div_ds.py | 42 ++++ mindspore/ops/_op_impl/tbe/floor_div.py | 4 +- .../_op_impl/tbe/fused_mul_add_n_l2loss.py | 4 +- mindspore/ops/_op_impl/tbe/gather_v2_ds.py | 67 ++++++ mindspore/ops/_op_impl/tbe/lin_space.py | 2 +- mindspore/ops/_op_impl/tbe/logsoftmax.py | 2 +- mindspore/ops/_op_impl/tbe/matmul.py | 4 +- mindspore/ops/_op_impl/tbe/matrix_set_diag.py | 2 +- mindspore/ops/_op_impl/tbe/mul_ds.py | 38 +++ mindspore/ops/_op_impl/tbe/one_hot.py | 2 +- mindspore/ops/_op_impl/tbe/real_div.py | 2 +- mindspore/ops/_op_impl/tbe/real_div_ds.py | 39 ++++ mindspore/ops/_op_impl/tbe/reduce_mean.py | 2 +- mindspore/ops/_op_impl/tbe/relu_grad.py | 4 +- .../_op_impl/tbe/resize_nearest_neighbor.py | 2 +- .../tbe/resize_nearest_neighbor_grad.py | 2 +- mindspore/ops/_op_impl/tbe/scatter_add_ds.py | 43 ++++ .../ops/_op_impl/tbe/scatter_update_ds.py | 43 ++++ mindspore/ops/_op_impl/tbe/softmax.py | 2 +- .../ops/_op_impl/tbe/sparse_apply_ftrl_d.py | 4 +- .../_op_impl/tbe/sparse_apply_ftrl_d_ds.py | 52 +++++ .../tbe/sparse_apply_proximal_adagrad.py | 2 +- mindspore/ops/_op_impl/tbe/sqrt_ds.py | 38 +++ .../tbe/{reduce_mean_d.py => square_ds.py} | 22 +- mindspore/ops/_op_impl/tbe/square_sum_all.py | 2 +- mindspore/ops/_op_impl/tbe/tensor_add_ds.py | 43 ++++ .../ops/_op_impl/tbe/tensor_scatter_update.py | 4 +- mindspore/ops/_op_impl/tbe/top_k.py | 4 +- .../_op_impl/tbe/unsorted_segment_sum_ds.py | 38 +++ mindspore/ops/op_info_register.py | 12 + mindspore/ops/operations/array_ops.py | 47 ++-- mindspore/ops/operations/nn_ops.py | 2 + .../test_tbe_ops/test_unsorted_segment_sum.py | 16 +- .../ir_fusion/add_input_to_output_test.cc | 2 +- .../stub/dynamic_shape/dynamic_shape_stub.cc | 84 +++++++ 160 files changed, 4605 insertions(+), 463 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/host_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc create mode 100644 mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h create mode 100644 mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc rename mindspore/ccsrc/runtime/device/ascend/{dump/ge_dump.h => ge_types_convert.h} (51%) create mode 100644 mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc create mode 100644 mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h create mode 100644 mindspore/ops/_op_impl/aicpu/dynamic_shape.py create mode 100644 mindspore/ops/_op_impl/aicpu/unique.py create mode 100644 mindspore/ops/_op_impl/tbe/div_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/gather_v2_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/mul_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/real_div_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/scatter_add_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/scatter_update_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/sqrt_ds.py rename mindspore/ops/_op_impl/tbe/{reduce_mean_d.py => square_ds.py} (69%) create mode 100644 mindspore/ops/_op_impl/tbe/tensor_add_ds.py create mode 100644 mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py create mode 100644 tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py index c6e39a41a8..e4b8397016 100755 --- a/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/compiler.py @@ -18,6 +18,7 @@ import os import sys from te.platform.cce_conf import te_set_version from te.platform.fusion_util import fusion_op +import te from common import check_kernel_info, get_args, get_build_in_impl_path build_in_impl_path = get_build_in_impl_path() @@ -38,6 +39,16 @@ def _initialize(impl_path): sys.path.insert(0, op_module_name) +def _replace_range(args): + for arg in args: + if not arg.__contains__('range'): + continue + shape_range = arg["range"] + for range_item in shape_range: + for index, value in enumerate(range_item): + if value < 0: + range_item[index] = None + def build_op(build_type, json_str): """ call op functions with function name and input args json_str @@ -71,11 +82,18 @@ def build_op(build_type, json_str): outputs_args = get_args(kernel_info['op_info'], 'outputs') attrs_args = get_args(kernel_info['op_info'], 'attrs') kernel_name = kernel_info['op_info']['kernel_name'] + is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape'] + if is_dynamic_shape: + _replace_range(inputs_args) + _replace_range(outputs_args) if custom_flag: op_module = __import__(op_name) else: - op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0) + if is_dynamic_shape: + op_module = __import__("impl.dynamic."+op_name, globals(), locals(), [op_name], 0) + else: + op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0) # get function if build_type == op_build: if custom_flag: @@ -92,7 +110,12 @@ def build_op(build_type, json_str): if kernel_name[0:19] == "bounding_box_encode": return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name) - return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + if is_dynamic_shape: + with te.op.dynamic(): + op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + return te.op.get_compile_info() + else: + return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) except Exception as e: raise RuntimeError(e) diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/helper.py b/mindspore/_extends/parallel_compile/tbe_compiler/helper.py index bb4c057c1a..e223525e9a 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/helper.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/helper.py @@ -78,6 +78,7 @@ def _check_supported(kernel_info): """ try: op_name = kernel_info['op_info']['name'] + is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape'] impl_path = build_in_impl_path custom_flag = False if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: @@ -92,8 +93,11 @@ def _check_supported(kernel_info): if custom_flag: op_module = __import__(op_name) + elif is_dynamic_shape: + op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0) else: op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) + # get function if not hasattr(op_module, "check_supported"): return "" diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 69375d7631..d46d1a23bb 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -219,6 +219,7 @@ if (ENABLE_D) set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common) set(ASCEND_DRIVER_BACK_PATH ${ASCEND_PATH}/driver/lib64/driver) set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64) + set(ASCEND_OPP_PATH ${ASCEND_PATH}/opp/op_impl/built-in/ai_core/tbe/op_tiling) endif() MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}") @@ -228,7 +229,8 @@ if (ENABLE_D) find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH}) find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH}) - target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER}) + find_library(OPTILING optiling ${ASCEND_OPP_PATH}) + target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} ${OPTILING}) target_link_libraries(mindspore -Wl,--start-group proto_input ${PROFILING} mindspore::protobuf -Wl,--end-group) elseif (CMAKE_SYSTEM_NAME MATCHES "Windows") target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece -Wl,--end-group) @@ -258,6 +260,7 @@ if (ENABLE_D) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons) + set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling) elseif (ENABLE_GPU) set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/cuda/lib64) endif () @@ -315,6 +318,8 @@ add_library(inference SHARED ${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc ${LOAD_ONNX_SRC} ) + +set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH}) target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY} -Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar) diff --git a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt index bde8328beb..bda98a304e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt +++ b/mindspore/ccsrc/backend/kernel_compiler/CMakeLists.txt @@ -15,6 +15,7 @@ if (ENABLE_D) "akg/akg_kernel_attrs_process.cc" "akg/akg_kernel_metadata.cc" "tbe/*.cc" + "host/*.cc" "aicpu/*.cc" "rts/*.cc" "hccl/*.cc" diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc index e6015f099f..9e1af32026 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc @@ -289,51 +289,25 @@ bool CreateNodeDefBytes(const std::shared_ptr &anf_node, return true; } -bool CreateExtInfo(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { - if (!anf_node->isa()) { - return true; - } - - if (!AnfAlgo::IsDynamicShape(anf_node)) { - return true; - } - - MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope(); - - int32_t unknown_shape_type = UnknowShapeOpType::DEPEND_COMPUTE; - uint64_t ext_info_head_len = kExtInfoHeadSize; - std::string ext_info; - size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); - size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); - - // 1.addr:unknown shape type - uint64_t ext_info_len = ext_info.size(); - ext_info_len += ext_info_head_len + sizeof(int32_t); - - // 2.addr:input ShapeAndType - ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType); - - // 3.addr:output ShapeAndType - ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType); - - uint64_t ext_info_offset = ext_info.size(); - ext_info.resize(ext_info_len, 0); - char *ext_info_buf = ext_info.data(); - +uint64_t SetExtInfoShapeType(char *ext_info_buf, uint64_t ext_info_offset) { // deal1: unknown shape type ExtInfo *info = reinterpret_cast(ext_info_buf + ext_info_offset); info->infoType = FWK_ADPT_EXT_SHAPE_TYPE; info->infoLen = sizeof(int32_t); - ext_info_offset += ext_info_head_len; + ext_info_offset += kExtInfoHeadSize; int32_t *shape_type = reinterpret_cast(ext_info_buf + ext_info_offset); - *shape_type = unknown_shape_type; + *shape_type = UnknowShapeOpType::DEPEND_COMPUTE; ext_info_offset += info->infoLen; + return ext_info_offset; +} +uint64_t SetExtInfoInputShapeType(char *ext_info_buf, uint64_t ext_info_offset, + const std::shared_ptr &anf_node, size_t input_num) { // deal2:input ShapeAndType - info = reinterpret_cast(ext_info_buf + ext_info_offset); + ExtInfo *info = reinterpret_cast(ext_info_buf + ext_info_offset); info->infoType = FWK_ADPT_EXT_INPUT_SHAPE; info->infoLen = input_num * sizeof(ShapeAndType); - ext_info_offset += ext_info_head_len; + ext_info_offset += kExtInfoHeadSize; ShapeAndType *inputs = reinterpret_cast(ext_info_buf + ext_info_offset); for (size_t input_index = 0; input_index < input_num; input_index++) { @@ -364,12 +338,16 @@ bool CreateExtInfo(const std::shared_ptr &anf_node, const std::shared_p } } ext_info_offset += info->infoLen; + return ext_info_offset; +} +uint64_t SetExtInfoOutputShapeType(char *ext_info_buf, uint64_t ext_info_offset, + const std::shared_ptr &anf_node, size_t output_num) { // deal3:output ShapeAndType - info = reinterpret_cast(ext_info_buf + ext_info_offset); + ExtInfo *info = reinterpret_cast(ext_info_buf + ext_info_offset); info->infoType = FWK_ADPT_EXT_OUTPUT_SHAPE; info->infoLen = output_num * sizeof(ShapeAndType); - ext_info_offset += ext_info_head_len; + ext_info_offset += kExtInfoHeadSize; ShapeAndType *outputs = reinterpret_cast(ext_info_buf + ext_info_offset); for (size_t output_index = 0; output_index < output_num; output_index++) { @@ -387,6 +365,47 @@ bool CreateExtInfo(const std::shared_ptr &anf_node, const std::shared_p } } + ext_info_offset += info->infoLen; + return ext_info_offset; +} + +bool CreateExtInfo(const std::shared_ptr &anf_node, const std::shared_ptr &kernel_mod_ptr) { + MS_EXCEPTION_IF_NULL(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + if (!anf_node->isa()) { + return true; + } + + if (!AnfAlgo::IsDynamicShape(anf_node)) { + return true; + } + + MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope(); + + uint64_t ext_info_head_len = kExtInfoHeadSize; + std::string ext_info; + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + // 1.addr:unknown shape type + uint64_t ext_info_len = ext_info.size(); + ext_info_len += ext_info_head_len + sizeof(int32_t); + + // 2.addr:input ShapeAndType + ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType); + + // 3.addr:output ShapeAndType + ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType); + + uint64_t ext_info_offset = ext_info.size(); + ext_info.resize(ext_info_len, 0); + char *ext_info_buf = ext_info.data(); + + ext_info_offset = SetExtInfoShapeType(ext_info_buf, ext_info_offset); + ext_info_offset = SetExtInfoInputShapeType(ext_info_buf, ext_info_offset, anf_node, input_num); + ext_info_offset = SetExtInfoOutputShapeType(ext_info_buf, ext_info_offset, anf_node, output_num); + + MS_LOG(INFO) << "Check ext_info_len:" << ext_info_len << " ext_info_offset:" << ext_info_offset; // set ext info kernel_mod_ptr->SetExtInfo(ext_info); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc index d6fafbcf43..98e84b83b1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.cc @@ -26,8 +26,13 @@ #include "utils/convert_utils.h" #include "backend/kernel_compiler/aicpu/aicpu_util.h" #include "utils/ms_context.h" +#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/executor/host_dynamic_kernel.h" using AicpuTaskInfoPtr = std::shared_ptr; +using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel; +using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; namespace mindspore { namespace kernel { @@ -93,7 +98,7 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector &inputs param_len += node_def_len; param_len += sizeof(uint32_t); - AicpuParamHead aicpu_param_head; + AicpuParamHead aicpu_param_head{}; aicpu_param_head.length = param_len; aicpu_param_head.ioAddrNum = io_addrs_num; @@ -178,5 +183,15 @@ std::vector AicpuOpKernelMod::GenTask(const std::vector MS_LOG(INFO) << "AicpuOpKernelMod GenTask end"; return {task_info_ptr}; } + +device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + + CreateCpuKernelInfo(kernel_inputs, kernel_outputs); + return std::make_shared(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h index 7d006cc67d..71768416ed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_mod.h @@ -31,6 +31,7 @@ class AicpuOpKernelMod : public AscendKernelMod { std::vector GenTask(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; void SetInputList(const std::vector &inputList); void SetOutputList(const std::vector &outputList); diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc index 2f17967c03..cc5a11f8bb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { -static std::map MS_PROTO_DATA_TYPE_MAP = { +static const std::map kMsProtoDataTypeMap = { {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, @@ -39,14 +39,38 @@ static std::map MS_PROTO_DATA_TYPE_MAP = { {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, }; +static const std::map kProtoDataTypeToMsDataTypeMap = { + {mindspore::DataType::MS_UNKNOWN, mindspore::TypeId::kTypeUnknown}, + {mindspore::DataType::MS_BOOL, mindspore::TypeId::kNumberTypeBool}, + {mindspore::DataType::MS_INT32, mindspore::TypeId::kNumberTypeInt32}, + {mindspore::DataType::MS_INT8, mindspore::TypeId::kNumberTypeInt8}, + {mindspore::DataType::MS_INT16, mindspore::TypeId::kNumberTypeInt16}, + {mindspore::DataType::MS_INT64, mindspore::TypeId::kNumberTypeInt64}, + {mindspore::DataType::MS_UINT8, mindspore::TypeId::kNumberTypeUInt8}, + {mindspore::DataType::MS_UINT16, mindspore::TypeId::kNumberTypeUInt16}, + {mindspore::DataType::MS_UINT32, mindspore::TypeId::kNumberTypeUInt32}, + {mindspore::DataType::MS_UINT64, mindspore::TypeId::kNumberTypeUInt64}, + {mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16}, + {mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32}, + {mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64}, +}; + int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) { - auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type); - if (iter != MS_PROTO_DATA_TYPE_MAP.end()) { - return MS_PROTO_DATA_TYPE_MAP[ms_type]; - } else { + auto iter = kMsProtoDataTypeMap.find(ms_type); + if (iter == kMsProtoDataTypeMap.end()) { MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast(ms_type); return -1; } + return iter->second; +} + +int AicpuOpUtil::ProtoTypeToMsType(int proto_type) { + auto iter = kProtoDataTypeToMsDataTypeMap.find(proto_type); + if (iter == kProtoDataTypeToMsDataTypeMap.end()) { + MS_LOG(ERROR) << "UnSupported proto_type value:" << proto_type; + return -1; + } + return iter->second; } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h index 5381812ca9..e6bd0fc975 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_util.h @@ -55,13 +55,6 @@ struct AicpuParamHead { uint64_t extInfoAddr; // extInfo address } __attribute__((packed)); -const uint32_t kExtInfoHeadSize = 8; -struct ExtInfo { - int32_t infoType; // extend type - uint32_t infoLen; // length for infoMsg - char infoMsg[0]; // extend value -} __attribute__((packed)); - // Extent info ShapeAndType const uint32_t kMaxShapeDims = 8; struct ShapeAndType { @@ -69,6 +62,14 @@ struct ShapeAndType { int64_t dims[kMaxShapeDims]; } __attribute__((packed)); +// Extend info structure for extInfoAddr +const uint32_t kExtInfoHeadSize = 8; +struct ExtInfo { + int32_t infoType; // extend type + uint32_t infoLen; // length for infoMsg + char infoMsg[0]; // extend value +} __attribute__((packed)); + // Extend Info type for task enum FWKTaskExtInfoType { FWK_ADPT_EXT_SHAPE_TYPE = 0, @@ -88,6 +89,7 @@ enum UnknowShapeOpType { class AicpuOpUtil { public: static int MsTypeToProtoType(TypeId ms_type); + static int ProtoTypeToMsType(int proto_type); private: // kernel id diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc index 0ec7c6c625..c4b108b021 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.cc @@ -15,15 +15,34 @@ */ #include "backend/kernel_compiler/hccl/hccl_kernel.h" + +#include #include "runtime/device/ascend/tasksink/runtime_utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "utils/utils.h" #include "utils/ms_context.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" using HcclTaskInfoPtr = std::shared_ptr; using ge::model_runner::HcclTaskInfo; using mindspore::device::ascend::tasksink::RuntimeUtils; +namespace { +static std::map kMsOpNameToHcomHcclType = { + {mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce}, + {mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather}, + {mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast}, + {mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}}; +std::string MsOpNameToHcomOpType(const std::string &ms_op_type) { + auto iter = kMsOpNameToHcomHcclType.find(ms_op_type); + if (iter == kMsOpNameToHcomHcclType.end()) { + MS_LOG(EXCEPTION) << "Invalid MsOpType:" << ms_op_type; + } + return iter->second; +} +} // namespace + namespace mindspore { namespace kernel { void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) { @@ -156,5 +175,30 @@ std::vector HcclKernel::GenTask(const std::vector &inpu MS_EXCEPTION_IF_NULL(task_info_ptr); return {task_info_ptr}; } + +device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { + AddressPtrList inputs; + AddressPtrList workspaces; + AddressPtrList outputs; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs); + + std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_)); + + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Hccl kernel input is empty"; + } + if (hccl_data_type_list_.empty()) { + MS_LOG(EXCEPTION) << "Hccl data type list is empty"; + } + MS_EXCEPTION_IF_NULL(inputs.at(0)); + auto input_data_addr = inputs.at(0)->addr; + MS_EXCEPTION_IF_NULL(outputs.at(0)); + auto output_data_addr = outputs.at(0)->addr; + HcclDataType data_type = hccl_data_type_list_[0]; + + auto executor = std::make_shared( + hccl_type, input_data_addr, output_data_addr, hccl_count_, data_type, op_type_, root_id_, stream_ptr, cnode_ptr); + return executor; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h index b7fe21945b..2ba888ecee 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel.h @@ -41,6 +41,7 @@ class HcclKernel : public AscendKernelMod { const std::vector &GetWorkspaceSizeList() const override; std::vector GenTask(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; protected: std::vector> hccl_kernel_input_shape_list_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc new file mode 100644 index 0000000000..30d50ae6ed --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/host/dynamic_shape_kernel.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +void DynamicShapeKernel::Execute() { + MS_LOG(INFO) << "Execute DynamicShapeKernel Start"; + auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num; + } + + auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0); + auto output_shape = std::vector(SizeToInt(prev_output_shape.size())); + + auto output_type = TypeId::kNumberTypeInt32; + + auto output_tensor_for_sync = std::make_shared(output_type, output_shape); + auto data_ptr = static_cast(output_tensor_for_sync->data_c()); + for (size_t i = 0; i < prev_output_shape.size(); ++i) { + MS_LOG(INFO) << "DEBUG prev_output_shape[" << i << "]:" << prev_output_shape[i]; + *(data_ptr + i) = prev_output_shape[i]; + } + + auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, 0); + MS_EXCEPTION_IF_NULL(output_addr); + output_addr->SyncHostToDevice(output_shape, LongToSize(output_tensor_for_sync->data().nbytes()), + output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c()); + MS_LOG(INFO) << "Execute DynamicShapeKernel End"; +} + +device::DynamicKernelPtr DynamicShapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { + return std::make_shared(stream_ptr, cnode_ptr); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.h new file mode 100644 index 0000000000..61c419e90e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ +#include +#include +#include +#include "runtime/device/ascend/executor/host_dynamic_kernel.h" +#include "backend/kernel_compiler/host/host_kernel_mod.h" +using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel; +namespace mindspore { +namespace kernel { +class DynamicShapeKernel : public HostDynamicKernel { + public: + DynamicShapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {} + ~DynamicShapeKernel() override = default; + void Execute() override; +}; + +class DynamicShapeKernelMod : public HostKernelMod { + public: + DynamicShapeKernelMod() = default; + ~DynamicShapeKernelMod() override = default; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; +}; +MS_HOST_REG_KERNEL(DynamicShape, DynamicShapeKernelMod); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.cc new file mode 100644 index 0000000000..7cd9e43ddd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.cc @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/host/host_kernel_build.h" +#include +#include "runtime/device/kernel_runtime.h" +#include "backend/kernel_compiler/host/host_kernel_mod.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_graph.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HostOpBuild(const std::shared_ptr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + std::string opname = AnfAlgo::GetCNodeName(anf_node); + MS_LOG(INFO) << "Host op [" << opname << "]"; + auto kerPtr = HostKernelFactory::Get(opname); + if (kerPtr == nullptr) { + MS_LOG(ERROR) << "Host can't find Kernel[" << opname << "]"; + return nullptr; + } + if (!kerPtr->Init(anf_node)) { + MS_LOG(ERROR) << "Host Kernel initialize failed!"; + return nullptr; + } + return kerPtr; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.h new file mode 100644 index 0000000000..f330905de8 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_build.h @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ +#include +#include "backend/kernel_compiler/kernel.h" + +namespace mindspore { +namespace kernel { +KernelModPtr HostOpBuild(const std::shared_ptr &anf_node); +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc new file mode 100644 index 0000000000..d2a04e4160 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/host/host_kernel_metadata.h" +#include +#include +#include "backend/kernel_compiler/oplib/oplib.h" +#include "backend/kernel_compiler/common_utils.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +constexpr auto kDynamicShape = "DynamicShape"; + +void HostMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + MS_LOG(INFO) << "HostMetadataInfo."; + MS_EXCEPTION_IF_NULL(kernel_node); + MS_EXCEPTION_IF_NULL(kernel_info_list); + std::string op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name != kDynamicShape) { + MS_LOG(DEBUG) << "Host does not have op [" << op_name << "]"; + return; + } + + std::vector inputs_format{}; + std::vector inputs_type{}; + for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + std::vector outputs_format; + std::vector outputs_type; + for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetKernelType(HOST_KERNEL); + kernel_info_list->push_back(builder.Build()); +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.h new file mode 100644 index 0000000000..dc0deab5fc --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.h @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ + +#include +#include +#include +#include "backend/kernel_compiler/kernel_build_info.h" + +namespace mindspore { +namespace kernel { +void HostMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.cc new file mode 100644 index 0000000000..820b398388 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/host/host_kernel_mod.h" + +#include +#include +#include +#include +#include "runtime/mem.h" +#include "utils/ms_context.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/executor/host_dynamic_kernel.h" + +namespace mindspore { +namespace kernel { +void HostKernelFactory::Registe(const std::string &name, HostKernelCreater &&fun) { + hostKernelMap_.emplace(name, std::move(fun)); +} + +std::shared_ptr HostKernelFactory::Get(const std::string &name) { + const auto &map = Get().hostKernelMap_; + auto it = map.find(name); + if (it != map.end() && it->second) { + return (it->second)(); + } + return nullptr; +} + +HostKernelFactory &HostKernelFactory::Get() { + static HostKernelFactory instance; + return instance; +} + +const std::vector &HostKernelMod::GetInputSizeList() const { return input_size_list_; } +const std::vector &HostKernelMod::GetOutputSizeList() const { return output_size_list_; } +const std::vector &HostKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; } +bool HostKernelMod::Init(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + + for (size_t i = 0; i < input_num; i++) { + std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); + TypePtr type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + input_size_list_.push_back(LongToSize(size_i)); + } + + for (size_t i = 0; i < output_num; i++) { + std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); + TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); + MS_EXCEPTION_IF_NULL(type_ptr); + int64_t size_i = 1; + for (size_t j = 0; j < shape_i.size(); j++) { + size_i = LongMulWithOverflowCheck(size_i, static_cast(shape_i[j])); + } + size_t type_byte = GetTypeByte(type_ptr); + if (type_byte == 0) { + return false; + } + size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte)); + output_size_list_.push_back(LongToSize(size_i)); + } + return true; +} +bool HostKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + return true; +} +std::vector HostKernelMod::GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t) { + return {}; +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.h new file mode 100644 index 0000000000..8e980334af --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_mod.h @@ -0,0 +1,86 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_ +#include +#include +#include +#include +#include +#include "backend/kernel_compiler/ascend_kernel_mod.h" +namespace mindspore { +namespace kernel { +class HostKernelMod : public AscendKernelMod { + public: + HostKernelMod() = default; + ~HostKernelMod() override = default; + const std::vector &GetInputSizeList() const override; + const std::vector &GetOutputSizeList() const override; + const std::vector &GetWorkspaceSizeList() const override; + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) override; + std::vector GenTask(const std::vector &, const std::vector &, + const std::vector &, uint32_t) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override = 0; + bool Init(const AnfNodePtr &anf_node); + + protected: + AnfNodePtr anf_node_; + std::string op_name_; + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; + +using HostKernelModPtr = std::shared_ptr; +using HostKernelModPtrList = std::vector; +using HostKernelCreater = std::function()>; + +class HostKernelFactory { + HostKernelFactory() = default; + ~HostKernelFactory() = default; + + public: + static HostKernelFactory &Get(); + void Registe(const string &name, HostKernelCreater &&fun); + static std::shared_ptr Get(const string &name); + + private: + std::map hostKernelMap_; +}; + +class _HostKernelRegister { + public: + _HostKernelRegister(const string &name, HostKernelCreater &&fun) { + HostKernelFactory::Get().Registe(name, std::move(fun)); + } + ~_HostKernelRegister() = default; +}; + +#define _MS_HOST_REG_KERNEL_REG(KNAME, clazz) \ + static_assert(std::is_base_of::value, " must be base of HostKernelMod"); \ + static const _HostKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \ + std::shared_ptr ptr = nullptr; \ + ptr = std::make_shared(); \ + MS_EXCEPTION_IF_NULL(ptr); \ + return ptr; \ + }); + +#define MS_HOST_REG_KERNEL(KNAME, clazz) _MS_HOST_REG_KERNEL_REG(KNAME, clazz) +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc index 0ede79f699..1e3d7592d4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kash/kernel_pack.cc @@ -174,6 +174,9 @@ void KernelPack::ParseKernelJson(const nlohmann::json &js) { kernel_json_info_.block_dim = js["blockDim"]; kernel_json_info_.kernel_name = js["kernelName"]; kernel_json_info_.magic = js["magic"]; + if (js.contains("opParaSize")) { + kernel_json_info_.op_para_size = js["opParaSize"]; + } if (js.find("parameters") != js.end()) { if (!js.at("parameters").is_array()) { MS_LOG(DEBUG) << "Format error!,parameters should be array."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel.h b/mindspore/ccsrc/backend/kernel_compiler/kernel.h index 5b0bb87658..5bfbe4cd3f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel.h @@ -25,9 +25,18 @@ #include "ir/tensor.h" #include "abstract/dshape.h" #include "utils/log_adapter.h" +#include "runtime/device/executor/dynamic_kernel.h" namespace mindspore { -enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL }; +enum KernelType : int { + UNKNOWN_KERNEL_TYPE = 0, + AKG_KERNEL, + AICPU_KERNEL, + RT_KERNEL, + HCCL_KERNEL, + TBE_KERNEL, + HOST_KERNEL +}; namespace kernel { // Supported fusion type @@ -69,7 +78,8 @@ struct KernelJsonInfo { std::vector parameters; std::string sha256; std::vector workspaces; - KernelJsonInfo() : block_dim(0) {} + uint32_t op_para_size; + KernelJsonInfo() : block_dim(0), op_para_size(0) {} }; class KernelPack { @@ -118,6 +128,7 @@ class KernelMod { virtual const std::vector &GetWorkspaceSizeList() const = 0; virtual bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) = 0; + virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; } virtual std::vector GenParameters() { return {}; } virtual void ReleaseResource() {} diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc index 55d7468617..a1508ef08e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_fusion.cc @@ -83,8 +83,8 @@ std::map KernelFusion(const std::vector while (!build_manger->IsAllTaskFinish()) { int task_id = -1; std::string task_result; - std::string pre_build_result; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + std::string build_result; + auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; } @@ -94,7 +94,7 @@ std::map KernelFusion(const std::vector << " change to single op build."; build_failed_num++; } - auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false); + auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, build_result, false); if (kernel_mod_item.second != nullptr) { (void)kernel_mod_ret.emplace(kernel_mod_item); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc index 4e32121cbb..704f47f4a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc @@ -18,6 +18,7 @@ #include #include #include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h" +#include "backend/kernel_compiler/host/host_kernel_metadata.h" #include "backend/kernel_compiler/rts/rt_kernel_info.h" #include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h" @@ -86,6 +87,9 @@ void KernelQueryAll(const CNodePtr &kernel_node, if (kernel_info_list->empty()) { HcclMetadataInfo(kernel_node, kernel_info_list); } + if (kernel_info_list->empty()) { + HostMetadataInfo(kernel_node, kernel_info_list); + } if (kernel_info_list->empty()) { MS_EXCEPTION(NotExistsError) << "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:" diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h index cf566a7d16..1a1918769f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/opinfo.h @@ -102,6 +102,7 @@ class OpInfo { kernel_name_ = opinfo.kernel_name(); partial_flag_ = opinfo.partial_flag_; dynamic_format_ = opinfo.dynamic_format_; + dynamic_shape_ = opinfo.dynamic_shape_; op_pattern_ = opinfo.op_pattern(); processor_ = opinfo.processor_; for (const auto &attr : opinfo.attrs_ptr()) { @@ -122,12 +123,14 @@ class OpInfo { std::string fusion_type() const { return fusion_type_; } std::string kernel_name() const { return kernel_name_; } OpPattern op_pattern() const { return op_pattern_; } + bool dynamic_shape() const { return dynamic_shape_; } std::string processor() const { return processor_; } std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } const std::unordered_map &ref_infos() const { return ref_infos_; } + void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; } void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } @@ -149,7 +152,8 @@ class OpInfo { void ClearOutputs() { (void)outputs_ptr_.clear(); } bool equals_to(const std::shared_ptr &other_info) const { return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && - this->processor_ == other_info->processor_; + this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ && + this->dynamic_shape_ == other_info->dynamic_shape_; } private: @@ -163,6 +167,7 @@ class OpInfo { std::string kernel_name_; bool partial_flag_ = false; bool dynamic_format_ = false; + bool dynamic_shape_ = false; OpPattern op_pattern_ = kCommonPattern; std::string processor_; std::vector> attrs_ptr_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc index 69209ed985..d5cecf79ed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc @@ -38,6 +38,7 @@ constexpr auto kDynamicFormat = "dynamicFormat"; constexpr auto kFormatAgnostic = "formatAgnostic"; constexpr auto kBroadcast = "broadcast"; constexpr auto kReduce = "reduce"; +constexpr auto kDynamicShape = "dynamic_shape"; constexpr auto kDtypeFormat = "dtype_format"; constexpr auto kAttr = "attr"; constexpr auto kIputs = "inputs"; @@ -111,6 +112,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p op_info->set_kernel_name(obj.at(kKernelName)); op_info->set_partial_flag(obj.at(kPartialFlag)); + if (obj.find(kDynamicShape) != obj.end()) { + op_info->set_dynamic_shape(obj.at(kDynamicShape)); + } + if (obj.find(kOpPattern) != obj.end()) { std::string op_pattern = obj.at(kOpPattern); auto find_iter = kOpPatternMap.find(op_pattern); @@ -322,7 +327,7 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply return ret; } -std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type, bool is_dynamic_shape) { if (!OpLib::RegOpFromLocalInfo()) { MS_LOG(INFO) << "Warning reg local op info failed."; } @@ -338,16 +343,20 @@ std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType im for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) { auto &op_info = iter->second; MS_EXCEPTION_IF_NULL(op_info); + if (op_info->imply_type() != imply_type) { continue; } if (imply_type == kAKG && op_info->processor() != target_processor) { continue; } + if (is_dynamic_shape && !op_info->dynamic_shape()) { + continue; + } return op_info; } MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) - << ", current op num: " << op_info_.size(); + << ", current op num: " << op_info_.size() << " is_dynamic_shape:" << is_dynamic_shape; return nullptr; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h index 2dfa0ea772..90137be6f0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h +++ b/mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.h @@ -32,7 +32,8 @@ class OpLib { virtual ~OpLib() = default; static bool RegOp(const std::string &json_string, const std::string &impl_path); static void RegOpInfo(const std::shared_ptr &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); } - static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type, + bool is_dynamic_shape = false); static const std::multimap> &GetAllOpsInfo() { return op_info_; } protected: diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc index fa175b3805..699a1d61a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.cc @@ -21,9 +21,14 @@ #include "backend/session/anf_runtime_algorithm.h" #include "common/trans.h" #include "utils/ms_context.h" +#include "runtime/device/kernel_runtime.h" +#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h" using ge::model_runner::MemcpyAsyncTaskInfo; using MemcpyAsyncTaskInfoPtr = std::shared_ptr; +using AddressPtrList = std::vector; +using mindspore::device::ascend::MemcpyRtsDynamicKernel; +using MemcpyRtsDynamicKernelPtr = std::shared_ptr; namespace mindspore { namespace kernel { @@ -122,6 +127,32 @@ std::vector MemCpyAsyncKernel::GenTask(const std::vectorsize < kernel_inputs[0]->size) { + MS_LOG(EXCEPTION) << "Check rtMemcpyAsync destMax < src size"; + } + // input x -> memcpy_async -> AllReduce + if (kernel_outputs[0]->size > kernel_inputs[0]->size) { + MS_LOG(WARNING) << "Check rtMemcpyAsync destMax > src size"; + } + + return std::make_shared(stream_ptr, cnode_ptr, kernel_outputs[0]->addr, + kernel_outputs[0]->size, kernel_inputs[0]->addr, + kernel_inputs[0]->size); +} const std::vector data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16, diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h index 4e66a212b2..30a8b6be44 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/memcpy_async.h @@ -34,6 +34,7 @@ class MemCpyAsyncKernel : public RtKernel { const std::vector &outputs, void *stream_ptr) override; std::vector GenTask(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; private: void GetInputOutputDataType(const AnfNodePtr &anf_node); diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc index e9548481e6..cbbfba380c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.cc @@ -21,8 +21,10 @@ #include "framework/ge_runtime/task_info.h" #include "runtime/device/ascend/profiling/profiling_utils.h" #include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h" using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo; +using mindspore::device::ascend::ProfilingRtsDynamicKernel; using mindspore::device::ascend::ProfilingUtils; namespace mindspore { @@ -64,5 +66,9 @@ std::vector ProfilingKernelMod::GenTask(const std::vector(kernel_name_, stream_id, log_id_, notify_, flags_); return {task_info_ptr}; } + +device::DynamicKernelPtr ProfilingKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { + return std::make_shared(stream_ptr, cnode_ptr, log_id_, notify_, flags_); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h index 239cf8e222..c69c6eaed8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/profiling_kernel_mod.h @@ -27,6 +27,7 @@ class ProfilingKernelMod : public RtKernel { const std::vector &outputs, void *stream_ptr) override; std::vector GenTask(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, uint32_t stream_id) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; bool Init(const AnfNodePtr &anf_node) override; private: diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc index f8b4653988..3b7f0c91b4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.cc @@ -29,157 +29,6 @@ namespace mindspore { namespace kernel { namespace tbe { -static std::map tbe_func_adapter_map = { - {"softmax", "softmax_v2"}, - {"log_softmax", "log_softmax_v2"}, - {"apply_momentum", "apply_momentum_d"}, - {"apply_ftrl", "apply_ftrl_d"}, - {"re_lu6", "relu6"}, - {"re_lu6_grad", "relu6_grad"}, - {"re_lu", "relu"}, - {"reverse_v2", "reverse_v2_d"}, - {"re_luv2", "relu_v2"}, - {"p_re_lu", "prelu"}, - {"p_re_lu_grad", "prelu_grad"}, - {"tensor_add", "add"}, - {"reduce_mean", "reduce_mean_d"}, - {"reduce_max", "reduce_max_d"}, - {"reduce_min", "reduce_min_d"}, - {"avg_pool_grad", "avg_pool_grad_d"}, - {"avg_pool_grad_vm", "avg_pool_grad_d"}, - {"conv2d_backprop_filter", "conv2d_backprop_filter_d"}, - {"conv2d_backprop_input", "conv2d_backprop_input_d"}, - {"depthwise_conv2d_native", "depthwise_conv2d"}, - {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"}, - {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"}, - {"scatter_nd", "scatter_nd_d"}, - {"tile", "tile_d"}, - {"gather_v2", "gather_v2_d"}, - {"sparse_gather_v2", "gather_v2_d"}, - {"batch_mat_mul", "batch_matmul"}, - {"b_n_training_reduce", "bn_training_reduce"}, - {"b_n_training_update", "bn_training_update"}, - {"b_n_training_update_v2", "bn_training_update_v2"}, - {"b_n_training_update_v3", "bn_training_update_v3"}, - {"b_n_training_reduce_grad", "bn_training_reduce_grad"}, - {"b_n_training_update_grad", "bn_training_update_grad"}, - {"b_n_infer", "bn_infer"}, - {"b_n_infer_grad", "bn_infer_grad"}, - {"b_n_inference", "bninference_d"}, - {"n_pu_clear_float_status", "n_p_u_clear_float_status"}, - {"n_pu_get_float_status", "n_p_u_get_float_status"}, - {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"}, - {"dropout_do_mask", "drop_out_do_mask"}, - {"strided_slice", "strided_slice_d"}, - {"strided_slice_grad", "strided_slice_grad_d"}, - {"sparse_apply_ftrl", "sparse_apply_ftrl_d"}, - {"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"}, - {"apply_ada_max", "apply_ada_max_d"}, - {"apply_adadelta", "apply_adadelta_d"}, - {"apply_adagrad", "apply_adagrad_d"}, - {"apply_adagrad_v2", "apply_adagradv2_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"}, - {"apply_proximal_adagrad", "apply_proximal_adagrad_d"}, - {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"}, - {"apply_add_sign", "apply_add_sign_d"}, - {"apply_power_sign", "apply_power_sign_d"}, - {"apply_centered_rms_prop", "apply_centered_rms_prop_d"}, - {"transpose", "transpose_d"}, - {"fill", "fill_d"}, - {"unsorted_segment_sum", "unsorted_segment_sum_d"}, - {"unsorted_segment_prod", "unsorted_segment_prod_d"}, - {"concat", "concat_d"}, - {"slice", "slice_d"}, - {"reduce_sum", "reduce_sum_d"}, - {"inplace_add", "inplace_add_d"}, - {"inplace_sub", "inplace_sub_d"}, - {"one_hot", "one_hot_d"}, - {"sum", "reduce_sum_d"}, - {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, - {"lamb_next_mv", "lamb_next_m_v"}, - {"split", "split_d"}, - {"split_v", "split_v_d"}, - {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, - {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"}, - {"pad", "pad_d"}, - {"argmax", "arg_max_d"}, - {"argmin", "arg_min_d"}, - {"space_to_batch", "space_to_batch_d"}, - {"batch_to_space", "batch_to_space_d"}, - {"space_to_batch_nd", "space_to_batch_nd_d"}, - {"batch_to_space_nd", "batch_to_space_nd_d"}, - {"resize_bilinear", "resize_bilinear_v2_d"}, - {"resize_bilinear_grad", "resize_bilinear_v2_grad"}, - {"adam", "apply_adam_d"}, - {"r_oi_align", "roi_align"}, - {"r_oi_align_grad", "roi_align_grad"}, - {"i_ou", "iou"}, - {"s_gd", "sgd"}, - {"l_rn", "lrn"}, - {"l_rn_grad", "lrn_grad"}, - {"l_ars_update", "lars_v2_update"}, - {"n_ms_with_mask", "nms_with_mask"}, - {"square_sum_all", "square_sum_all"}, - {"cum_sum", "cumsum_d"}, - {"range", "range_d"}, - {"lin_space", "lin_space_d"}, - {"inv_grad", "inv_grad"}, - {"apply_rms_prop", "apply_rms_prop_d"}, - {"cum_prod", "cumprod_d"}, - {"reduce_all", "reduce_all_d"}, - {"reduce_any", "reduce_any_d"}, - {"sparse_apply_adagrad", "sparse_apply_adagrad_d"}, - {"unsorted_segment_min", "unsorted_segment_min_d"}, - {"reduce_prod", "reduce_prod_d"}, - {"a_cos", "acos"}, - {"a_cos_grad", "acos_grad"}, - {"histogram_fixed_width", "histogram_fixed_width_d"}, - {"broadcast_to", "broadcast_to_d"}, - {"inplace_update", "inplace_update_d"}, - {"i_fmr", "ifmr"}, - {"matrix_diag", "matrix_diag_d"}, - {"matrix_diag_part", "matrix_diag_part_d"}, - {"matrix_set_diag", "matrix_set_diag_d"}, - {"l_stm_input_grad", "lstm_input_grad"}}; - -void TbeAdapter::NormalizeFuncName(std::string *func_name) { - if (func_name == nullptr) { - MS_LOG(EXCEPTION) << "func_name is null"; - } - std::string name_tmp; - bool sub_head = false; - for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) { - if (islower(*iter)) { - sub_head = false; - } - if (isdigit(*iter)) { - sub_head = true; - } - if (isupper(*iter) && iter != func_name->begin()) { - if (!sub_head) { - (void)name_tmp.insert(name_tmp.end(), '_'); - sub_head = true; - } else { - string::iterator iter_next = iter + 1; - if (iter_next != func_name->end()) { - if (islower(*iter_next)) { - (void)name_tmp.insert(name_tmp.end(), '_'); - } - } - } - } - (void)name_tmp.insert(name_tmp.end(), *iter); - } - (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower); - *func_name = name_tmp; - auto iter = tbe_func_adapter_map.find(*func_name); - if (iter != tbe_func_adapter_map.end()) { - MS_LOG(INFO) << "Map actual op from me: " << *func_name << " to tbe op: " << iter->second; - *func_name = iter->second; - } -} - std::unordered_set input_order_adjusted_ops = { "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop", "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h index 027b8e4b88..e6dddfb974 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_adapter.h @@ -35,7 +35,6 @@ class TbeAdapter { public: TbeAdapter() = default; ~TbeAdapter() = default; - static void NormalizeFuncName(std::string *func_name); static void InputOrderPass(const std::string &op_name, std::vector> const &inputs_list, nlohmann::json *inputs_json); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.cc new file mode 100644 index 0000000000..e90a17b414 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace kernel { +namespace tbe { + +bool TbeDynamicShapeUtil::IsDynamicShapeNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto input_num = AnfAlgo ::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_num; ++i) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i); + if (std::any_of(input_shape.begin(), input_shape.end(), [](const size_t &dim) { return dim < 0; })) { + MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node."; + return true; + } + } + auto output_num = AnfAlgo ::GetOutputTensorNum(cnode); + for (size_t i = 0; i < output_num; ++i) { + auto output_shape = AnfAlgo::GetOutputInferShape(cnode, i); + if (std::any_of(output_shape.begin(), output_shape.end(), [](const size_t &dim) { return dim < 0; })) { + MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node."; + return true; + } + } + return false; +} + +bool TbeDynamicShapeUtil::IsDynamicShapeNode(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return IsDynamicShapeNode(cnode); + } + return false; +} + +void TbeDynamicShapeUtil::SetDynamicShapeAttr(const CNodePtr &cnode) { + auto is_dyanmic_shape = IsDynamicShapeNode(cnode); + AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(is_dyanmic_shape), cnode); +} + +bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return GetDynamicShapeAttr(cnode); + } + return false; +} + +bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto is_dynamic_shape = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode); + if (!is_dynamic_shape) { + MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") does not has is_dynamic_shape attribute."; + return false; + } + is_dynamic_shape = AnfAlgo::GetNodeAttr(cnode, kAttrIsDynamicShape); + return is_dynamic_shape; +} + +std::shared_ptr TbeDynamicShapeUtil::FindOp(const std::string &op_name, const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (anf_node->isa()) { + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + return FindOp(op_name, cnode); + } + return nullptr; +} + +std::shared_ptr TbeDynamicShapeUtil::FindOp(const std::string &op_name, const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto is_dynamic_shape = GetDynamicShapeAttr(cnode); + return mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE, is_dynamic_shape); +} + +std::vector> TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto input_range_min = AnfAlgo::GetInputMinShape(anf_node, index); + auto input_range_max = AnfAlgo::GetInputMaxShape(anf_node, index); + if (input_range_min.size() != input_range_max.size()) { + MS_EXCEPTION(ArgumentError) << "Input range size is not equal, min size: " << input_range_min.size() + << "max size: " << input_range_max.size(); + } + if (input_range_min.empty() && input_range_max.empty()) { + return {{1, 1}}; + } + std::vector> ret; + for (size_t i = 0; i < input_range_min.size(); ++i) { + ret.emplace_back(input_range_min[i], input_range_max[i]); + } + return ret; +} + +std::vector> TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto output_range_min = AnfAlgo::GetOutputMinShape(anf_node, index); + auto output_range_max = AnfAlgo::GetOutputMaxShape(anf_node, index); + if (output_range_min.size() != output_range_max.size()) { + MS_EXCEPTION(ArgumentError) << "Onput range size is not equal, min size: " << output_range_min.size() + << "max size: " << output_range_max.size(); + } + if (output_range_max.empty() && output_range_min.empty()) { + return {{1, 1}}; + } + std::vector> ret; + for (size_t i = 0; i < output_range_min.size(); ++i) { + ret.emplace_back(output_range_min[i], output_range_max[i]); + } + return ret; +} + +} // namespace tbe +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h new file mode 100644 index 0000000000..c37846796e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H + +#include +#include +#include +#include +#include "mindspore/core/ir/anf.h" +#include "backend/kernel_compiler/oplib/oplib.h" +namespace mindspore { +namespace kernel { +namespace tbe { + +class TbeDynamicShapeUtil { + public: + TbeDynamicShapeUtil() = default; + ~TbeDynamicShapeUtil() = default; + static bool IsDynamicShapeNode(const CNodePtr &cnode); + static bool IsDynamicShapeNode(const AnfNodePtr &anf_node); + static void SetDynamicShapeAttr(const CNodePtr &cnode); + static bool GetDynamicShapeAttr(const CNodePtr &cnode); + static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node); + static std::shared_ptr FindOp(const std::string &op_name, const AnfNodePtr &anf_node); + static std::shared_ptr FindOp(const std::string &op_name, const CNodePtr &cnode); + static std::vector> GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index); + static std::vector> GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index); +}; + +} // namespace tbe +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 84d36c4de4..c6ec50125e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -23,6 +23,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" #include "utils/ms_context.h" #include "runtime/dev.h" @@ -61,6 +62,7 @@ constexpr auto kJDataType = "data_type"; constexpr auto kJOutputIndex = "output_index"; constexpr auto kJOutputDesc = "output_desc"; constexpr auto kJInputDesc = "input_desc"; +constexpr auto kJRange = "range"; constexpr auto kVTypeInt = "int"; constexpr auto kVTypeStr = "str"; constexpr auto kVTypeBool = "bool"; @@ -89,24 +91,21 @@ constexpr auto kJKwdArgs = "kwds_args"; constexpr auto kJListArgs = "list_args"; constexpr auto kJSocVersion = "socVersion"; constexpr auto kSOC_VERSION = "SOC_VERSION"; +constexpr auto kJIsDynamicShape = "is_dynamic_shape"; bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(kernel_json); std::string op_name = AnfAlgo::GetCNodeName(anf_node); - auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE); + auto op_info_ptr = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, anf_node); MS_EXCEPTION_IF_NULL(op_info_ptr); (*kernel_json)[kPlatform] = kPlatTBE; (*kernel_json)[kGenModel] = kSingle; (*kernel_json)[kImplPath] = op_info_ptr->impl_path(); nlohmann::json op_info_json; - if (op_info_ptr->impl_path().empty()) { - tbe::TbeAdapter::NormalizeFuncName(&op_name); - } else { - op_name = op_info_ptr->kernel_name(); - } - op_info_json[kJName] = op_name; + op_info_json[kJIsDynamicShape] = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node->cast()); + op_info_json[kJName] = op_info_ptr->kernel_name(); // generate inputs json nlohmann::json inputs_json; if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) { @@ -180,6 +179,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr &anf_ input_desc_json[kJFormat] = format; input_desc_json[kJValid] = value; input_desc_json[kJParamType] = input_ptr->param_type(); + input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index); input_list->emplace_back(input_desc_json); } return true; @@ -359,8 +359,13 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_nod for (size_t i = 0; i < output_obj_num; i++) { auto dtype = GetDeviceOutputType(anf_node, *output_idx); auto format = GetDeviceOutputFormat(anf_node, *output_idx); - auto shape = GetDeviceOutputShape(anf_node, *output_idx); - std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); + + std::vector shape; + AnfAlgo::GetRealDynamicShape(GetDeviceOutputShape(anf_node, *output_idx), NOT_NULL(&shape)); + + std::vector ori_shape; + AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape)); + // std::vector ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx); if (ori_shape.empty()) { ori_shape.emplace_back(1); } @@ -373,6 +378,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_nod output_obj[kJName] = output_ptr->name(); output_obj[kJValid] = true; output_obj[kJParamType] = output_ptr->param_type(); + output_obj[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, *output_idx); output_list->emplace_back(output_obj); (*output_idx)++; } @@ -575,48 +581,76 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no return format; } -bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list) { - if (input_size_list == nullptr || output_size_list == nullptr) { - MS_LOG(ERROR) << "Input size or output size is nullptr"; - return false; - } - input_size_list->clear(); - output_size_list->clear(); - for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) { +void GetInputSizeList(const nlohmann::json &input_json, std::vector *input_size_list, + const AnfNodePtr &anf_node) { + for (size_t i = 0; i < input_json.size(); i++) { + for (size_t m = 0; m < input_json[i].size(); m++) { size_t size_i = 1; - if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) { - std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName]; + if (input_json[i][m][kJValid] == false) { + std::string input_name = input_json[i][m][kJName]; MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false."; continue; } - for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) { - size_i *= static_cast(j); + for (size_t j = 0; j < input_json[i][m][kJShape].size(); ++j) { + if (input_json[i][m][kJShape][j] == -1) { + auto input_max_shape = AnfAlgo::GetInputMaxShape(anf_node, i); + if (j >= input_max_shape.size()) { + MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape"; + } + MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << input_max_shape[j]; + size_i *= input_max_shape[j]; + continue; + } + size_i *= static_cast(input_json[i][m][kJShape][j]); } - std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype]; + std::string dtype = input_json[i][m][kJDtype]; size_t nbyte = tbe::GetDtypeNbyte(dtype); size_i *= nbyte; input_size_list->push_back(size_i); } } - for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) { - for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) { +} + +void GetOutputSizeList(const nlohmann::json &output_json, std::vector *output_size_list, + const AnfNodePtr &anf_node) { + for (size_t i = 0; i < output_json.size(); i++) { + for (size_t m = 0; m < output_json[i].size(); m++) { size_t size_i = 1; - if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) { - std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName]; + if (output_json[i][m][kJValid] == false) { + std::string output_name = output_json[i][m][kJName]; MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false."; continue; } - for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) { - size_i *= static_cast(j); + for (size_t j = 0; j < output_json[i][m][kJShape].size(); ++j) { + if (output_json[i][m][kJShape][j] == -1) { + auto output_max_shape = AnfAlgo::GetOutputMaxShape(anf_node, i); + if (j >= output_max_shape.size()) { + MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape"; + } + MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j]; + size_i *= output_max_shape[j]; + continue; + } + size_i *= static_cast(output_json[i][m][kJShape][j]); } - std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype]; + std::string dtype = output_json[i][m][kJDtype]; size_t nbyte = tbe::GetDtypeNbyte(dtype); size_i *= nbyte; output_size_list->push_back(size_i); } } +} + +bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, + std::vector *output_size_list, const AnfNodePtr &anf_node) { + if (input_size_list == nullptr || output_size_list == nullptr) { + MS_LOG(ERROR) << "Input size or output size is nullptr"; + return false; + } + input_size_list->clear(); + output_size_list->clear(); + GetInputSizeList(kernel_json[kJOpInfo][kJInputs], input_size_list, anf_node); + GetOutputSizeList(kernel_json[kJOpInfo][kJOutputs], output_size_list, anf_node); return true; } @@ -678,17 +712,18 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode MS_EXCEPTION_IF_NULL(fusion_kernel_name); // gen others auto origin_type = AnfAlgo::GetCNodeName(cnode); + auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(origin_type, cnode); // replace special op type for buffer fusion op auto type = GetRealOpType(origin_type); (*compute_op_str)[kJtype] = type; - tbe::TbeAdapter::NormalizeFuncName(&type); - (*compute_op_str)[kJFuncName] = type; + auto kernel_name = op_info_ptr->kernel_name(); + (*compute_op_str)[kJFuncName] = kernel_name; (*compute_op_str)[kJModuleName] = std::string("impl.") + type; (*compute_op_str)[kJName] = cnode->fullname_with_scope(); (*compute_op_str)[kJPattern] = GetNodeFusionType(cnode); (*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe"; (void)(*fusion_kernel_name).append("_"); - (void)(*fusion_kernel_name).append(type); + (void)(*fusion_kernel_name).append(kernel_name); } void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) { @@ -952,7 +987,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i } MS_EXCEPTION_IF_NULL(cnode); auto node_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = OpLib::FindOp(node_name, kTBE); + auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode); MS_EXCEPTION_IF_NULL(cnode); if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h index 9c760bdcbb..c91ac7c666 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.h @@ -38,7 +38,7 @@ class TbeKernelBuild { public: static bool GetIOSize(const nlohmann::json &kernel_json, std::vector *input_size_list, - std::vector *output_size_list); + std::vector *output_size_list, const AnfNodePtr &anf_node); // Ub Fuison static bool GenFusionScopeJson(const std::vector &input_nodes, const std::vector &compute_nodes, nlohmann::json *fusion_json, diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc index 933fcb1566..a054f58583 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.cc @@ -19,11 +19,14 @@ #include "runtime/rt.h" #include "utils/ms_context.h" #include "graphengine/inc/framework/ge_runtime/task_info.h" +#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h" +#include "runtime/device/kernel_runtime.h" namespace mindspore { namespace kernel { using TbeTaskInfoPtr = std::shared_ptr; using tbe::KernelManager; +using AddressPtrList = std::vector; bool TbeKernelMod::Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) { @@ -105,6 +108,49 @@ std::vector TbeKernelMod::GenTask(const std::vector &in return {task_info_ptr}; } +device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + + // Get para_size from json + auto kernel_json_info = kernel_pack_->kernel_json_info(); + auto op_para_size = kernel_json_info.op_para_size; + + // Get stub_function + uint32_t block_dim = 1; // default blockdim equal to 1. + auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim); + if (func_stub == 0) { + MS_LOG(EXCEPTION) << "GenFuncStub failed."; + } + const void *stub_func_ptr = reinterpret_cast(func_stub); + + // Generate args + std::vector runtime_args; + (void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args), + [](const AddressPtr &output) -> void * { return output->addr; }); + if (!kernel_workspaces.empty()) { + (void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args), + [](const AddressPtr &addr) -> void * { return addr->addr; }); + } + + void *tiling_data_ptr = nullptr; + if (op_para_size > 0) { + auto ret = rtMalloc(&tiling_data_ptr, op_para_size, RT_MEMORY_HBM); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "rtMalloc tiling data failed"; + } + runtime_args.push_back(tiling_data_ptr); + } + + auto executor = std::make_shared( + stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args); + return executor; +} + vector TbeKernelMod::GenParameters() { auto kernel_json_info = kernel_pack_->kernel_json_info(); return kernel_json_info.parameters; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h index 70d17a02c6..25d1427150 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_mod.h @@ -42,6 +42,7 @@ class TbeKernelMod : public AscendKernelMod { const std::vector &outputs, void *stream_ptr) override; std::vector GenTask(const std::vector &inputs, const std::vector &workspaces, const std::vector &outputs, uint32_t stream_id) override; + device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override; std::vector GenParameters() override; private: diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index 41f03b7f5d..04676a02c6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -15,13 +15,11 @@ */ #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" - #include #include #include #include #include - #include "utils/ms_context.h" #include "backend/kernel_compiler/tbe/tbe_adapter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_build.h" @@ -29,6 +27,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/tbe/tbe_convert_utils.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" namespace mindspore { namespace kernel { @@ -52,15 +51,18 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { // get size std::vector input_size_list; std::vector output_size_list; - (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + (void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, anf_node); // search cache const std::string &json_name = creator.json_name(); - if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) { - MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name; + auto IsDynamicShape = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node); + if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get()) && + !IsDynamicShape) { + MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " Use cached kernel, kernel json name:." + << json_name; continue; } // same op not need build, but need wait build finish to set kernel mode - if (processed_kernel.find(json_name) != processed_kernel.end()) { + if (processed_kernel.find(json_name) != processed_kernel.end() && !IsDynamicShape) { build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list); continue; } @@ -72,8 +74,8 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { while (!build_manger->IsAllTaskFinish()) { int task_id = -1; std::string task_result; - std::string pre_build_result; - auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result); + std::string build_result; + auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; } @@ -81,7 +83,7 @@ bool TbeOpParallelBuild(const std::vector &anf_nodes) { if (task_result != "Success") { MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; } - (void)build_manger->TaskFinishProcess(task_id); + (void)build_manger->TaskFinishProcess(task_id, build_result); } return build_manger->GenSameOpKernelMod(); } @@ -93,7 +95,7 @@ void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNod const std::vector &output_size_list, int32_t scope_id) { MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id; struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); + task_info.node = anf_node; task_info.json_name = json_name; if (anf_node == nullptr) { task_info.processor = tbe::kProcessorAiCore; @@ -111,7 +113,38 @@ bool ParallelBuildManager::IsAllTaskFinish() const { return task_map_.empty(); } -std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) { +void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) { + auto task_iter = pre_task_map_.find(task_id); + if (task_iter == pre_task_map_.end()) { + MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id; + } + auto node = task_iter->second; + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + std::string start_flag = "fusion_pattern_start"; + std::string end_flag = "fusion_pattern_end"; + int start = pre_build_result.find(start_flag); + int end = pre_build_result.find(end_flag); + if (start != -1 && end != -1 && end >= start) { + std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size()); + if (result.empty()) { + (void)pre_task_map_.erase(task_iter); + return; + } + transform(result.begin(), result.end(), result.begin(), ::toupper); + AnfAlgo::SetNodeAttr(kAttrFusionType, MakeValue(result), node); + FusionType fusion_type = tbe::GetFusionType(result); + builder->SetFusionType(fusion_type); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + (void)pre_task_map_.erase(task_iter); +} + +std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_id, const std::string &build_ret, + bool set_kernel_mod) { + auto compile_info = ProcessBuildRetStr(build_ret); + MS_LOG(DEBUG) << "Tbe build ret:" << compile_info; + auto task_iter = task_map_.find(task_id); if (task_iter == task_map_.end()) { MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id; @@ -133,7 +166,9 @@ std::pair ParallelBuildManager::TaskFinishProcess(int32_t task_iter->second.output_size_list, kernel_pack); MS_EXCEPTION_IF_NULL(kernel_mod); if (set_kernel_mod) { - AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node); + AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node.get()); + AnfAlgo::SetNodeAttr(kAttrCompileInfo, MakeValue(compile_info), task_iter->second.node); + MS_LOG(DEBUG) << "Set Node Attr compile_info:" << compile_info; } auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod); (void)task_map_.erase(task_iter); @@ -145,7 +180,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, const std::vector &input_size_list, const std::vector &output_size_list) { struct KernelBuildTaskInfo task_info; - task_info.node = anf_node.get(); + task_info.node = anf_node; task_info.json_name = json_name; task_info.processor = tbe::GetProcessor(anf_node); task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end()); @@ -156,7 +191,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node, bool ParallelBuildManager::GenSameOpKernelMod() const { for (const auto &task_info : same_op_list_) { bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list, - task_info.output_size_list, task_info.node); + task_info.output_size_list, task_info.node.get()); if (!ret) { MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache."; return false; @@ -212,5 +247,20 @@ void ParallelBuildManager::ResetTaskInfo() { same_op_list_.clear(); AscendKernelBuildClient::Instance().TbeReset(); } + +std::string ParallelBuildManager::ProcessBuildRetStr(const std::string &build_result) { + std::string start_flag = "fusion_pattern_start"; + std::string end_flag = "fusion_pattern_end"; + int start = build_result.find(start_flag); + int end = build_result.find(end_flag); + if (start != -1 && end != -1 && end >= start) { + std::string result = build_result.substr(start + start_flag.size(), end - start - start_flag.size()); + if (!result.empty()) { + return result; + } + } + return ""; +} + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h index a7a28d4502..48d9833f06 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h @@ -31,7 +31,7 @@ namespace kernel { bool TbeOpParallelBuild(const std::vector &anf_nodes); struct KernelBuildTaskInfo { - AnfNode *node; + AnfNodePtr node; std::string processor; std::string json_name; std::vector input_size_list; @@ -53,16 +53,21 @@ class ParallelBuildManager { const std::vector &input_size_list, const std::vector &output_size_list, AnfNode *node) const; bool IsAllTaskFinish() const; - std::pair TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true); + void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result); + std::pair TaskFinishProcess(int32_t task_id, const std::string &build_ret, + bool set_kernel_mod = true); KernelModPtr GenKernelMod(const string &json_name, const string &processor, const std::vector &input_size_list, const std::vector &output_size_list, const KernelPackPtr &kernel_pack) const; // Interactive with real backend, who could be implemented by Python. - int StartCompileOp(const nlohmann::json &kernel_json); - bool WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result); + static int StartCompileOp(const nlohmann::json &kernel_json); + static bool WaitOne(int *task_id, std::string *task_result, std::string *build_result); void ResetTaskInfo(); + private: + std::string ProcessBuildRetStr(const std::string &build_result); + private: std::map pre_task_map_; std::map task_map_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index da366cecac..a4a41042e0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -30,6 +30,7 @@ #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h" #include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "backend/session/kernel_build_client.h" namespace mindspore { @@ -54,7 +55,8 @@ void TbeKernelSelect::TbeMetadataInfoEx() { MS_EXCEPTION_IF_NULL(cnode_ptr_); MS_EXCEPTION_IF_NULL(kernel_info_list_); node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_); - auto op_info_ptr = OpLib::FindOp(node_name_, kTBE); + + auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_); if (!op_info_ptr) { MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_; return; @@ -81,6 +83,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() { } // check support FilterInVaildKernelInfo(); + MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select."; } void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index b7a86efa4c..e2a01cc87d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -23,6 +23,7 @@ #include "backend/kernel_compiler/kernel_query.h" #include "backend/kernel_compiler/oplib/oplib.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" namespace mindspore { namespace opt { @@ -62,7 +63,7 @@ class KernelQuery { if (!node->isa()) { return false; } - auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE); + auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node); if (op_info != nullptr) { return op_info->is_ref(); } @@ -75,8 +76,8 @@ class OpFinder { public: OpFinder() = default; virtual ~OpFinder() = default; - virtual int GetOpRegisteredOutputNum(const std::string &op_name) { - auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE); + virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) { + auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode); if (op_info == nullptr) { return -1; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc index 7337bf9a22..301dbdc0a4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc @@ -46,6 +46,9 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 41de6650a7..376d0e5623 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -36,7 +36,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { auto cnode = cur_node->cast(); MS_EXCEPTION_IF_NULL(cnode); std::string op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode); // deal ref op if (op_info != nullptr && op_info->is_ref()) { auto ref_infos = op_info->ref_infos(); @@ -223,7 +223,7 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A DealBroadCastAsRef(graph, cnode); auto op_name = AnfAlgo::GetCNodeName(cnode); - auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE); + auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode); if (op_info == nullptr || !op_info->is_ref()) { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc index 63fb40fa8b..dba3592e63 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/concat_fission.cc @@ -65,6 +65,9 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // The real input begins with index 1. diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc index 1eca3298e4..d79e9e7818 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc @@ -86,6 +86,9 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An const EquivPtr &) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); if (cnode->inputs().size() != kLayerNormGradInputNum) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc index 2037f61b4b..873bec1f70 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/pack_fission.cc @@ -72,6 +72,9 @@ const BaseRef PackFission::DefinePattern() const { const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // The real input begins with index 1. diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc index f6a941cc69..8b237d7e73 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc @@ -105,6 +105,9 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN if (graph == nullptr || node == nullptr) { return nullptr; } + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); CheckCNodeInputSize(cnode, 2); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc index a37c2f38a9..fd94170623 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc @@ -174,6 +174,9 @@ const BaseRef SplitFission::DefinePattern() const { const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); // Check output num diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc index d9b0f4616e..213b435f1c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -127,6 +127,9 @@ const BaseRef TopKSplit::DefinePattern() const { const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(node); + if (AnfAlgo::IsDynamicShape(node)) { + return nullptr; + } auto kernel_graph = func_graph->cast(); // set value node as topk's input auto cnode = node->cast(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc index cc58d2b057..b055847908 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/add_input_to_output.cc @@ -86,7 +86,7 @@ const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) { return nullptr; } - int output_num = op_finder_->GetOpRegisteredOutputNum(op_name); + int output_num = op_finder_->GetOpRegisteredOutputNum(op_name, cnode); // No need add output when it is not a tbe op. if (output_num == -1) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc index 37243fbeeb..f051c7594a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/confusion_mul_grad_fusion.cc @@ -84,6 +84,9 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf MS_LOG(INFO) << "mul's second input is not addn"; return true; } + if (AnfAlgo::IsDynamicShape(addn)) { + return true; + } std::vector shape = AnfAlgo::GetOutputInferShape(addn, 0); if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc index 4cf83df43c..fa111405c6 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc @@ -53,6 +53,9 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { // ReluV2's 2rd output is mask whose data type is uint8 TypeId mask_dtype = kNumberTypeUInt8; + if (AnfAlgo::IsDynamicShape(relu)) { + return nullptr; + } std::vector mask_shape = AnfAlgo::GetOutputInferShape(relu, 0); if (mask_shape.size() != 4) { MS_LOG(DEBUG) << "relu's infer shape size not equal 4"; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc index 90c5ac19a9..9c3a59ed7c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc @@ -29,6 +29,9 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) { if (!node->isa()) { return false; } + if (AnfAlgo::IsDynamicShape(node)) { + return false; + } std::vector mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0); return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc index 73f4e61241..a0f803e6a3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc @@ -85,6 +85,9 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode break; } } + if (AnfAlgo::IsDynamicShape(mul->input(lossscale_input_index))) { + return nullptr; + } auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0); if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) { MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is " diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc index 24010e1858..706c8e4fc0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc @@ -45,6 +45,10 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons if (IsUsedByOthers(func_graph, in_reshape)) { return nullptr; } + + if (AnfAlgo::IsDynamicShape(out_reshape) || AnfAlgo::IsDynamicShape(in_reshape)) { + return nullptr; + } auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0); auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0); if (kernel::IsSameShape(input_shape, output_shape)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc index d4ad4c431f..76f2445dbb 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -50,6 +50,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(transpose_cnode); auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(reshape_cnode); + if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { + return nullptr; + } std::vector reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0); std::vector transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0); if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc index ccbdfa8791..7e34b90fa0 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -50,6 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(reshape_cnode); auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); MS_EXCEPTION_IF_NULL(transpose_cnode); + if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { + return nullptr; + } std::vector reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0); std::vector transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0); if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) { diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 81a1c9135a..388858be75 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -26,6 +26,8 @@ #include "base/base_ref.h" #include "backend/session/anf_runtime_algorithm.h" #include "base/core_ops.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" +#include "frontend/operator/ops.h" #include "utils/ms_utils.h" #include "runtime/device/kernel_info.h" #include "utils/ms_context.h" @@ -394,6 +396,7 @@ bool IsNopNode(const AnfNodePtr &node) { context_ptr->get_param(MS_CTX_DEVICE_TARGET) != kGPUDevice) { return false; } + static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(), kFlattenGradOpName}; diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index 06e319d59f..a652bd3fd0 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -55,6 +55,10 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An continue; } } + if (AnfAlgo::IsDynamicShape(cnode)) { + MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); + continue; + } ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); } return node; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 1ec38e991d..2b267728a8 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -28,6 +28,7 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "common/trans.h" +#include "abstract/param_validator.h" namespace mindspore { namespace session { @@ -42,12 +43,27 @@ namespace { constexpr size_t kNopNodeInputSize = 2; constexpr size_t kNopNodeRealInputIndex = 1; +bool IsShapeDynamic(const abstract::ShapePtr &shape) { + MS_EXCEPTION_IF_NULL(shape); + return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; }); +} + std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); std::vector shape_size_t; - std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); + if (IsShapeDynamic(shape)) { + if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int s) { return s >= 0; })) { + std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t), IntToSize); + } else { + MS_LOG(EXCEPTION) << "Invalid Max Shape"; + } + } else { + std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize); + } return shape_size_t; } + +enum ShapeType { kMaxShape, kMinShape }; } // namespace AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { @@ -1206,19 +1222,6 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s return GetCNodeOutputPrecision(kernel_with_index.first); } -bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { - if (!node->isa()) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto has_attr = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode); - if (!has_attr) { - return false; - } - return AnfAlgo::GetNodeAttr(node, kAttrIsDynamicShape); -} - bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (node->inputs().empty()) { @@ -1252,5 +1255,96 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) { } return true; } + +bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode); + if (!has_attr) { + return false; + } + return AnfAlgo::GetNodeAttr(node, attr); +} + +bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) { + return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape); +} + +void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector &shape, + NotNull *> dynamic_shape) { + for (auto size : shape) { + if (size == SIZE_MAX) { + dynamic_shape->push_back(-1); + } else { + dynamic_shape->push_back(SizeToLong(size)); + } + } +} + +std::vector GetShapeFromSequeueShape(const abstract::SequeueShapePtr &sequeue_shape_ptr, size_t index, + ShapeType type) { + MS_EXCEPTION_IF_NULL(sequeue_shape_ptr); + auto shape_list = sequeue_shape_ptr->shape(); + if (index >= shape_list.size()) { + MS_LOG(EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size(); + } + + auto shape = shape_list[index]; + MS_EXCEPTION_IF_NULL(shape); + if (shape->isa()) { + auto shape_ptr = shape->cast(); + if (type == kMaxShape) { + return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape(); + } else { + return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape(); + } + } else { + MS_LOG(EXCEPTION) << "Invalid Shape Type In Shape List"; + } +} + +std::vector AnfRuntimeAlgorithm::GetInputMaxShape(const AnfNodePtr &anf_node, size_t index) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index); + return GetOutputMaxShape(input_node_with_index.first, input_node_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetInputMinShape(const AnfNodePtr &anf_node, size_t index) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index); + return GetOutputMinShape(input_node_with_index.first, input_node_with_index.second); +} + +std::vector AnfRuntimeAlgorithm::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto shape = anf_node->Shape(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->isa()) { + auto shape_ptr = shape->cast(); + return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape(); + } else if (shape->isa()) { + auto shape_ptr = shape->cast(); + return GetShapeFromSequeueShape(shape_ptr, index, kMaxShape); + } else { + MS_LOG(EXCEPTION) << "Invalid Shape Type"; + } +} + +std::vector AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_node, size_t index) { + MS_EXCEPTION_IF_NULL(anf_node); + auto shape = anf_node->Shape(); + MS_EXCEPTION_IF_NULL(shape); + if (shape->isa()) { + auto shape_ptr = shape->cast(); + return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape(); + } else if (shape->isa()) { + auto shape_ptr = shape->cast(); + return GetShapeFromSequeueShape(shape_ptr, index, kMinShape); + } else { + MS_LOG(EXCEPTION) << "Invalid Shape Type"; + } +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 017afe036c..b909570a76 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -221,6 +221,12 @@ class AnfRuntimeAlgorithm { static bool IsDynamicShape(const AnfNodePtr &node); static bool IsCondControlKernel(const CNodePtr &node); static bool IsIndependentNode(const CNodePtr &node); + static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr); + static void GetRealDynamicShape(const std::vector &shape, NotNull *> dynamic_shape); + static std::vector GetInputMaxShape(const AnfNodePtr &anf_node, size_t index); + static std::vector GetInputMinShape(const AnfNodePtr &anf_node, size_t index); + static std::vector GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index); + static std::vector GetOutputMinShape(const AnfNodePtr &anf_node, size_t index); }; } // namespace session using AnfAlgo = session::AnfRuntimeAlgorithm; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 4b9fc0e2e1..736fc3afe1 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -127,6 +127,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { MS_LOG(INFO) << "Start"; std::vector all_graphs; auto root_graph = ConstructKernelGraph(func_graph, &all_graphs); + // Update Graph Dynamic Shape Attr + UpdateGraphDynamicShapeAttr(NOT_NULL(root_graph)); + root_graph->UpdateGraphDynamicAttr(); BackendOptimization(all_graphs); // empty graph dont entry to backend if (root_graph->execution_order().empty()) { @@ -136,6 +139,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { InitRuntimeResource(); return root_graph->graph_id(); } + // create parameter for multiple branch std::set memo; CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 31c3cab446..5d55821c97 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -1201,6 +1201,17 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) { } } +void KernelGraph::UpdateGraphDynamicAttr() { + for (const auto &cnode : execution_order_) { + if (AnfAlgo::IsDynamicShape(cnode)) { + MS_LOG(INFO) << "Update Graph Dynamic Attr"; + is_dynamic_shape_ = true; + return; + } + } + is_dynamic_shape_ = false; +} + std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); } KernelGraph::~KernelGraph() { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index ff01501bbc..25f54f1a58 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -37,7 +37,13 @@ namespace session { using AnfWithOutIndex = std::pair; class KernelGraph : public FuncGraph { public: - KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) { + KernelGraph() + : graph_id_(0), + start_label_(nullptr), + end_goto_(nullptr), + null_output_(false), + current_epoch_(0), + is_dynamic_shape_(false) { inputs_ = std::make_shared>(); execution_order_ = {}; executable_ = true; @@ -161,6 +167,7 @@ class KernelGraph : public FuncGraph { void set_child_graph_result(const std::vector &child_graph_result) { child_graph_result_ = child_graph_result; } + void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) { if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { return; @@ -176,6 +183,9 @@ class KernelGraph : public FuncGraph { } void RemoveNodeFromGraph(const AnfNodePtr &node); + void UpdateGraphDynamicAttr(); + bool is_dynamic_shape() const { return is_dynamic_shape_; } + private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); @@ -247,10 +257,10 @@ class KernelGraph : public FuncGraph { std::unordered_map> internal_outputs_tensor_map_; uint32_t current_epoch_; std::unordered_map tuple_parameter_to_make_tuple_map_; - std::set visited_nodes_; std::map edge_to_; std::stack loop_nodes_; + bool is_dynamic_shape_; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index ffd5b48dda..5bc577a1f3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -35,6 +35,7 @@ #include "ir/dtype.h" #include "ir/anf.h" #include "ir/func_graph_cloner.h" +#include "utils/utils.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/worker.h" #include "ps/common.h" @@ -1405,6 +1406,97 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vectorRunGraphAsync(shared_from_this(), graph_id, inputs, outputs); } +bool IsDynamicShape(const NotNull &shape) { + return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; }); +} + +bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + auto base_shape = anf_node_ptr->Shape(); + if (base_shape == nullptr) { + MS_LOG(INFO) << "Invalid bash shape ptr, node:" << anf_node_ptr->fullname_with_scope(); + return false; + } + if (base_shape->isa()) { + if (IsDynamicShape(NOT_NULL(base_shape->cast()))) { + return true; + } + } else if (base_shape->isa()) { + auto tuple_shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + + for (size_t i = 0; i < tuple_shape->size(); ++i) { + auto b_shp = (*tuple_shape)[i]; + if (!b_shp->isa()) { + continue; + } + if (IsDynamicShape(NOT_NULL(b_shp->cast()))) { + return true; + } + } + } + return false; +} + +bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) { + MS_EXCEPTION_IF_NULL(anf_node_ptr); + auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr); + for (size_t i = 0; i < input_num; ++i) { + auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i); + auto input = input_with_index.first; + auto index = input_with_index.second; + MS_EXCEPTION_IF_NULL(input); + + auto base_shape = input->Shape(); + if (base_shape == nullptr) { + MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope(); + continue; + } + if (base_shape->isa()) { + if (IsDynamicShape(NOT_NULL(base_shape->cast()))) { + return true; + } + } else if (base_shape->isa()) { + auto tuple_shape = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_shape); + + if (index >= tuple_shape->size()) { + MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index + << " and tuple_shape size:" << tuple_shape->size(); + continue; + } + + auto b_shp = (*tuple_shape)[index]; + if (!b_shp->isa()) { + continue; + } + if (IsDynamicShape(NOT_NULL(b_shp->cast()))) { + return true; + } + } + } + return false; +} + +void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull &root_graph) { + for (const auto &cnode : root_graph->execution_order()) { + auto output_dynamic = IsNodeOutputDynamicShape(NOT_NULL(cnode)); + auto input_dynamic = IsNodeInputDynamicShape(NOT_NULL(cnode)); + if (output_dynamic || input_dynamic) { + AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode); + MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope(); + } + if (output_dynamic) { + AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode); + MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope(); + } + if (input_dynamic) { + AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode); + MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope(); + } + } +} + #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { if (!ps::Util::IsRoleOfWorker()) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 39142ce0ff..6071912126 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this { void AddParameterToGraphInputs(const std::vector ¶meters, KernelGraph *graph); void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter); AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector &node_list); + void UpdateGraphDynamicShapeAttr(const NotNull &root_graph); std::unordered_map> graphs_; std::unordered_map> run_op_graphs_; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 37d71abe17..163f7aeb3d 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -713,5 +713,16 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs); return eval_result; } + +AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(prim); + auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap(); + auto ret = prim_eval_implement_map.find(prim); + if (ret == prim_eval_implement_map.end()) { + MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name() + << " primitive type:" << prim->type_name(); + } + return ret->second.impl_(nullptr, prim, args_spec_list); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index cfe667f252..5912b2f3aa 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -302,6 +302,8 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) { } EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs); + +AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/CMakeLists.txt b/mindspore/ccsrc/runtime/device/CMakeLists.txt index 6666d08c3c..b74dce0523 100644 --- a/mindspore/ccsrc/runtime/device/CMakeLists.txt +++ b/mindspore/ccsrc/runtime/device/CMakeLists.txt @@ -1,5 +1,5 @@ file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc" - "kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" + "kernel_info.cc" "executor/dynamic_kernel.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc" ) if (ENABLE_GPU) diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index bbf2ba933c..a44265755a 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -372,7 +372,7 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con // get size std::vector input_size_list; std::vector output_size_list; - (void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list); + (void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, nullptr); std::string json_name = kernel_json[op_info_str][kernel_name_str]; // op build if (constructed_kernel.find(json_name) == constructed_kernel.end()) { @@ -382,15 +382,15 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con while (!build_manager->IsAllTaskFinish()) { int task_id = -1; std::string task_result; - std::string pre_build_result; - auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result); + std::string build_result; + auto ret = build_manager->WaitOne(&task_id, &task_result, &build_result); if (!ret) { MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id; } if (task_result != "Success") { MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result; } - (void)build_manager->TaskFinishProcess(task_id, false); + (void)build_manager->TaskFinishProcess(task_id, build_result, false); } constructed_kernel.insert(json_name); // search cache diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 6d13baf9ac..2043644304 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -46,12 +46,22 @@ #ifdef MEM_REUSE_DEBUG #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #endif +#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h" +#include "runtime/device/ascend/executor/executor_callback.h" +#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" +#include "profiler/device/ascend/ascend_profiling.h" +#include "profiler/device/ascend/profiling_context.h" +#include "profiler/device/ascend/rt_callback_manager.h" using ge::model_runner::ModelRunner; using mindspore::device::ascend::ProfilingManager; using mindspore::device::ascend::ProfilingUtils; using mindspore::device::ascend::tasksink::TaskGenerator; using mindspore::kernel::tbe::TbeUtils; +using mindspore::profiler::ascend::AscendProfiler; +using mindspore::profiler::ascend::CallbackManager; +using mindspore::profiler::ascend::GetTid; +using mindspore::profiler::ascend::kCallback; using std::vector; constexpr uint32_t kTupleTaskId = 0; @@ -135,6 +145,8 @@ void AscendKernelRuntime::ClearGraphModelMap() { // tell users which dump kernel name not used DumpJsonParser::GetInstance().PrintUnusedKernel(); + graph_dynamic_kernel_map_.clear(); + for (auto &iter : graph_model_map_) { MS_LOG(INFO) << "Ge UnloadModel " << iter.first; auto ret = ModelRunner::Instance().UnloadModel(iter.first); @@ -160,6 +172,13 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; } + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " dynamic kernels"; + if (auto dynamic_kernel_iter = graph_dynamic_kernel_map_.find(graph_id); + dynamic_kernel_iter != graph_dynamic_kernel_map_.end()) { + MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel"; + graph_dynamic_kernel_map_.erase(dynamic_kernel_iter); + } + MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource"; if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) { MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id; @@ -233,6 +252,7 @@ bool AscendKernelRuntime::Init() { InnerSetContext(); return true; } + OpTilingCalculater::GetInstance().Init(); // Start up profiling before rtSetDevice bool ret = ProfilingManager::GetInstance().StartupProfiling(device_id_); if (!ret) { @@ -342,6 +362,11 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { if (!is_task_sink) { return true; } + // Do HcomExecutorInitialize + if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) { + MS_LOG(ERROR) << "Init Hccl Executor Failed"; + return false; + } if (!GenTask(graph)) { return false; } @@ -351,8 +376,35 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; } +bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "GenDynamicKernel start"; + auto cnode_list = graph->execution_order(); + std::vector dynamic_kernels; + for (const auto &cnode : cnode_list) { + MS_EXCEPTION_IF_NULL(cnode); + MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel"; + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_); + MS_EXCEPTION_IF_NULL(dynamic_kernel); + dynamic_kernel->Initialize(); + dynamic_kernels.emplace_back(dynamic_kernel); + } + auto ret = graph_dynamic_kernel_map_.try_emplace(graph->graph_id(), dynamic_kernels); + if (!ret.second) { + MS_LOG(ERROR) << "Graph:" << graph->graph_id() << " already generator executor"; + return false; + } + MS_LOG(INFO) << "GenDynamicKernel end"; + return true; +} + bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { InnerSetContext(); + if (graph->is_dynamic_shape()) { + MS_LOG(INFO) << "Dynamic Shape Graph Generate Dynamic kernel"; + return GenDynamicKernel(graph); + } if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!"; } @@ -407,6 +459,11 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) { InnerSetContext(); + if (graph->is_dynamic_shape()) { + MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step"; + return true; + } + if (graph == nullptr) { MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. "; } @@ -520,9 +577,70 @@ bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, De return ret; } +bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph->graph_id(); + + auto iter = graph_dynamic_kernel_map_.find(graph->graph_id()); + if (iter == graph_dynamic_kernel_map_.end()) { + MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Not Found! Please generator executor first"; + return false; + } + + // Profiling Init + auto &async_profiler = AscendProfiler::GetInstance(); + auto &rt_callback = CallbackManager::GetInstance(stream_); + rt_callback.Init(); + + auto dynamic_kernels = iter->second; + for (const auto &dynamic_kernel : dynamic_kernels) { + if (dynamic_kernel->have_depends()) { + MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream"; + if (!SyncStream()) { + MS_LOG(ERROR) << "SyncStream failed"; + return false; + } + } + + if (dynamic_kernel->is_dynamic_shape()) { + ExecutorCallback::GetInstance().Consume(); + dynamic_kernel->InferShape(); + dynamic_kernel->UpdateArgs(); + } + + // Enable profiling trace point start + rt_callback.RegisterCallback( + [&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] start"); }); + + dynamic_kernel->Execute(); + + // Enable profiling trace point end + rt_callback.RegisterCallback( + [&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] end"); }); + + ExecutorCallback::GetInstance().RegistCallback([&dynamic_kernel] { dynamic_kernel->PostExecute(); }); + } + + if (!SyncStream()) { + MS_LOG(ERROR) << "SyncStream failed"; + return false; + } + ExecutorCallback::GetInstance().Consume(); + + rt_callback.Destroy(); + async_profiler.Dump(std::cout); + async_profiler.Reset(); + return true; +} + bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { InnerSetContext(); MS_EXCEPTION_IF_NULL(graph); + if (graph->is_dynamic_shape()) { + MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async"; + return RunDynamicKernelAsync(graph); + } + MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id(); auto context_ptr = MsContext::GetInstance(); @@ -657,7 +775,12 @@ bool AscendKernelRuntime::DestroyHccl() { MS_LOG(INFO) << "Hccl is not enable, no need to close."; return true; } + // Dynamic Shape Hccl Finalize + if (!HcclExecutorManager::GetInstance().Finalize()) { + MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed"; + } HcclResult res = hcom_destroy(); + if (res != HCCL_SUCCESS) { MS_LOG(ERROR) << "Hccl destroy failed"; return false; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 7014b5a13d..42384490a4 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -40,6 +40,8 @@ class AscendKernelRuntime : public KernelRuntime { bool Init() override; bool LoadData(session::KernelGraph *graph, Debugger *debugger) override; bool GenTask(const session::KernelGraph *graph); + bool GenDynamicKernel(const session::KernelGraph *graph) override; + bool RunDynamicKernelAsync(const session::KernelGraph *graph) override; bool LoadTask(const session::KernelGraph *graph); bool RunTask(const session::KernelGraph *graph); bool Load(session::KernelGraph *graph, bool is_task_sink) override; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc index 9a4ff4edca..8fdc0df498 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc @@ -34,7 +34,7 @@ const uint32_t kHcomMaxTask = 5; const uint32_t kCommonMaxTask = 350; void AscendStreamAssign::AssignStream(const NotNull &graph_ptr) { - if (IsTaskSink()) { + if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) { Reset(); SetLoopSink(); ReorderIndependentOrders(graph_ptr); diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index eb6c53daf6..3ac55fa925 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -24,7 +24,7 @@ #include "runtime/mem.h" #include "runtime/kernel.h" #include "runtime/rt_model.h" -#include "runtime/device/ascend/dump/ge_dump.h" +#include "runtime/device/ascend/ge_types_convert.h" #include "proto/op_mapping_info.pb.h" #include "utils/ms_context.h" #include "debug/data_dump/dump_json_parser.h" @@ -369,13 +369,13 @@ void DataDumper::DumpKernelOutput(const CNodePtr &kernel, void *args, NotNulladd_dim(dim); } - output.set_original_output_format(GetGeFormat(output_format, output_shape.size())); + output.set_original_output_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); output.set_address(static_cast(reinterpret_cast(args)) + offset); // device address data size auto address = AnfAlgo::GetOutputAddr(kernel, i); @@ -409,8 +409,8 @@ void DataDumper::DumpKernelInput(const CNodePtr &kernel, void *args, NotNulladd_dim(dim); diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc new file mode 100644 index 0000000000..98cc3f5964 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.cc @@ -0,0 +1,182 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h" + +#include +#include +#include +#include "framework/common/debug/log.h" +#include "utils/log_adapter.h" +#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h" +#include "register/op_tiling.h" +#include "utils/convert_utils_base.h" +#include "utils/ms_context.h" +#include "runtime/device/kernel_runtime_manager.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "common/trans.h" + +namespace mindspore { +namespace device { +namespace ascend { +AiCoreDynamicKernel::~AiCoreDynamicKernel() { + if (tiling_data_ptr_ != nullptr) { + auto ret = rtFree(tiling_data_ptr_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree tiling_data_ptr_ failed"; + } + } +} + +void AiCoreDynamicKernel::Execute() { + if (stream_ == nullptr) { + MS_LOG(EXCEPTION) << "stream_ptr should not be nullptr."; + } + MS_LOG(INFO) << "Start Execute node:" << cnode_ptr_->fullname_with_scope(); + rtL2Ctrl_t *l2ctrl = nullptr; + auto args_size = static_cast(UlongToUint(sizeof(void *)) * runtime_args_.size()); + if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) { + MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error."; + } + MS_LOG(INFO) << "End Execute node:" << cnode_ptr_->fullname_with_scope(); +} + +std::string ReplaceInvalidJsonStr(const std::string &str) { + auto ret = std::regex_replace(str, std::regex("100000000"), R"("100000000")"); + ret = std::regex_replace(ret, std::regex("100000001"), R"("100000001")"); + ret = std::regex_replace(ret, std::regex("100000002"), R"("100000002")"); + ret = std::regex_replace(ret, std::regex("True"), R"(true)"); + ret = std::regex_replace(ret, std::regex("False"), R"(false)"); + return ret; +} + +void AiCoreDynamicKernel::ParseCompileJson() { + if (!AnfAlgo::IsDynamicShape(cnode_ptr_)) { + return; + } + if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, cnode_ptr_)) { + MS_LOG(EXCEPTION) << "Get compile_info failed"; + } + auto compile_info_attr = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrCompileInfo); + std::replace(compile_info_attr.begin(), compile_info_attr.end(), '\'', '\"'); + compile_info_attr = ReplaceInvalidJsonStr(compile_info_attr); + MS_LOG(INFO) << "Get compile_info:" << compile_info_attr; + + try { + compile_info_json_ = std::make_shared(nlohmann::json::parse(compile_info_attr)); + } catch (nlohmann::json::parse_error &e) { + MS_LOG(EXCEPTION) << "parse json failed, error:" << e.what(); + } + + if (AnfAlgo::HasNodeAttr(kAttrFusionType, cnode_ptr_)) { + auto fusion_type = AnfAlgo::GetNodeAttr(cnode_ptr_, kAttrFusionType); + MS_LOG(INFO) << "Get fusion_type:" << fusion_type; + (*compile_info_json_)["_pattern"] = fusion_type; + } +} + +void AiCoreDynamicKernel::Initialize() { + DynamicKernel::Initialize(); + ParseCompileJson(); +} + +void AiCoreDynamicKernel::UpdateArgs() { + ComputeTiling(); + + if (!CopyTilingToDevice()) { + MS_LOG(EXCEPTION) << "Copy tiling to device failed"; + } + + AllocateWorkspace(); + + auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_mod); + + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + + runtime_args_.clear(); + + (void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_), + [](const AddressPtr &input) -> void * { return input->addr; }); + (void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_), + [](const AddressPtr &output) -> void * { return output->addr; }); + // Update workspace + if (!workspace_addr_.empty()) { + (void)std::transform(std::begin(workspace_addr_), std::end(workspace_addr_), std::back_inserter(runtime_args_), + [](const DeviceAddressPtr &address_ptr) -> void * { return address_ptr->GetMutablePtr(); }); + } + + if (is_dynamic_shape_ && !tiling_data_.empty() && tiling_data_ptr_ != nullptr) { + runtime_args_.push_back(tiling_data_ptr_); + } +} + +void AiCoreDynamicKernel::ComputeTiling() { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + MS_LOG(INFO) << "Start compute tiling of:" << cnode_ptr_->fullname_with_scope(); + optiling::OpRunInfo op_run_info; + + OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode_ptr_), NOT_NULL(compile_info_json_), + depend_tensor_map_, NOT_NULL(&op_run_info)); + block_dim_ = op_run_info.block_dim; + workspaces_size_ = op_run_info.workspaces; + tiling_data_ = op_run_info.tiling_data.str(); +} + +void AiCoreDynamicKernel::AllocateWorkspace() { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_id = ms_context->get_param(MS_CTX_DEVICE_ID); + auto runtime_instance = KernelRuntimeManager::Instance().GetSingleKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + + workspace_addr_.clear(); + for (auto size : workspaces_size_) { + auto device_address_ptr = std::make_shared(nullptr, size); + auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr); + if (device_ptr == nullptr) { + MS_LOG(EXCEPTION) << "MallocMem from memory pool failed"; + } + workspace_addr_.emplace_back(device_address_ptr); + } +} + +bool AiCoreDynamicKernel::CopyTilingToDevice() { + if (tiling_data_.size() > op_para_size_) { + MS_LOG(EXCEPTION) << "compute tiling size:" << tiling_data_.size() + << " larger than tbe build op_para_size:" << op_para_size_; + } + + if (tiling_data_.empty() || tiling_data_ptr_ == nullptr) { + MS_LOG(INFO) << "tiling size is 0, skip rtMemcpyAsync"; + return true; + } + + auto ret = rtMemcpyAsync(tiling_data_ptr_, tiling_data_.size(), tiling_data_.c_str(), tiling_data_.size(), + RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "tiling rtMemcpyAsync failed, ret:" << ret; + } + return true; +} + +void AiCoreDynamicKernel::PostExecute() {} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.h new file mode 100644 index 0000000000..218ecccb50 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_core_dynamic_kernel.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_ + +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "ir/tensor.h" +#include "runtime/device/device_address.h" +#include "mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AiCoreDynamicKernel : public DynamicKernel { + public: + AiCoreDynamicKernel(const void *stub_fubc, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size, + void *stream, const CNodePtr &cnode_ptr, const std::vector &runtime_args) + : DynamicKernel(stream, cnode_ptr), + stub_func_(stub_fubc), + block_dim_(block_dim), + tiling_data_ptr_(tiling_data_ptr), + op_para_size_(op_para_size), + runtime_args_(runtime_args) {} + ~AiCoreDynamicKernel() override; + + void Execute() override; + void UpdateArgs() override; + void Initialize() override; + void PostExecute() override; + + protected: + void AllocateWorkspace(); + void ParseCompileJson(); + + private: + const void *stub_func_; + uint32_t block_dim_; + void *tiling_data_ptr_; // device ptr + uint32_t op_para_size_; // size of tiling_data_ptr_ + std::vector runtime_args_; + std::string tiling_data_; + std::vector workspaces_size_; + std::vector workspace_addr_; + std::shared_ptr compile_info_json_; + + void ComputeTiling(); + bool CopyTilingToDevice(); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc new file mode 100644 index 0000000000..c083a88ea0 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.cc @@ -0,0 +1,204 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h" +#include +#include +#include +#include "runtime/mem.h" +#include "runtime/kernel.h" +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "runtime/device/ascend/executor/executor_callback.h" + +namespace mindspore { +namespace device { +namespace ascend { +AiCpuDynamicKernel::~AiCpuDynamicKernel() { + // free dev ptr + if (ext_info_addr_dev_ == nullptr) { + return; + } + auto ret = rtFree(ext_info_addr_dev_); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtFree failed"; + } +} + +void AiCpuDynamicKernel::UpdateArgs() { + if (!UpdateInputOutputAddr()) { + MS_LOG(EXCEPTION) << "Update input output failed"; + } + + if (is_dynamic_shape_ && !UpdateExtInfo()) { + MS_LOG(EXCEPTION) << "Update ExtInfo failed"; + } +} + +void AiCpuDynamicKernel::Execute() { + MS_LOG(INFO) << "Execute AiCpuDynamicKerenl Start"; + auto ret = rtCpuKernelLaunchWithFlag( + reinterpret_cast(so_name_.c_str()), reinterpret_cast(kernel_name_.c_str()), 1, + reinterpret_cast(args_.data()), args_.size(), nullptr, stream_, RT_KERNEL_DEFAULT); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtCpuKernelLaunchWithFlag Failed"; + } +} + +void AiCpuDynamicKernel::Initialize() { + // is dynamic + MS_LOG(INFO) << "Initialize node:" << cnode_ptr_->fullname_with_scope(); + DynamicKernel::Initialize(); + + input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_); + output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_); + + // Parse aicpu ext info + if (is_dynamic_shape_) { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + ext_info_handler_ = + std::make_shared(cnode_ptr_->fullname_with_scope(), input_num_, output_num_, DEPEND_COMPUTE); + ext_info_handler_->Parse(ext_info_data_); + } + + if (ext_info_data_.empty()) { + MS_LOG(INFO) << "No need to copy to device, ext_info_data_ is empty. "; + return; + } + + // Allocate ext info addr in device + auto ret = rtMalloc(&ext_info_addr_dev_, ext_info_data_.size(), RT_MEMORY_HBM); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtMalloc ext_info_addr_dev_ failed"; + } + ext_info_size_ = ext_info_data_.size(); + + ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_data_.data(), ext_info_data_.size(), + RT_MEMCPY_HOST_TO_DEVICE); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtMemcpy ext_info_addr_dev_ failed"; + } + + auto aicpu_param_head = reinterpret_cast(args_.data()); + aicpu_param_head->extInfoLength = ext_info_size_; + aicpu_param_head->extInfoAddr = reinterpret_cast(ext_info_addr_dev_); +} + +bool AiCpuDynamicKernel::UpdateInputOutputAddr() { + std::vector io_addrs; + io_addrs.reserve(input_num_ + output_num_); + + for (size_t i = 0; i < input_num_; ++i) { + auto input_addr = AnfAlgo::GetPrevNodeOutputAddr(cnode_ptr_, i); + io_addrs.emplace_back(reinterpret_cast(input_addr->GetMutablePtr())); + } + + for (size_t i = 0; i < output_num_; ++i) { + auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, i); + io_addrs.emplace_back(reinterpret_cast(output_addr->GetMutablePtr())); + } + + if (args_.empty()) { + MS_LOG(ERROR) << "args_ is empty"; + return false; + } + + auto io_ptr = args_.data() + sizeof(kernel::AicpuParamHead); + auto ret = + memcpy_s(io_ptr, args_.size() - sizeof(kernel::AicpuParamHead), &io_addrs[0], sizeof(uint64_t) * io_addrs.size()); + if (ret != 0) { + MS_LOG(EXCEPTION) << "Memcpy input output addr failed"; + } + + return true; +} + +bool AiCpuDynamicKernel::UpdateExtInfo() { + MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " start"; + if (input_num_ == 0 && output_num_ == 0) { + MS_LOG(INFO) << "Node:" << cnode_ptr_->fullname_with_scope() << " no need to update output shape"; + return true; + } + + for (size_t i = 0; i < input_num_; ++i) { + ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_)); + } + + if (unknow_type_ != DEPEND_COMPUTE) { + for (size_t i = 0; i < output_num_; ++i) { + ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_)); + } + } + + auto ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_handler_->GetExtInfo(), + ext_info_handler_->GetExtInfoLen(), RT_MEMCPY_HOST_TO_DEVICE); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "UpdateExtInfo rtMemcpy failed"; + return false; + } + + MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " end"; + return true; +} + +bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() { + if (input_num_ == 0) { + MS_LOG(WARNING) << "input num is 0"; + return true; + } + MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start"; + auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_, + ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST); + if (ret != RT_ERROR_NONE) { + MS_LOG(ERROR) << "rtMemcpy output shape failed"; + return false; + } + + MS_LOG(INFO) << "rtMemcpy from device to host success"; + + std::vector type_ids; + std::vector> shapes; + + for (size_t i = 0; i < output_num_; ++i) { + MS_LOG(INFO) << "Get output:" << output_num_ << " Shape"; + std::vector shape; + TypeId type_id; + ext_info_handler_->GetOutputShapeAndType(i, NOT_NULL(&shape), NOT_NULL(&type_id)); + + for (auto x : shape) { + MS_LOG(INFO) << "Update output:" << i << " shape:" << x; + } + + type_ids.emplace_back(type_id); + std::vector size_t_shape; + std::transform(shape.begin(), shape.end(), std::back_inserter(size_t_shape), LongToSize); + shapes.emplace_back(size_t_shape); + } + + AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode_ptr_.get()); + return true; +} + +void AiCpuDynamicKernel::PostExecute() { + MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute"; + if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ == DEPEND_COMPUTE) { + MS_LOG(INFO) << "Update aicpu kernel output shape from ext_info"; + UpdateOutputShapeFromExtInfo(); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h new file mode 100644 index 0000000000..5eedf097b8 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_ + +#include +#include +#include "runtime/device/executor/dynamic_kernel.h" +#include "ir/anf.h" +#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h" + +namespace mindspore { +namespace device { +namespace ascend { +class AiCpuDynamicKernel : public DynamicKernel { + public: + AiCpuDynamicKernel(void *stream, const CNodePtr &cnode_ptr, const std::string &args, const std::string &ext_info_data, + const std::string &so_name, const std::string &kernel_name) + : DynamicKernel(stream, cnode_ptr), + args_(args), + ext_info_data_(ext_info_data), + so_name_(so_name), + kernel_name_(kernel_name), + ext_info_handler_(nullptr), + ext_info_addr_dev_(nullptr), + ext_info_size_(0), + input_num_(0), + output_num_(0), + unknow_type_(DEPEND_COMPUTE) {} + + ~AiCpuDynamicKernel() override; + + void UpdateArgs() override; + void Execute() override; + void Initialize() override; + void PostExecute() override; + + // Get Compute Shape from ExtInfo + bool UpdateOutputShapeFromExtInfo(); + + private: + std::string args_; + std::string ext_info_data_; + std::string so_name_; + std::string kernel_name_; + + std::shared_ptr ext_info_handler_; + void *ext_info_addr_dev_; + size_t ext_info_size_; + + size_t input_num_; + size_t output_num_; + + UnknowShapeOpType unknow_type_; + + bool UpdateInputOutputAddr(); + bool UpdateExtInfo(); +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.cc b/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.cc new file mode 100644 index 0000000000..7933d49956 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.cc @@ -0,0 +1,218 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "backend/kernel_compiler/aicpu/aicpu_util.h" + +namespace mindspore { +namespace device { +namespace ascend { +namespace { +// if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end. +constexpr int64_t kDimEndFlag = INT64_MIN; +} // namespace +bool AicpuExtInfoHandler::Parse(const std::string &ext_info) { + MS_LOG(INFO) << "Parse Node:" << node_name_ << " start"; + if (ext_info.empty()) { + MS_LOG(ERROR) << "Node:" << node_name_ << " ext_info is empty"; + return false; + } + + ext_info_len_ = ext_info.size(); + ext_info_.reset(new (std::nothrow) uint8_t[ext_info_len_]); + MS_EXCEPTION_IF_NULL(ext_info_); + + (void)memcpy_s(ext_info_.get(), ext_info_len_, ext_info.c_str(), ext_info.size()); + + input_shape_and_type_.clear(); + output_shape_and_type_.clear(); + + auto ext_info_data = ext_info_.get(); + size_t offset = 0; + while (offset + sizeof(AicpuExtInfo) <= ext_info_len_) { + auto aicpu_ext_info = reinterpret_cast(ext_info_data + offset); + MS_EXCEPTION_IF_NULL(aicpu_ext_info); + switch (aicpu_ext_info->infoType) { + case kernel::FWK_ADPT_EXT_SHAPE_TYPE: + if (!ParseExtShapeType(aicpu_ext_info)) { + MS_LOG(EXCEPTION) << "Parse ext shape type failed."; + } + break; + case kernel::FWK_ADPT_EXT_INPUT_SHAPE: + if (!ParseExtInputShape(aicpu_ext_info)) { + MS_LOG(EXCEPTION) << "Parse ext input shape failed."; + } + break; + case kernel::FWK_ADPT_EXT_OUTPUT_SHAPE: + if (!ParseExtOutputShape(aicpu_ext_info)) { + MS_LOG(EXCEPTION) << "Parse ext output shape failed."; + } + break; + default: + MS_LOG(INFO) << "Ignore Node:" << node_name_ << " infoType:" << aicpu_ext_info->infoType + << " infoLen:" << aicpu_ext_info->infoLen; + break; + } + offset += sizeof(AicpuExtInfo); + offset += aicpu_ext_info->infoLen; + } + + if (offset != ext_info_len_) { + MS_LOG(EXCEPTION) << "Node:" << node_name_ << " ext_info format error, parse not reach end, offset=" << offset + << ", ext_info_len" << ext_info_len_; + } + MS_LOG(INFO) << "Node:" << node_name_ << " parse ext info end."; + return true; +} + +bool AicpuExtInfoHandler::ParseExtShapeType(AicpuExtInfo *aicpu_ext_info) { + if (aicpu_ext_info->infoLen != sizeof(int32_t)) { + MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as infoLen must be " << sizeof(int32_t) + << " but got:" << aicpu_ext_info->infoLen; + return false; + } + + auto type = reinterpret_cast(aicpu_ext_info->infoMsg); + + if (*type != unknown_type_) { + MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as need:" << unknown_type_ + << " but got:" << *type; + } + MS_LOG(INFO) << "Node:" << node_name_ << "parse ext shape type success infoLen=" << aicpu_ext_info->infoLen; + return true; +} + +bool AicpuExtInfoHandler::ParseExtInputShape(AicpuExtInfo *aicpu_ext_info) { + auto need_len = input_num_ * sizeof(AicpuShapeAndType); + + if (aicpu_ext_info->infoLen != need_len) { + MS_LOG(ERROR) << "Node:" << node_name_ + << " parse ext input shape failed as aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen + << " and need_len:" << need_len; + } + auto input = reinterpret_cast(aicpu_ext_info->infoMsg); + + for (uint32_t index = 0; index < input_num_; ++index) { + input_shape_and_type_.emplace_back(&input[index]); + } + MS_LOG(INFO) << "Node:" << node_name_.c_str() << " parse ext input shape success infoLen=" << aicpu_ext_info->infoLen; + return true; +} + +bool AicpuExtInfoHandler::ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info) { + auto need_len = output_num_ * sizeof(AicpuShapeAndType); + if (aicpu_ext_info->infoLen != need_len) { + MS_LOG(INFO) << "Node:" << node_name_ + << " parse ext output shape failed, aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen + << " need_len:" << need_len; + return false; + } + + auto output = reinterpret_cast(aicpu_ext_info->infoMsg); + for (uint32_t index = 0; index < output_num_; ++index) { + output_shape_and_type_.emplace_back(&output[index]); + } + MS_LOG(INFO) << "Node:" << node_name_ << " parse ext output shape success infoLen=" << aicpu_ext_info->infoLen; + return true; +} + +bool AicpuExtInfoHandler::UpdateInputShapeAndType(uint32_t input_index, const NotNull &anf_node) { + if (input_index >= input_num_) { + MS_LOG(ERROR) << "input_index=" << input_index << " >= input_num_:" << input_num_; + return false; + } + + auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index); + auto data_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index); + std::vector tmp_shape; + std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(tmp_shape), SizeToLong); + return UpdateShapeAndType(tmp_shape, data_type, NOT_NULL(input_shape_and_type_[input_index])); +} + +bool AicpuExtInfoHandler::UpdateOutputShapeAndType(uint32_t output_index, const NotNull &anf_node) { + if (output_index >= output_num_) { + MS_LOG(ERROR) << "output_index:" << output_index << " >= output_num_:" << output_num_; + return false; + } + + auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index); + auto max_shape = AnfAlgo::GetOutputMaxShape(anf_node, output_index); + if (shape.size() != max_shape.size()) { + MS_LOG(ERROR) << "shape size != max_shape size"; + return true; + } + + for (size_t i = 0; i < shape.size(); ++i) { + if (i < max_shape.size() && shape[i] == SIZE_MAX) { + MS_LOG(INFO) << "Node:" << node_name_ << " update shape from SIZE_MAX to " << max_shape[i]; + shape[i] = max_shape[i]; + } + } + + std::vector tmp_shape; + std::transform(shape.begin(), shape.end(), std::back_inserter(tmp_shape), SizeToLong); + return UpdateShapeAndType(tmp_shape, AnfAlgo::GetOutputDeviceDataType(anf_node, output_index), + NOT_NULL(output_shape_and_type_[output_index])); +} + +bool AicpuExtInfoHandler::GetOutputShapeAndType(uint32_t output_index, NotNull *> shape, + NotNull data_type) { + MS_LOG(INFO) << "Get " << node_name_ << " Output:" << output_index << " Shape And Type"; + GetShapeAndType(NOT_NULL(output_shape_and_type_[output_index]), shape, data_type); + return true; +} + +bool AicpuExtInfoHandler::UpdateShapeAndType(const std::vector &shape, TypeId data_type, + NotNull shape_and_type) { + if (shape.empty() || shape.size() > kernel::kMaxShapeDims) { + MS_LOG(ERROR) << "Invalid shape:" << shape.size(); + return false; + } + + size_t index = 0; + for (; index < shape.size(); ++index) { + shape_and_type->dims[index] = shape[index]; + } + if (index < kernel::kMaxShapeDims) { + shape_and_type->dims[index] = kDimEndFlag; + } + + // now only support update shape, type is not support + return true; +} + +void AicpuExtInfoHandler::GetShapeAndType(NotNull shape_and_type, + NotNull *> shape, NotNull data_type) { + for (int64_t tmpDim : shape_and_type->dims) { + if (tmpDim == kDimEndFlag) { + break; + } + shape->emplace_back(tmpDim); + MS_LOG(INFO) << "Debug tmpDim:" << tmpDim; + } + + auto ms_type = kernel::AicpuOpUtil::ProtoTypeToMsType(shape_and_type->type); + if (ms_type == -1) { + MS_LOG(EXCEPTION) << "Unspport Proto Type:" << shape_and_type->type; + } + MS_LOG(INFO) << "Debug ms_type:" << ms_type; + *data_type = static_cast(ms_type); +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.h b/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.h new file mode 100644 index 0000000000..641d1d3f9c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/aicpu_ext_info_handle.h @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_ + +#include +#include +#include +#include +#include "backend/kernel_compiler/aicpu/aicpu_util.h" +#include "utils/contract.h" + +namespace mindspore { +namespace device { +namespace ascend { +// for unknown shape op type +enum UnknowShapeOpType { + DEPEND_IN_SHAPE = 1, // op out shape get by input shape + DEPEND_CONST_VALUE = 2, // op out shape get by const op value + DEPEND_SHAPE_RANGE = 3, // op out shape get by range + DEPEND_COMPUTE = 4 // op out shape get by totally computing +}; + +using AicpuShapeAndType = kernel::ShapeAndType; +using AicpuExtInfo = kernel::ExtInfo; + +class AicpuExtInfoHandler { + public: + AicpuExtInfoHandler(std::string node_name, uint32_t input_num, uint32_t output_num, UnknowShapeOpType unknown_type) + : node_name_(std::move(node_name)), + input_num_(input_num), + output_num_(output_num), + unknown_type_(unknown_type), + ext_info_len_(0) {} + + ~AicpuExtInfoHandler() = default; + + uint8_t *GetExtInfo() const { return ext_info_.get(); } + size_t GetExtInfoLen() const { return ext_info_len_; } + + bool Parse(const std::string &ext_info); + + bool UpdateInputShapeAndType(uint32_t input_index, const NotNull &anf_node); + + bool UpdateOutputShapeAndType(uint32_t output_index, const NotNull &anf_node); + + bool GetOutputShapeAndType(uint32_t output_index, NotNull *> shape, NotNull data_type); + + private: + bool ParseExtShapeType(AicpuExtInfo *aicpu_ext_info); + bool ParseExtInputShape(AicpuExtInfo *aicpu_ext_info); + bool ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info); + + static bool UpdateShapeAndType(const std::vector &shape, TypeId data_type, + NotNull shape_and_type); + + static void GetShapeAndType(NotNull shape_and_type, NotNull *> shape, + NotNull data_type); + + private: + const std::string node_name_; + const uint32_t input_num_; + const uint32_t output_num_; + UnknowShapeOpType unknown_type_; + size_t ext_info_len_; + + std::unique_ptr ext_info_; + std::vector input_shape_and_type_; + std::vector output_shape_and_type_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.cc b/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.cc new file mode 100644 index 0000000000..33d4bb08e0 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.cc @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/executor_callback.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace device { +namespace ascend { +void ExecutorCallback::RegistCallback(const std::function &callback) { + std::lock_guard guard(lock_); + callback_queue_.push(callback); +} + +void ExecutorCallback::Consume() { + std::lock_guard guard(lock_); + while (!callback_queue_.empty()) { + auto callback_func = callback_queue_.front(); + callback_queue_.pop(); + if (!callback_func) { + MS_LOG(EXCEPTION) << "callback_func is empty"; + } + callback_func(); + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.h b/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.h new file mode 100644 index 0000000000..2994f9b70e --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/executor_callback.h @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_ + +#include +#include +#include +#include "utils/ms_utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +class ExecutorCallback { + public: + static ExecutorCallback &GetInstance() { + static ExecutorCallback instance; + return instance; + } + + void RegistCallback(const std::function &callback); + void Consume(); + + private: + ExecutorCallback() = default; + ~ExecutorCallback() = default; + DISABLE_COPY_AND_ASSIGN(ExecutorCallback); + + std::queue> callback_queue_; + std::mutex lock_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc new file mode 100644 index 0000000000..de6bd7b985 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.cc @@ -0,0 +1,187 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" + +#include +#include +#include "hccl/hcom.h" +#include "common/opskernel/ge_task_info.h" +#include "utils/log_adapter.h" +#include "runtime/device/kernel_runtime.h" +#include "backend/kernel_compiler/hccl/hcom_util.h" + +namespace { +// Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/) +constexpr auto kHcomGraphAdaptorPath = "libhcom_graph_adaptor.so"; +} // namespace + +namespace mindspore { +namespace device { +namespace ascend { +void HcclDynamicKernel::UpdateArgs() { + if (!is_dynamic_shape_) { + MS_LOG(INFO) << "Not Dynamic Shape"; + return; + } + MS_LOG(INFO) << "Start to UpdateArgs"; + auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_mod); + // Update input, output, count + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + if (kernel_inputs.empty() || kernel_outputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs or outputs is empty"; + } + auto input0 = kernel_inputs.at(0); + auto output0 = kernel_outputs.at(0); + MS_EXCEPTION_IF_NULL(input0); + MS_EXCEPTION_IF_NULL(output0); + + // Update Hccl input and output + input_ptr_ = input0->addr; + output_ptr_ = output0->addr; + + std::vector> hccl_kernel_input_shape_list; + if (!HcomUtil::GetKernelInputShape(cnode_ptr_, &hccl_kernel_input_shape_list)) { + MS_LOG(EXCEPTION) << "GetKernelInputShape fail!"; + } + + std::vector hccl_data_type_list; + if (!HcomUtil::GetHcomDataType(cnode_ptr_, &hccl_data_type_list)) { + MS_LOG(EXCEPTION) << "GetHcomDataType fail!"; + } + + // Update Hccl count + if (!HcomUtil::GetHcomCount(cnode_ptr_, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) { + MS_LOG(EXCEPTION) << "GetHcomCount fail!"; + } + MS_LOG(INFO) << "Update Hccl count:" << count_; +} + +void HcclDynamicKernel::StaticShapeExecute() { + MS_EXCEPTION_IF_NULL(cnode_ptr_); + auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_); + MS_EXCEPTION_IF_NULL(kernel_mod); + AddressPtrList kernel_inputs; + AddressPtrList kernel_workspaces; + AddressPtrList kernel_outputs; + KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs); + kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); +} + +void HcclDynamicKernel::Execute() { + MS_LOG(INFO) << "Start Execute"; + if (!is_dynamic_shape_) { + MS_LOG(INFO) << "Not Dynamic, call hcom api"; + StaticShapeExecute(); + return; + } + auto handle = HcclExecutorManager::GetInstance().handle(); + auto EnqueueHcomOperation = + (HcclResult(*)(ge::HcomOpertion, std::function))dlsym(handle, "EnqueueHcomOpertion"); + if (EnqueueHcomOperation == nullptr) { + MS_LOG(ERROR) << "Failed to get EnqueueHcomOperation function"; + if (dlclose(handle) != 0) { + MS_LOG(WARNING) << "Failed to close hcom handle"; + } + MS_LOG(EXCEPTION) << "Hccl dynamic kernel execute failed"; + return; + } + + ge::HcomOpertion op_info; + op_info.hcclType = hccl_type_; + op_info.inputPtr = input_ptr_; + op_info.outputPtr = output_ptr_; + op_info.dataType = data_type_; + op_info.opType = op_type_; + op_info.root = root_; + op_info.count = count_; + + auto callback = [this](HcclResult status) { + if (status != HCCL_SUCCESS) { + MS_LOG(ERROR) << "HcomExcutorInitialize failed, ret:" << status; + } + std::lock_guard lock(this->hccl_mutex_); + this->cond_.notify_all(); + MS_LOG(INFO) << "hccl callback success."; + }; + + auto hccl_ret = EnqueueHcomOperation(op_info, callback); + if (hccl_ret != HCCL_SUCCESS) { + MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed"; + } + + std::unique_lock ulock(hccl_mutex_); + cond_.wait(ulock); + MS_LOG(INFO) << "Execute success"; +} + +void HcclDynamicKernel::PostExecute() {} + +bool HcclExecutorManager::Initialize() { + if (initialized_) { + return true; + } + initialized_ = true; + MS_LOG(INFO) << "Start Initialize Hccl DynamicKernel"; + handle_ = dlopen(kHcomGraphAdaptorPath, RTLD_NOW | RTLD_GLOBAL); + if (handle_ == nullptr) { + MS_LOG(ERROR) << "dlopen failed, path:" << kHcomGraphAdaptorPath; + return false; + } + + auto HcomExecutorInitialize = (HcclResult(*)())dlsym(handle_, "HcomExcutorInitialize"); + if (HcomExecutorInitialize == nullptr) { + MS_LOG(ERROR) << "dlsym HcomExecutorInitialize failed"; + return false; + } + + HcclResult hccl_ret = HcomExecutorInitialize(); + if (hccl_ret == HCCL_E_PTR) { + MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required"; + } else if (hccl_ret == HCCL_SUCCESS) { + MS_LOG(INFO) << "Hcom DynamicKernel Initialize success"; + } else { + MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed"; + return false; + } + return true; +} + +bool HcclExecutorManager::Finalize() { + auto HcomExecutorFinalize = (HcclResult(*)())dlsym(handle_, "HcomExcutorFinalize"); + if (HcomExecutorFinalize == nullptr) { + MS_LOG(ERROR) << "Faile to dlsym HcomExecutorFinalize"; + return false; + } + HcclResult hccl_ret = HcomExecutorFinalize(); + if (hccl_ret != HCCL_SUCCESS) { + MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed"; + return false; + } + if (dlclose(handle_) != 0) { + MS_LOG(ERROR) << "Failed to close hcom handle"; + return false; + } + MS_LOG(INFO) << "Hccl DynamicKernel Finalize failed"; + return true; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h new file mode 100644 index 0000000000..b164bbd986 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/hccl_dynamic_kernel.h @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_ + +#include +#include +#include "runtime/device/executor/dynamic_kernel.h" + +#include "utils/ms_utils.h" + +namespace mindspore { +namespace device { +namespace ascend { +class HcclDynamicKernel : public DynamicKernel { + public: + HcclDynamicKernel(const std::string &hccl_type, void *input_ptr, void *output_ptr, uint64_t count, int32_t data_type, + int32_t op_type, int32_t root, void *stream, const CNodePtr &cnode_ptr) + : DynamicKernel(stream, cnode_ptr), + hccl_type_(hccl_type), + input_ptr_(input_ptr), + output_ptr_(output_ptr), + count_(count), + data_type_(data_type), + op_type_(op_type), + root_(root) {} + ~HcclDynamicKernel() override = default; + void UpdateArgs() override; + void Execute() override; + void PostExecute() override; + + private: + std::string hccl_type_; + void *input_ptr_; + void *output_ptr_; + uint64_t count_{0}; + int32_t data_type_{0}; + int32_t op_type_{0}; + int32_t root_{0}; + std::mutex hccl_mutex_; + std::condition_variable cond_; + + void StaticShapeExecute(); +}; + +class HcclExecutorManager { + public: + static HcclExecutorManager &GetInstance() { + static HcclExecutorManager instance; + return instance; + } + + bool Initialize(); + bool Finalize(); + void *handle() { return handle_; } + + private: + HcclExecutorManager() = default; + ~HcclExecutorManager() = default; + DISABLE_COPY_AND_ASSIGN(HcclExecutorManager); + + void *handle_{nullptr}; + bool initialized_{false}; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/host_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/host_dynamic_kernel.h new file mode 100644 index 0000000000..96b8fffa0b --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/host_dynamic_kernel.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_ + +#include "runtime/device/executor/dynamic_kernel.h" + +namespace mindspore { +namespace device { +namespace ascend { +class HostDynamicKernel : public DynamicKernel { + public: + HostDynamicKernel(void *stream, const CNodePtr &cnode_ptr) : DynamicKernel(stream, cnode_ptr) {} + ~HostDynamicKernel() override = default; + void UpdateArgs() override {} + void Execute() override = 0; + void PostExecute() override {} +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.cc new file mode 100644 index 0000000000..12c5785bc3 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h" + +#include "runtime/mem.h" + +namespace mindspore { +namespace device { +namespace ascend { +void MemcpyRtsDynamicKernel::Execute() { + auto status = rtMemcpyAsync(dst_, dest_max_, src_, count_, RT_MEMCPY_DEVICE_TO_DEVICE, stream_); + if (status != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "MemCpyAsync op rtMemcpyAsync failed!"; + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h new file mode 100644 index 0000000000..c8c2cb1dc2 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_ + +#include "runtime/device/executor/dynamic_kernel.h" + +namespace mindspore { +namespace device { +namespace ascend { +class MemcpyRtsDynamicKernel : public DynamicKernel { + public: + MemcpyRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, void *dst, uint32_t dest_max, void *src, + uint32_t count) + : DynamicKernel(stream, cnode_ptr), dst_(dst), dest_max_(dest_max), src_(src), count_(count) {} + ~MemcpyRtsDynamicKernel() override = default; + + void UpdateArgs() override {} + void Execute() override; + void PostExecute() override {} + + private: + void *dst_; + uint32_t dest_max_; + void *src_; + uint32_t count_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.cc new file mode 100644 index 0000000000..230430cece --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h" + +#include "runtime/base.h" + +namespace mindspore { +namespace device { +namespace ascend { +void ProfilingRtsDynamicKernel::Execute() { + auto rt_ret = rtProfilerTrace(log_id_, notify_, flags_, stream_); + if (rt_ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Call rtProfilerTrace failed"; + } +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h b/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h new file mode 100644 index 0000000000..dd070f365d --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_ + +#include "runtime/device/executor/dynamic_kernel.h" + +namespace mindspore { +namespace device { +namespace ascend { +class ProfilingRtsDynamicKernel : public DynamicKernel { + public: + ProfilingRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, uint64_t log_id, bool notify, uint32_t flags) + : DynamicKernel(stream, cnode_ptr), log_id_(log_id), notify_(notify), flags_(flags) {} + ~ProfilingRtsDynamicKernel() override = default; + + void UpdateArgs() override {} + void Execute() override; + void PostExecute() override {} + + private: + uint64_t log_id_; + bool notify_; + uint32_t flags_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc new file mode 100644 index 0000000000..500fc9c99c --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc @@ -0,0 +1,188 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h" +#include +#include +#include +#include +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "runtime/device/ascend/ge_types_convert.h" +#include "utils/utils.h" +#include "external/graph/tensor.h" + +namespace mindspore { +namespace device { +namespace ascend { +ge::Tensor MakeTempGeTensor(TypeId type_id) { + auto ge_type = GeTypesConvert::TransTypeIdToGeDataType(type_id); + ge::TensorDesc tensor_desc; + tensor_desc.SetDataType(ge_type); + ge::Tensor ge_tensor; + ge_tensor.SetTensorDesc(tensor_desc); + return ge_tensor; +} + +void FeedTeOpTensorInputArg(const NotNull &cnode, + NotNull *> tensor_arg_list) { + MS_LOG(INFO) << "FeedTeOpTensorInputArg start, node:" << cnode->fullname_with_scope(); + auto input_size = AnfAlgo::GetInputTensorNum(cnode.get()); + + // Skip Dynamic Shape Depend Input + + for (size_t i = 0; i < input_size; ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode.get(), i); + auto input_node = input_node_with_index.first; + auto input_index = input_node_with_index.second; + auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index); + auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index); + auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index); + auto iter = type_name_map.find(output_dtype); + if (iter == type_name_map.end()) { + MS_LOG(EXCEPTION) << "Cannot found typeId:" << output_dtype; + } + auto ge_output_dtype = iter->second; + + optiling::TeOpTensorArg tensor_arg; + optiling::TeOpTensor tensor; + tensor_arg.arg_type = optiling::TA_SINGLE; + tensor.dtype = ge_output_dtype; + tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end()); + + tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); + MS_LOG(INFO) << "Tiling Format:" << tensor.format; + tensor_arg.tensor.emplace_back(tensor); + tensor_arg_list->emplace_back(tensor_arg); + } +} + +void FeedTeOpTensorOutputArg(const NotNull &cnode, + NotNull *> tensor_arg_list) { + MS_LOG(INFO) << "FeedTeOpTensorOutputArg start, node:" << cnode->fullname_with_scope(); + auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get()); + for (size_t i = 0; i < output_size; ++i) { + auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i); + auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i); + auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i); + auto iter = type_name_map.find(data_type); + if (iter == type_name_map.end()) { + MS_LOG(EXCEPTION) << "Cannot found typeId:" << data_type; + } + + optiling::TeOpTensorArg tensor_arg; + optiling::TeOpTensor tensor; + tensor_arg.arg_type = optiling::TA_SINGLE; + tensor.dtype = iter->second; + tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end()); + tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); + MS_LOG(INFO) << "Tiling Format:" << tensor.format; + tensor_arg.tensor.emplace_back(tensor); + tensor_arg_list->emplace_back(tensor_arg); + } +} + +void FeedTeOpConstTensor(const NotNull &cnode, const std::map &depend_tensor_map, + NotNull *> const_inputs) { + MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope(); + if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) { + MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope(); + return; + } + + auto depends_list = AnfAlgo::GetNodeAttr>(cnode.get(), kDynamicShapeDepends); + for (auto index : depends_list) { + auto iter = depend_tensor_map.find(IntToSize(index)); + if (iter == depend_tensor_map.end()) { + MS_LOG(EXCEPTION) << "Index not found in depend_tensor_map"; + } + + auto const_tensor = iter->second; + + auto have_input_names_attr = AnfAlgo::HasNodeAttr("input_names", cnode); + if (!have_input_names_attr) { + MS_LOG(EXCEPTION) << "cnode:" << cnode->fullname_with_scope() << " no input_names attr"; + } + auto input_names_attr = AnfAlgo::GetNodeAttr>(cnode.get(), "input_names"); + if (IntToSize(index) >= input_names_attr.size()) { + MS_LOG(EXCEPTION) << "input index" << index << " >= input_name_attr.size:" << input_names_attr.size(); + } + auto input_name = input_names_attr[index]; + MS_LOG(INFO) << "input_name is " << input_name; + auto type_id = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode.get(), index); + const_inputs->try_emplace( + input_name, optiling::TeConstTensorData{static_cast(const_tensor->data_c()), + IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id)}); + } + MS_LOG(INFO) << "FeedTeOpConstTensor end"; +} + +void OpTilingCalculater::Init() { + MS_LOG(INFO) << "Start init OpTilingCalculater"; + tiling_func_map_ = optiling::OpTilingInterf::RegisteredOpInterf(); + MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size(); + for (const auto &iter : tiling_func_map_) { + MS_LOG(INFO) << "Regist tiling func:" << iter.first; + } +} + +std::string GetRealOpType(const std::string &op_type) { + static const std::map kOpTypeMap = { + {"SparseApplyFtrl", "SparseApplyFtrlD"}, + }; + auto iter = kOpTypeMap.find(op_type); + if (iter == kOpTypeMap.end()) { + return op_type; + } + return iter->second; +} + +void OpTilingCalculater::CalculateTiling(const NotNull &cnode, + const NotNull> &compile_info_json, + const std::map &depend_tensor_map, + NotNull op_run_info) { + optiling::TeOpParas op_param; + std::string op_type = AnfAlgo::GetCNodeName(cnode.get()); + MS_LOG(INFO) << "[DynamicShape] calculate tiling, op_type:" << op_type; + + FeedTeOpTensorInputArg(cnode, NOT_NULL(&op_param.inputs)); + FeedTeOpTensorOutputArg(cnode, NOT_NULL(&op_param.outputs)); + FeedTeOpConstTensor(cnode, depend_tensor_map, NOT_NULL(&op_param.const_inputs)); + + op_type = GetRealOpType(op_type); + auto iter = tiling_func_map_.find(op_type); + if (iter == tiling_func_map_.end()) { + iter = tiling_func_map_.find("AutoTiling"); + if (iter == tiling_func_map_.end()) { + MS_LOG(EXCEPTION) << "AutoTiling Func Not Found"; + } + } + + MS_LOG(INFO) << "Get tiling func:" << iter->first; + + if (iter != tiling_func_map_.end()) { + bool ret = (iter->second)(op_type, op_param, *compile_info_json.get(), *op_run_info); + if (!ret) { + MS_LOG(EXCEPTION) << "Calculate tiling failed"; + } + } else { + MS_LOG(EXCEPTION) << "Tiling func not found"; + } + MS_LOG(INFO) << "CalculateTiling success"; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h new file mode 100644 index 0000000000..0331ae3757 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_ + +#include +#include +#include +#include "utils/ms_utils.h" +#include "utils/contract.h" +#include "ir/anf.h" +#include "ir/tensor.h" +#include "register/op_tiling.h" + +namespace mindspore { +namespace device { +namespace ascend { +class OpTilingCalculater { + public: + static OpTilingCalculater &GetInstance() { + static OpTilingCalculater instance; + return instance; + } + + void Init(); + void CalculateTiling(const NotNull &cnode, + const NotNull> &compile_info_json, + const std::map &depend_tensor_map, + NotNull op_run_info); + + private: + OpTilingCalculater() = default; + ~OpTilingCalculater() = default; + DISABLE_COPY_AND_ASSIGN(OpTilingCalculater); + + std::map tiling_func_map_; +}; +} // namespace ascend +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_ diff --git a/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc new file mode 100644 index 0000000000..8cd5fa9aed --- /dev/null +++ b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.cc @@ -0,0 +1,137 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/ge_types_convert.h" + +namespace mindspore { +namespace device { +namespace ascend { +ge::proto::DataType GeTypesConvert::GetGeDataType(TypeId type_id) { + static const std::map data_type_map = { + {TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT}, + {TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8}, + {TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16}, + {TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32}, + {TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32}, + {TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL}, + {TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE}, + }; + MS_LOG(INFO) << "Vm origin type_id:" << type_id; + auto iter = data_type_map.find(type_id); + if (iter == data_type_map.end()) { + MS_LOG(EXCEPTION) << "Invalid data type:" << type_id; + } + return iter->second; +} + +ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) { + static const std::map data_type_map = { + {TypeId::kNumberTypeFloat, ge::DataType::DT_FLOAT}, {TypeId::kNumberTypeFloat32, ge::DataType::DT_FLOAT}, + {TypeId::kNumberTypeFloat16, ge::DataType::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::DataType::DT_INT8}, + {TypeId::kNumberTypeInt16, ge::DataType::DT_INT16}, {TypeId::kNumberTypeUInt16, ge::DataType::DT_UINT16}, + {TypeId::kNumberTypeUInt8, ge::DataType::DT_UINT8}, {TypeId::kNumberTypeInt32, ge::DataType::DT_INT32}, + {TypeId::kNumberTypeInt, ge::DataType::DT_INT32}, {TypeId::kNumberTypeInt64, ge::DataType::DT_INT64}, + {TypeId::kNumberTypeUInt32, ge::DataType::DT_UINT32}, {TypeId::kNumberTypeUInt, ge::DataType::DT_UINT32}, + {TypeId::kNumberTypeUInt64, ge::DataType::DT_UINT64}, {TypeId::kNumberTypeBool, ge::DataType::DT_BOOL}, + {TypeId::kNumberTypeInt64, ge::DataType::DT_DOUBLE}, {TypeId::kTypeUnknown, ge::DataType::DT_UNDEFINED}}; + auto iter = data_type_map.find(type_id); + if (iter == data_type_map.end()) { + MS_LOG(EXCEPTION) << "Invalid data type:" << type_id; + } + return iter->second; +} + +GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) { + static const std::map format_map = { + // default format: nchw, fractal_nz? + {kOpFormat_DEFAULT, kFormat_NCHW}, + {kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, + {kOpFormat_ND, kFormat_ND}, + {kOpFormat_NCHW, kFormat_NCHW}, + {kOpFormat_NHWC, kFormat_NHWC}, + {kOpFormat_HWCN, kFormat_HWCN}, + {kOpFormat_NC1HWC0, kFormat_NC1HWC0}, + {kOpFormat_FRAC_Z, kFormat_FRACTAL_Z}, + {kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ}, + {kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0}, + {kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04}, + {kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04}, + {kOpFormat_NDHWC, kFormat_NDHWC}, + }; + MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; + if (format == kOpFormat_DEFAULT) { + return shape_size == 4 ? kFormat_NCHW : kFormat_ND; + } + auto iter = format_map.find(format); + if (iter == format_map.end()) { + MS_LOG(EXCEPTION) << "Invalid format:" << format; + } + return iter->second; +} + +std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) { + static const std::map kFormatToStringMap = { + {kFormat_NCHW, "NCHW"}, + {kFormat_NHWC, "NHWC"}, + {kFormat_ND, "ND"}, + {kFormat_NC1HWC0, "NC1HWC0"}, + {kFormat_FRACTAL_Z, "FRACTAL_Z"}, + {kFormat_NC1C0HWPAD, "NC1C0HWPAD"}, + {kFormat_NHWC1C0, "NHWC1C0"}, + {kFormat_FSR_NCHW, "FSR_NCHW"}, + {kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"}, + {kFormat_C1HWNC0, "C1HWNC0"}, + {kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, + {kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, + {kFormat_NC1HWC0_C04, "NC1HWC0_C04"}, + {kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, + {kFormat_CHWN, "CHWN"}, + {kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, + {kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"}, + {kFormat_BN_WEIGHT, "BN_WEIGHT"}, + {kFormat_FILTER_HWCK, "FILTER_HWCK"}, + {kFormat_HWCN, "HWCN"}, + {kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, + {kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, + {kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, + {kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, + {kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, + {kFormat_MD, "MD"}, + {kFormat_NDHWC, "NDHWC"}, + {kFormat_NCDHW, "NCDHW"}, + {kFormat_DHWCN, "DHWCN"}, + {kFormat_DHWNC, "DHWNC"}, + {kFormat_NDC1HWC0, "NDC1HWC0"}, + {kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, + {kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, + {kFormat_C1HWNCoC0, "C1HWNCoC0"}, + {kFormat_FRACTAL_NZ, "FRACTAL_NZ"}, + {kFormat_CN, "CN"}, + {kFormat_NC, "NC"}, + {kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, + {kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"}, + {kFormat_RESERVED, "FORMAT_RESERVED"}, + {kFormat_ALL, "ALL"}}; + + auto iter = kFormatToStringMap.find(ge_format); + if (iter == kFormatToStringMap.end()) { + MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format; + } + return iter->second; +} +} // namespace ascend +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h similarity index 51% rename from mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h rename to mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h index 60af609a48..0881ced928 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/ge_dump.h +++ b/mindspore/ccsrc/runtime/device/ascend/ge_types_convert.h @@ -22,28 +22,11 @@ #include "proto/ge_dtype.pb.h" #include "ir/dtype/type_id.h" #include "utils/utils.h" +#include "external/graph/types.h" namespace mindspore { namespace device { namespace ascend { -static ge::proto::DataType GetGeDataType(TypeId type_id) { - static const std::map data_type_map = { - {TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT}, - {TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8}, - {TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16}, - {TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32}, - {TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32}, - {TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL}, - {TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE}, - }; - MS_LOG(INFO) << "Vm origin type_id:" << type_id; - auto iter = data_type_map.find(type_id); - if (iter == data_type_map.end()) { - MS_LOG(EXCEPTION) << "Invalid data type:" << type_id; - } - return iter->second; -} - enum GeFormat { kFormat_NCHW = 0, // NCHW kFormat_NHWC, // NHWC @@ -83,37 +66,21 @@ enum GeFormat { kFormat_NC, kFormat_DHWNC, kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + kFormat_FRACTAL_ZN_LSTM, + kFormat_FRACTAL_Z_G, kFormat_RESERVED, kFormat_ALL }; -static GeFormat GetGeFormat(const std::string &format, size_t shape_size) { - static const std::map format_map = { - // default format: nchw, fractal_nz? - {kOpFormat_DEFAULT, kFormat_NCHW}, - {kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0}, - {kOpFormat_ND, kFormat_ND}, - {kOpFormat_NCHW, kFormat_NCHW}, - {kOpFormat_NHWC, kFormat_NHWC}, - {kOpFormat_HWCN, kFormat_HWCN}, - {kOpFormat_NC1HWC0, kFormat_NC1HWC0}, - {kOpFormat_FRAC_Z, kFormat_FRACTAL_Z}, - {kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ}, - {kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0}, - {kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04}, - {kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04}, - {kOpFormat_NDHWC, kFormat_NDHWC}, - }; - MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size; - if (format == kOpFormat_DEFAULT) { - return shape_size == 4 ? kFormat_NCHW : kFormat_ND; - } - auto iter = format_map.find(format); - if (iter == format_map.end()) { - MS_LOG(EXCEPTION) << "Invalid format:" << format; - } - return iter->second; -} +class GeTypesConvert { + public: + GeTypesConvert() = default; + ~GeTypesConvert() = default; + static ge::proto::DataType GetGeDataType(TypeId type_id); + static GeFormat GetGeFormat(const std::string &format, size_t shape_size); + static std::string GetGeTilingFormat(GeFormat ge_format); + static ge::DataType TransTypeIdToGeDataType(TypeId type_id); +}; } // namespace ascend } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc index 833104a1c5..58d0480e48 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -27,6 +27,7 @@ #include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h" #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h" #include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h" +#include "backend/kernel_compiler/host/host_kernel_build.h" #include "backend/kernel_compiler/hccl/hccl_kernel_build.h" #include "backend/kernel_compiler/rts/rt_kernel_build.h" #include "backend/kernel_compiler/tbe/tbe_utils.h" @@ -47,6 +48,10 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) { kernel_mod_ptr = kernel::AicpuOpBuild(anf_node); break; } + case KernelType::HOST_KERNEL: { + kernel_mod_ptr = kernel::HostOpBuild(anf_node); + break; + } case KernelType::RT_KERNEL: { kernel_mod_ptr = kernel::RtOpBuild(anf_node); break; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index 5a423a62ae..03f1153821 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -22,6 +22,10 @@ #include #include #include +#include +#include +#include "utils/ms_utils.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "debug/anf_ir_dump.h" #include "frontend/operator/ops.h" #include "utils/ms_context.h" @@ -493,7 +497,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co } // we set special device info of a input tensor. bool is_ref = false; - auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE); + auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); if (op_info != nullptr) { is_ref = op_info->is_ref(); } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index 4e89be16d4..5f3058d6fa 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -44,6 +44,8 @@ class CPUKernelRuntime : public KernelRuntime { VectorRef *outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); + bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } + bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } protected: bool SyncStream() override { return true; }; diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc new file mode 100644 index 0000000000..cec87c2f19 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/executor/dynamic_kernel.h" +#include +#include "backend/session/anf_runtime_algorithm.h" +#include "common/trans.h" +#include "pipeline/jit/static_analysis/static_analysis.h" +#include "abstract/dshape.h" +#include "abstract/param_validator.h" + +namespace mindspore { +namespace device { +void DynamicKernel::Initialize() { + MS_LOG(INFO) << "Init Start"; + is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode_ptr_); + if (!is_dynamic_shape_) { + MS_LOG(INFO) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope(); + return; + } + + is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape); + is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape); + + auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_); + if (!have_depends) { + MS_LOG(WARNING) << "No dynamic_shape_depends found"; + return; + } + MS_LOG(INFO) << "Have depends"; + auto depends_list = AnfAlgo::GetNodeAttr>(cnode_ptr_, kDynamicShapeDepends); + // Save depend input tensor. Sync data in InferShape. + for (auto depend : depends_list) { + auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend); + auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend); + std::vector shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second); + auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second); + auto out_tensor = std::make_shared(host_type, shapes); + out_tensor->set_device_address(output_addr); + + auto ret = depend_tensor_map_.try_emplace(depend, out_tensor); + if (!ret.second) { + MS_LOG(EXCEPTION) << "Insert map failed"; + } + } + MS_LOG(INFO) << "Init End"; +} + +bool IsTupleGetItem(const AnfNodePtr &anf_node) { + MS_EXCEPTION_IF_NULL(anf_node); + if (!anf_node->isa()) { + return false; + } + auto cnode = anf_node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input0 = cnode->input(0); + return IsPrimitive(input0, prim::kPrimTupleGetItem); +} + +void DynamicKernel::InferShape() { + if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) { + return; + } + MS_EXCEPTION_IF_NULL(cnode_ptr_); + MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope(); + + auto inputs = cnode_ptr_->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Invalid inputs"; + } + AbstractBasePtrList args_spec_list; + auto primitive = GetValueNode(inputs[0]); + + auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_); + for (size_t i = 0; i < input_size; ++i) { + auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i); + auto real_input = input_with_index.first; + + MS_EXCEPTION_IF_NULL(real_input); + auto ret = depend_tensor_map_.find(i); + if (ret != depend_tensor_map_.end()) { + auto tensor_ptr = ret->second; + MS_EXCEPTION_IF_NULL(tensor_ptr); + // sync data from device to host + tensor_ptr->data_sync(); + real_input->abstract()->set_value(tensor_ptr); + } + + auto cnode_input = cnode_ptr_->input(i + 1); + MS_EXCEPTION_IF_NULL(cnode_input); + if (IsTupleGetItem(cnode_input)) { + auto base_shape = real_input->Shape(); + if (!base_shape->isa()) { + MS_LOG(EXCEPTION) << "Node:" << cnode_ptr_->fullname_with_scope() + << " input is a tuple_get_item but real input node shape is not a TupleShape"; + } + auto tuple_ptr = base_shape->cast(); + MS_EXCEPTION_IF_NULL(tuple_ptr); + auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast()); + auto real_shape = tuple_ptr->shape().at(tuple_get_item_index); + auto abstract_tensor = cnode_input->abstract()->cast(); + MS_EXCEPTION_IF_NULL(abstract_tensor); + args_spec_list.emplace_back(std::make_shared(abstract_tensor->element(), real_shape)); + } else if (cnode_input->isa() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) { + args_spec_list.emplace_back(cnode_input->abstract()); + } else { + args_spec_list.emplace_back(real_input->abstract()); + } + } + + auto eval_result = abstract::CppInferShape(primitive, args_spec_list); + cnode_ptr_->set_abstract(eval_result); +} +} // namespace device +} // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h new file mode 100644 index 0000000000..4689126af6 --- /dev/null +++ b/mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_ +#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/tensor.h" + +namespace mindspore { +namespace device { + +constexpr auto kDynamicShapeDepends = "dynamic_shape_depends"; + +class DynamicKernel { + public: + DynamicKernel(void *stream, const CNodePtr &cnode_ptr) + : stream_(stream), + cnode_ptr_(cnode_ptr), + is_dynamic_shape_(false), + is_input_dynamic_shape_(false), + is_output_dynamic_shape_(false) {} + virtual ~DynamicKernel() = default; + virtual void InferShape(); + virtual void UpdateArgs() = 0; + virtual void Execute() = 0; + virtual void PostExecute() = 0; + bool is_dynamic_shape() const { return is_dynamic_shape_; } + bool is_input_dynamic_shape() const { return is_input_dynamic_shape_; } + bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; } + bool have_depends() const { return !depend_tensor_map_.empty(); } + virtual void Initialize(); + std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); } + + protected: + void *stream_; + const CNodePtr cnode_ptr_; + bool is_dynamic_shape_; + bool is_input_dynamic_shape_; + bool is_output_dynamic_shape_; + std::map depend_tensor_map_; +}; +using DynamicKernelPtr = std::shared_ptr; +} // namespace device +} // namespace mindspore +#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_ diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index 8adc434110..55c977cd31 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -43,6 +43,8 @@ class GPUKernelRuntime : public KernelRuntime { const std::vector &execution_order) override; void AssignMemory(session::KernelGraph *graph) override; bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override; + bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } + bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index a8ac165e5f..a77089314b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -124,6 +124,10 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr return; } MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + if (kernel_graph_ptr->is_dynamic_shape()) { + MS_LOG(INFO) << "KernelGraph:" << kernel_graph_ptr->graph_id() << " is dynamic shape, skip InsertSwitchLoop"; + return; + } bool eos_mode = ConfigManager::GetInstance().iter_num() == INT32_MAX; ReorderGetNext(kernel_graph_ptr); std::map switch_loop_input; @@ -513,6 +517,10 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptris_dynamic_shape()) { + MS_LOG(INFO) << "Skip StepLoadCtrlInputs"; + return true; + } auto input_nodes = kernel_graph_ptr->inputs(); std::vector inputs; LoadSwitchInputs(&inputs); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index aa36e62144..50581cb760 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -290,6 +290,7 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); + MS_LOG(INFO) << "AssignStaticMemoryInput start"; auto graph_inputs = graph->inputs(); auto graph_valid_input = graph->valid_inputs(); graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); @@ -334,6 +335,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } auto tensor_size = CountNodeDeviceMemorySize(item, index); auto address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); + MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); if (mem_manager_->MallocMem(kStaticMem, tensor_size, address) == nullptr) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } @@ -342,10 +344,12 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { AnfAlgo::SetOutputAddr(address, index, item.get()); } } + MS_LOG(INFO) << "AssignStaticMemoryInput end"; } void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "AssignStaticMemoryOutput start"; auto nodes = AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}); std::vector non_communication_op; // Assign Communicate Op Memory firstly. @@ -363,8 +367,10 @@ void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph *graph) } for (const auto &item_with_index : non_communication_op) { + MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope(); AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second)); } + MS_LOG(INFO) << "AssignStaticMemoryOutput end"; } void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) { @@ -553,6 +559,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in MS_LOG(INFO) << "Already malloc index:" << i; continue; } + MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memeory size:" << output_sizes[i]; std::string output_format = AnfAlgo::GetOutputFormat(node, i); auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); @@ -612,6 +619,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); + MS_LOG(INFO) << "AssignStaticMemoryValueNode start"; auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); for (auto &value_node : graph->graph_value_nodes()) { @@ -622,6 +630,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { } auto &node_value = value_node->value(); MS_EXCEPTION_IF_NULL(node_value); + MS_LOG(DEBUG) << "Malloc memeory for " << value_node->fullname_with_scope(); if (node_value->isa() || node_value->isa()) { AssignValueNodeTensor(value_node, node_value, 0); } else if (node_value->isa()) { @@ -643,6 +652,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { } } } + MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; } void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 1cc4717fe3..706fe3373a 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -32,6 +32,7 @@ #include "backend/kernel_compiler/kernel.h" #include "utils/ms_context.h" #include "runtime/device/memory_manager.h" +#include "runtime/device/executor/dynamic_kernel.h" using mindspore::tensor::Tensor; using std::vector; @@ -58,6 +59,8 @@ class KernelRuntime { virtual bool LoadData(session::KernelGraph *graph, Debugger *debugger); virtual bool Load(session::KernelGraph *graph, bool is_task_sink); virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0; + virtual bool GenDynamicKernel(const session::KernelGraph *graph) = 0; + virtual bool RunDynamicKernelAsync(const session::KernelGraph *graph) = 0; bool LaunchKernel(const session::KernelGraph *graph); bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs, const AddressPtrList &kernel_outputs, @@ -73,6 +76,12 @@ class KernelRuntime { virtual bool SyncStream() = 0; virtual void ClearGlobalIdleMem() {} virtual void SetContext() {} + uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { + return mem_manager_->MallocMem(type, size, address); + } + static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, + AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, + AddressPtrList *kernel_outputs); // for GPU and D to impl virtual void ReleaseDeviceRes() {} @@ -100,10 +109,8 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, - AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); - void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); + static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); void RunOpAssignOutputMemory(const AnfNodePtr &kernel); @@ -119,6 +126,7 @@ class KernelRuntime { #endif void *stream_ = nullptr; std::shared_ptr mem_manager_{nullptr}; + std::map> graph_dynamic_kernel_map_; }; using KernelRuntimePtr = std::shared_ptr; } // namespace device diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 9953745cf3..3d50e07624 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -223,6 +223,12 @@ constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell"; constexpr auto kDynamicRNNOpName = "DynamicRNN"; constexpr auto kLSTMInputGradOpName = "LSTMInputGrad"; +// Hcom Op Type +constexpr auto kHcomOpTypeAllReduce = "HcomAllReduce"; +constexpr auto kHcomOpTypeAllGather = "HcomAllGather"; +constexpr auto kHcomOpTypeBroadcast = "HcomBroadcast"; +constexpr auto kHcomOpTypeReduceScatter = "HcomReduceScatter"; + // attr key name constexpr auto kAttrInputNames = "input_names"; constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel"; @@ -301,6 +307,10 @@ constexpr auto kAttrNumSegments = "num_segments"; constexpr auto kAttrBegin = "begin"; constexpr auto kAttrSize = "size"; constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; +constexpr auto kAttrInputIsDynamicShape = "input_is_dynamic_shape"; +constexpr auto kAttrOutputIsDynamicShape = "output_is_dynamic_shape"; +constexpr auto kAttrCompileInfo = "compile_info"; +constexpr auto kAttrFusionType = "fusion_type"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 19bcd3fa8e..13f6a370d8 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -80,6 +80,16 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti AbstractBasePtr InferImplMinOrMaxGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + +AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); @@ -99,6 +109,8 @@ AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -179,14 +191,33 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); template AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) { diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index fb7684eb98..3fc2fad836 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -23,42 +23,6 @@ namespace mindspore { namespace abstract { -namespace { -ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { - int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); - if (dlen < 0) { - for (int i = 0; i < -dlen; ++i) { - (void)shpx.insert(shpx.begin(), 1); - } - } else if (dlen > 0) { - for (int i = 0; i < dlen; i++) { - (void)shpy.insert(shpy.begin(), 1); - } - } - if (shpx.size() != shpy.size()) { - MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; - } - ShapeVector shp; - for (size_t i = 0; i < shpx.size(); i++) { - auto a = shpx[i]; - auto b = shpy[i]; - if (a == 1) { - shp.push_back(b); - } else if (b == 1) { - shp.push_back(a); - } else if (a == -1) { - shp.push_back(b); - } else if (b == -1) { - shp.push_back(a); - } else if (a == b) { - shp.push_back(a); - } else { - return ShapeVector(); - } - } - return shp; -} -} // namespace AbstractBasePtr InferImplScalarToArray(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a scalar. @@ -229,17 +193,123 @@ AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePt return std::make_shared(ids->element(), ids_idx->shape()); } +AbstractBasePtr InferImplUnsortedSegmentSum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto x_shape = x->shape()->shape(); + + auto segment_ids = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(segment_ids); + MS_EXCEPTION_IF_NULL(segment_ids->shape()); + auto segment_ids_shape = segment_ids->shape()->shape(); + auto num_segments = CheckArg(op_name, args_spec_list, 2); + + std::vector shape; + auto num_segments_value = num_segments->BuildValue(); + MS_EXCEPTION_IF_NULL(num_segments_value); + if (!num_segments_value->isa()) { + MS_LOG(WARNING) << num_segments_value << "evaluator num_segments_value should be tensor, but got " + << num_segments_value->type_name(); + shape.emplace_back(-1); + } else { + auto num_segments_tensor = num_segments_value->cast(); + int value = *(static_cast(num_segments_tensor->data_c())); + MS_LOG(INFO) << "Infer UnsortedSegmentSum output shape:" << value; + shape.emplace_back(value); + } + + shape.insert(shape.end(), x_shape.begin() + segment_ids_shape.size(), x_shape.end()); + + AbstractTensorPtr ret = std::make_shared(x->element(), std::make_shared(shape)); + return ret; +} + +AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + return std::make_shared(x->element(), x->shape()); +} + +AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + return std::make_shared(x->element(), x->shape()); +} + +AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + std::vector x_shape = x->shape()->shape(); + std::vector y_shape = y->shape()->shape(); + std::vector out_shape = BroadcastShape(x_shape, y_shape); + return std::make_shared(x->element(), std::make_shared(out_shape)); +} + +AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + auto y = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + MS_EXCEPTION_IF_NULL(y); + MS_EXCEPTION_IF_NULL(y->shape()); + std::vector x_shape = x->shape()->shape(); + std::vector y_shape = y->shape()->shape(); + std::vector out_shape = BroadcastShape(x_shape, y_shape); + if (out_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + return std::make_shared(x->element(), std::make_shared(out_shape)); +} + AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); CheckArgsSize(op_name, args_spec_list, 3); AbstractTensorPtr params = CheckArg(op_name, args_spec_list, 0); AbstractTensorPtr indices = CheckArg(op_name, args_spec_list, 1); - AbstractScalarPtr axis = CheckArg(op_name, args_spec_list, 2); + + int axis_val = 0; + // 3rd input is a Tensor when GatherV2 is a dynamic shape operator + if (args_spec_list[2]->isa()) { + auto axis = args_spec_list[2]->cast(); + MS_EXCEPTION_IF_NULL(axis); + auto axis_value_ptr = axis->BuildValue(); + MS_EXCEPTION_IF_NULL(axis_value_ptr); + auto axis_tensor = axis_value_ptr->cast(); + MS_EXCEPTION_IF_NULL(axis_tensor); + axis_val = *static_cast(axis_tensor->data_c()); + } else if (args_spec_list[2]->isa()) { + auto axis = args_spec_list[2]->cast(); + axis_val = GetValue(axis->BuildValue()); + } else { + MS_LOG(EXCEPTION) << "Invalid abstract type:" << args_spec_list[2]->type_name(); + } auto params_shp = params->shape()->shape(); auto indices_shp = indices->shape()->shape(); - auto axis_val = GetValue(axis->BuildValue()); auto params_rank = static_cast(params_shp.size()); if (axis_val < 0) { @@ -265,6 +335,25 @@ AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr return std::make_shared(params->element(), std::make_shared(out_shape)); } +AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string &op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + AbstractTensorPtr input = CheckArg(op_name, args_spec_list, 0); + auto shape = input->shape()->shape(); + + AbstractBasePtrList elements; + for (const auto &dim : shape) { + if (dim == Shape::SHP_ANY) { + elements.push_back(std::make_shared(std::make_shared(), std::make_shared(32))); + } else { + elements.push_back(std::make_shared(dim)); + } + } + + return std::make_shared(elements); +} + AbstractBasePtr InferImplDynamicShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); diff --git a/mindspore/core/abstract/prim_maths.cc b/mindspore/core/abstract/prim_maths.cc index 1a3d8f6c71..230486216a 100644 --- a/mindspore/core/abstract/prim_maths.cc +++ b/mindspore/core/abstract/prim_maths.cc @@ -46,5 +46,55 @@ AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &pri auto inp = CheckArg(op_name, args_spec_list, 0); return inp->Clone()->Broaden(); } + +AbstractBasePtr InferImplMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ShapePtr shape_x = dyn_cast(args_spec_list[0]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_x); + std::vector x_dims = shape_x->shape(); + ShapePtr shape_y = dyn_cast(args_spec_list[1]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_y); + std::vector y_dims = shape_y->shape(); + auto broadcast_shape = BroadcastShape(x_dims, y_dims); + if (broadcast_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + auto out = args_spec_list[0]->Broaden(); + out->set_shape(std::make_shared(broadcast_shape)); + return out; +} + +AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: two tensors. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + ShapePtr shape_x = dyn_cast(args_spec_list[0]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_x); + std::vector x_dims = shape_x->shape(); + ShapePtr shape_y = dyn_cast(args_spec_list[1]->GetShapeTrack()); + MS_EXCEPTION_IF_NULL(shape_y); + std::vector y_dims = shape_y->shape(); + auto broadcast_shape = BroadcastShape(x_dims, y_dims); + if (broadcast_shape.empty()) { + MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," + << args_spec_list[1]->ToString(); + } + auto out = args_spec_list[0]->Broaden(); + out->set_shape(std::make_shared(broadcast_shape)); + return out; +} + +AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: one tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + return args_spec_list[0]->Broaden(); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 868c50aa7c..68bab92e40 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -27,6 +27,10 @@ #include "utils/symbolic.h" #include "utils/shape_utils.h" +namespace { +constexpr auto kRankSize = "rank_size"; +} + namespace mindspore { namespace abstract { AbstractBasePtr InferImplIdentity(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -362,5 +366,69 @@ AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, co MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape()); return sparse_tensor->dense_shape(); } + +AbstractBasePtr InferImplAllReduce(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + return std::make_shared(x->element(), std::make_shared(x->shape()->shape())); +} + +AbstractBasePtr InferImplBroadcast(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + return std::make_shared(x->element(), std::make_shared(x->shape()->shape())); +} + +AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto tmp_shape = x->shape()->shape(); + if (!primitive->HasAttr(kRankSize)) { + MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr"; + } + auto rank_size = GetValue(primitive->GetAttr(kRankSize)); + if (rank_size == 0) { + MS_LOG(EXCEPTION) << "rank_size is 0"; + } + if (tmp_shape.empty()) { + MS_LOG(EXCEPTION) << "shape size is 0"; + } + if (tmp_shape[0] % rank_size != 0) { + MS_LOG(EXCEPTION) << "first dimension of x should be divided by rank_size"; + } + tmp_shape[0] = tmp_shape[0] / rank_size; + return std::make_shared(x->element(), std::make_shared(tmp_shape)); +} + +AbstractBasePtr InferImplReduceScatter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 1); + auto x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + auto tmp_shape = x->shape()->shape(); + if (!primitive->HasAttr(kRankSize)) { + MS_LOG(EXCEPTION) << "Primitive don't have rank_size attr"; + } + auto rank_size = GetValue(primitive->GetAttr(kRankSize)); + if (tmp_shape.empty()) { + MS_LOG(EXCEPTION) << "shape size is 0"; + } + tmp_shape[0] = IntMulWithOverflowCheck(tmp_shape[0], rank_size); + return std::make_shared(x->element(), std::make_shared(tmp_shape)); +} } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index dfeda67712..bf92b9ca0a 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -37,6 +37,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { // Maths {prim::kPrimMaximumGrad, {InferImplMinOrMaxGrad, true}}, {prim::kPrimMinimumGrad, {InferImplMinOrMaxGrad, true}}, + + {prim::kPrimMul, {InferImplMul, true}}, + {prim::kPrimTensorAdd, {InferImplTensorAdd, true}}, + {prim::kPrimSquare, {InferImplSquare, true}}, + {prim::kPrimSqrt, {InferImplSqrt, true}}, // Array {prim::kPrimScalarToArray, {InferImplScalarToArray, true}}, @@ -47,6 +52,12 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUniqueGrad, {InferImplUniqueGrad, true}}, {prim::kPrimGatherV2, {InferImplGatherV2, true}}, {prim::kPrimSparseGatherV2, {InferImplGatherV2, true}}, + {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, + {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, + {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, + {prim::kPrimDiv, {InferImplDiv, true}}, + {prim::kPrimRealDiv, {InferImplRealDiv, true}}, + {prim::kPrimShape, {InferImplShape, false}}, {prim::kPrimDynamicShape, {InferImplDynamicShape, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, @@ -109,6 +120,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, + // Comm Ops + {prim::kPrimAllReduce, {InferImplAllReduce, true}}, + {prim::kPrimBroadcast, {InferImplBroadcast, true}}, + {prim::kPrimAllGather, {InferImplAllGather, true}}, + {prim::kPrimReduceScatter, {InferImplReduceScatter, true}}, }; return prim_eval_implement_map; } diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 100badd0f3..2262ed1ade 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -245,6 +245,41 @@ ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVecto return broadcast_shape; } +ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { + int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); + if (dlen < 0) { + for (int i = 0; i < -dlen; ++i) { + (void)shpx.insert(shpx.begin(), 1); + } + } else if (dlen > 0) { + for (int i = 0; i < dlen; i++) { + (void)shpy.insert(shpy.begin(), 1); + } + } + if (shpx.size() != shpy.size()) { + MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; + } + ShapeVector shp; + for (size_t i = 0; i < shpx.size(); i++) { + auto a = shpx[i]; + auto b = shpy[i]; + if (a == 1) { + shp.push_back(b); + } else if (b == 1) { + shp.push_back(a); + } else if (a == -1) { + shp.push_back(b); + } else if (b == -1) { + shp.push_back(a); + } else if (a == b) { + shp.push_back(a); + } else { + return ShapeVector(); + } + } + return shp; +} + ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y) { mindspore::abstract::ShapePtr tensor_x_shape = tensor_x->shape(); diff --git a/mindspore/core/abstract/utils.h b/mindspore/core/abstract/utils.h index 75ba63aa0b..413ff79940 100644 --- a/mindspore/core/abstract/utils.h +++ b/mindspore/core/abstract/utils.h @@ -48,6 +48,8 @@ bool CheckType(const TypePtr &expected_type, const TypePtr &x); int GetPositiveAxis(int axis_value, size_t increment); +std::vector BroadcastShape(std::vector shpx, std::vector shpy); + // Get broadcasted shape for binary element-wise operation ShapePtr GetBroadcastShape(const std::string &op, const AbstractTensorPtr &tensor_x, const AbstractTensorPtr &tensor_y); } // namespace abstract diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 1e0dedecd0..35799bcdd2 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -108,6 +108,9 @@ inline const PrimitivePtr kPrimUniqueGrad = std::make_shared("UniqueG inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared("ExtractImagePatches"); inline const PrimitivePtr kPrimDynamicRNN = std::make_shared("DynamicRNN"); inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared("DynamicRNNGrad"); +inline const PrimitivePtr kPrimScatterAdd = std::make_shared("ScatterAdd"); +inline const PrimitivePtr kPrimScatterUpdate = std::make_shared("ScatterUpdate"); +inline const PrimitivePtr kPrimDiv = std::make_shared("Div"); // NN inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); @@ -171,6 +174,9 @@ inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOper inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); +inline const PrimitivePtr kPrimBroadcast = std::make_shared("Broadcast"); +inline const PrimitivePtr kPrimAllGather = std::make_shared("AllGather"); +inline const PrimitivePtr kPrimReduceScatter = std::make_shared("ReduceScatter"); // RowTensor inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared("MakeRowTensor"); diff --git a/mindspore/core/utils/convert_utils_base.h b/mindspore/core/utils/convert_utils_base.h index ade7c3a967..6aa70346fe 100644 --- a/mindspore/core/utils/convert_utils_base.h +++ b/mindspore/core/utils/convert_utils_base.h @@ -46,7 +46,8 @@ inline int64_t SizeToLong(size_t u) { inline size_t IntToSize(int u) { if (u < 0) { - MS_LOG(EXCEPTION) << "The int value(" << u << ") is less than 0."; + MS_LOG(WARNING) << "The int value(" << u << ") is less than 0."; + return SIZE_MAX; } return static_cast(u); } diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index e181c97010..dd8b536363 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """aicpu ops""" +from .unique import _unique_aicpu from .init_data_set_queue import _init_data_set_queue_aicpu from .embedding_lookup import _embedding_lookup_aicpu from .padding import _padding_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/dynamic_shape.py b/mindspore/ops/_op_impl/aicpu/dynamic_shape.py new file mode 100644 index 0000000000..cbcbd41b52 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/dynamic_shape.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""DynamicShape op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +dynamic_shape_op_info = AiCPURegOp("DynamicShape") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(dynamic_shape_op_info) +def _dynamic_shape_aicpu(): + """Unique AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/unique.py b/mindspore/ops/_op_impl/aicpu/unique.py new file mode 100644 index 0000000000..849e969609 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/unique.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Unique op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +unique_op_info = AiCPURegOp("Unique") \ + .fusion_type("OPAQUE") \ + .input(0, "x", "required") \ + .output(0, "y", "required") \ + .output(1, "idx", "required") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + +@op_info_register(unique_op_info) +def _unique_aicpu(): + """Unique AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 195934e07a..5b3c1fec2f 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -62,7 +62,9 @@ from .max_pool_grad import _max_pool_grad_tbe from .max_pool_grad_with_argmax import _max_pool_grad_with_argmax_tbe from .max_pool_with_argmax import _max_pool_with_argmax_tbe from .mul import _mul_tbe +from .mul_ds import _mul_ds_tbe from .real_div import _real_div_tbe +from .real_div_ds import _real_div_ds_tbe from .relu import _relu_tbe from .relu_grad import _relu_grad_tbe from .relu6 import _relu6_tbe @@ -73,11 +75,11 @@ from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logit from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe from .tensor_add import _tensor_add_tbe +from .tensor_add_ds import _tensor_add_ds_tbe from .trans_data import _trans_data_tbe from .top_k import _top_k_tbe from .matmul import _matmul_tbe from .sub import _sub_tbe -from .reduce_mean_d import _reduce_mean_d_tbe from .scatter_nd import _scatter_nd_tbe from .scatter_nd_d import _scatter_nd_d_tbe from .scatter_nd_add import _scatter_nd_add_tbe @@ -87,6 +89,7 @@ from .reduce_mean import _reduce_mean_tbe from .tile import _tile_tbe from .atomic_addr_clean import _atomic_addr_clean_tbe from .gather_v2 import _gather_v2_tbe +from .gather_v2_ds import _gather_v2_ds_tbe from .gather_nd import _gather_nd_tbe from .bn_training_reduce import _bn_training_reduce_tbe from .bn_training_reduce_grad import _bn_training_reduce_grad_tbe @@ -106,6 +109,7 @@ from .expm1 import _expm1_tbe from .elu import _elu_tbe from .elu_grad import _elu_grad_tbe from .div import _div_tbe +from .div_ds import _div_ds_tbe from .log import _log_tbe from .xdivy import _xdivy_tbe from .xlogy import _xlogy_tbe @@ -134,15 +138,19 @@ from .softplus import _softplus_tbe from .softplus_grad import _softplus_grad_tbe from .softmax_grad_ext import _softmax_grad_ext_tbe from .square import _square_tbe +from .square_ds import _square_ds_tbe from .squared_difference import _squared_difference_tbe from .sqrt import _sqrt_tbe +from .sqrt_ds import _sqrt_ds_tbe from .sparse_apply_ftrl_d import _sparse_apply_ftrl_d +from .sparse_apply_ftrl_d_ds import _sparse_apply_ftrl_d_ds from .sparse_apply_proximal_adagrad import _sparse_apply_proximal_adagrad from .apply_proximal_adagrad import _apply_proximal_adagrad from .transpose_d import _transpose_d_tbe from .truncate_div import _truncate_div_tbe from .truncate_mod import _truncate_mod_tbe from .unsorted_segment_sum import _unsorted_segment_sum_tbe +from .unsorted_segment_sum_ds import _unsorted_segment_sum_ds_tbe from .unsorted_segment_prod import _unsorted_segment_prod_tbe from .logsoftmax_grad import _logsoftmax_grad_tbe from .logsoftmax import _logsoftmax_tbe @@ -229,6 +237,7 @@ from .square_sum_all import _square_sum_all_tbe from .pack import _pack_tbe from .unpack import _unpack_tbe from .scatter_update import _scatter_update_tbe +from .scatter_update_ds import _scatter_update_ds_tbe from .prelu import _prelu_tbe from .prelu_grad import _prelu_grad_tbe from .binary_cross_entropy import _binary_cross_entropy_tbe @@ -245,6 +254,7 @@ from .sqrt_grad import _sqrt_grad_tbe from .rsqrt_grad import _rsqrt_grad_tbe from .flatten_grad import _flatten_grad_tbe from .scatter_add import _scatter_add_tbe +from .scatter_add_ds import _scatter_add_ds_tbe from .atan2 import _atan2_tbe from .bessel_i0e import _bessel_i0e_tbe from .bessel_i1e import _bessel_i1e_tbe diff --git a/mindspore/ops/_op_impl/tbe/accumulate_n_v2.py b/mindspore/ops/_op_impl/tbe/accumulate_n_v2.py index b16233c37e..87bb8c8541 100644 --- a/mindspore/ops/_op_impl/tbe/accumulate_n_v2.py +++ b/mindspore/ops/_op_impl/tbe/accumulate_n_v2.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType accumulate_n_v2_op_info = TBERegOp("AccumulateNV2") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ - .binfile_name("accumulate_n_v2.so") \ + .binfile_name("accumulate_nv2.so") \ .compute_cost(10) \ - .kernel_name("accumulate_n_v2") \ + .kernel_name("accumulate_nv2") \ .partial_flag(True) \ .attr("n", "required", "int", "all") \ .input(0, "x", False, "dynamic", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/apply_adam.py b/mindspore/ops/_op_impl/tbe/apply_adam.py index 6fd7205567..e7d0ee4975 100644 --- a/mindspore/ops/_op_impl/tbe/apply_adam.py +++ b/mindspore/ops/_op_impl/tbe/apply_adam.py @@ -21,7 +21,7 @@ apply_adam_op_info = TBERegOp("Adam") \ .async_flag(False) \ .binfile_name("apply_adam.so") \ .compute_cost(10) \ - .kernel_name("apply_adam") \ + .kernel_name("apply_adam_d") \ .partial_flag(True) \ .attr("use_locking", "optional", "bool", "true,false", "false") \ .attr("use_nesterov", "optional", "bool", "true,false", "false") \ diff --git a/mindspore/ops/_op_impl/tbe/apply_ftrl.py b/mindspore/ops/_op_impl/tbe/apply_ftrl.py index 56c6bf3612..48bc64da54 100644 --- a/mindspore/ops/_op_impl/tbe/apply_ftrl.py +++ b/mindspore/ops/_op_impl/tbe/apply_ftrl.py @@ -21,7 +21,7 @@ apply_ftrl_op_info = TBERegOp("ApplyFtrl") \ .async_flag(False) \ .binfile_name("apply_ftrl.so") \ .compute_cost(10) \ - .kernel_name("apply_ftrl") \ + .kernel_name("apply_ftrl_d") \ .partial_flag(True) \ .input(0, "var", False, "required", "all") \ .input(1, "accum", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/apply_momentum.py b/mindspore/ops/_op_impl/tbe/apply_momentum.py index deb8f0d387..11a7037e2c 100644 --- a/mindspore/ops/_op_impl/tbe/apply_momentum.py +++ b/mindspore/ops/_op_impl/tbe/apply_momentum.py @@ -21,7 +21,7 @@ apply_momentum_op_info = TBERegOp("ApplyMomentum") \ .async_flag(False) \ .binfile_name("apply_momentum.so") \ .compute_cost(10) \ - .kernel_name("apply_momentum") \ + .kernel_name("apply_momentum_d") \ .partial_flag(True) \ .attr("use_nesterov", "optional", "bool", "true,false", "false") \ .input(0, "var", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/assign_add.py b/mindspore/ops/_op_impl/tbe/assign_add.py index 7ad23ff3bc..5003157b2f 100644 --- a/mindspore/ops/_op_impl/tbe/assign_add.py +++ b/mindspore/ops/_op_impl/tbe/assign_add.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType assign_add_op_info = TBERegOp("AssignAdd") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("assignadd.so") \ + .binfile_name("assign_add.so") \ .compute_cost(10) \ - .kernel_name("assignadd") \ + .kernel_name("assign_add") \ .partial_flag(True) \ .input(0, "ref", False, "required", "all") \ .input(1, "value", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/batchnorm_grad.py b/mindspore/ops/_op_impl/tbe/batchnorm_grad.py index 973f3709e5..286fb27301 100644 --- a/mindspore/ops/_op_impl/tbe/batchnorm_grad.py +++ b/mindspore/ops/_op_impl/tbe/batchnorm_grad.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType batch_norm_grad_op_info = TBERegOp("BatchNormGrad") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("batchnormgrad.so") \ + .binfile_name("batch_norm_grad.so") \ .compute_cost(10) \ - .kernel_name("batchnormgrad") \ + .kernel_name("batch_norm_grad") \ .partial_flag(True) \ .attr("epsilon", "optional", "float", "all") \ .attr("data_format", "optional", "str", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/bias_add_grad.py b/mindspore/ops/_op_impl/tbe/bias_add_grad.py index e59c197bce..2205bc6dcb 100644 --- a/mindspore/ops/_op_impl/tbe/bias_add_grad.py +++ b/mindspore/ops/_op_impl/tbe/bias_add_grad.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ .fusion_type("COMMREDUCE") \ .async_flag(False) \ - .binfile_name("biasaddgrad.so") \ + .binfile_name("bias_add_grad.so") \ .compute_cost(10) \ - .kernel_name("biasaddgrad") \ + .kernel_name("bias_add_grad") \ .partial_flag(True) \ .attr("data_format", "required", "str", "all") \ .input(0, "output_backprop", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py index e49d5386f2..06a12fed23 100644 --- a/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py +++ b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py @@ -18,6 +18,8 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType confusion_mul_grad_op_info = TBERegOp("ConfusionMulGrad") \ .fusion_type("OPAQUE") \ + .binfile_name("confusion_mul_grad.so") \ + .kernel_name("confusion_mul_grad") \ .attr("axis", "required", "listInt", "all") \ .attr("keep_dims", "required", "bool", "all") \ .input(0, "input0", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/div_ds.py b/mindspore/ops/_op_impl/tbe/div_ds.py new file mode 100644 index 0000000000..f3229bf5ae --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/div_ds.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Div op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +div_op_info = TBERegOp("Div") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("div.so") \ + .compute_cost(10) \ + .kernel_name("div") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.I8_None, DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.U8_None, DataType.U8_None, DataType.U8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(div_op_info) +def _div_ds_tbe(): + """Div TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/floor_div.py b/mindspore/ops/_op_impl/tbe/floor_div.py index c700e8d15a..9368bdbc57 100644 --- a/mindspore/ops/_op_impl/tbe/floor_div.py +++ b/mindspore/ops/_op_impl/tbe/floor_div.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType floordiv_op_info = TBERegOp("FloorDiv") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ - .binfile_name("floordiv.so") \ + .binfile_name("floor_div.so") \ .compute_cost(10) \ - .kernel_name("floordiv") \ + .kernel_name("floor_div") \ .partial_flag(True) \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py b/mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py index e4f3f8be16..d2db9eb20d 100644 --- a/mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py +++ b/mindspore/ops/_op_impl/tbe/fused_mul_add_n_l2loss.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType fused_mul_add_n_l2loss_op_info = TBERegOp("FusedMulAddNL2loss") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("fused_mul_addn_l2loss.so") \ + .binfile_name("fused_mul_addn_l2_loss.so") \ .compute_cost(10) \ - .kernel_name("fused_mul_addn_l2loss") \ + .kernel_name("fused_mul_addn_l2_loss") \ .partial_flag(True) \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/gather_v2_ds.py b/mindspore/ops/_op_impl/tbe/gather_v2_ds.py new file mode 100644 index 0000000000..468571844a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/gather_v2_ds.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""AddN op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +gather_v2_op_info = TBERegOp("GatherV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("gather_v2.so") \ + .compute_cost(10) \ + .kernel_name("gather_v2") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "axis", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.I64_5HD, DataType.I32_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_FracZ, DataType.I64_FracZ, DataType.I32_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_5HD, DataType.I64_5HD, DataType.I32_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.U8_FracZ, DataType.I64_FracZ, DataType.I32_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_5HD, DataType.I64_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.I64_5HD, DataType.I32_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_FracZ, DataType.I64_FracZ, DataType.I32_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.I32_5HD, DataType.I32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.I64_5HD, DataType.I32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_FracZ, DataType.I64_FracZ, DataType.I32_FracZ, DataType.F32_FracZ) \ + .get_op_info() + + +@op_info_register(gather_v2_op_info) +def _gather_v2_ds_tbe(): + """GatherV2 TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/lin_space.py b/mindspore/ops/_op_impl/tbe/lin_space.py index 6e474c50ea..9093894972 100644 --- a/mindspore/ops/_op_impl/tbe/lin_space.py +++ b/mindspore/ops/_op_impl/tbe/lin_space.py @@ -21,7 +21,7 @@ lin_space_op_info = TBERegOp("LinSpace") \ .async_flag(False) \ .binfile_name("lin_space.so") \ .compute_cost(10) \ - .kernel_name("lin_space") \ + .kernel_name("lin_space_d") \ .partial_flag(True) \ .op_pattern("broadcast") \ .input(0, "assist", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/logsoftmax.py b/mindspore/ops/_op_impl/tbe/logsoftmax.py index 9bf0baf3f2..59aac3400a 100644 --- a/mindspore/ops/_op_impl/tbe/logsoftmax.py +++ b/mindspore/ops/_op_impl/tbe/logsoftmax.py @@ -21,7 +21,7 @@ log_softmax_op_info = TBERegOp("LogSoftmax") \ .async_flag(False) \ .binfile_name("log_softmax.so") \ .compute_cost(10) \ - .kernel_name("log_softmax") \ + .kernel_name("log_softmax_v2") \ .partial_flag(True) \ .attr("axis", "optional", "listInt", "all") \ .input(0, "logits", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/matmul.py b/mindspore/ops/_op_impl/tbe/matmul.py index e773191ae8..be7f7303e4 100644 --- a/mindspore/ops/_op_impl/tbe/matmul.py +++ b/mindspore/ops/_op_impl/tbe/matmul.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType matmul_op_info = TBERegOp("MatMul") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("matmul.so") \ + .binfile_name("mat_mul.so") \ .compute_cost(10) \ - .kernel_name("matmul") \ + .kernel_name("mat_mul") \ .partial_flag(True) \ .attr("transpose_x1", "required", "bool", "all") \ .attr("transpose_x2", "required", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/matrix_set_diag.py b/mindspore/ops/_op_impl/tbe/matrix_set_diag.py index db0b460084..90eec88db1 100644 --- a/mindspore/ops/_op_impl/tbe/matrix_set_diag.py +++ b/mindspore/ops/_op_impl/tbe/matrix_set_diag.py @@ -21,7 +21,7 @@ matrix_diag_d_op_info = TBERegOp("MatrixSetDiag") \ .async_flag(False) \ .binfile_name("matrix_diag_d.so") \ .compute_cost(10) \ - .kernel_name("matrix_diag_d") \ + .kernel_name("matrix_set_diag_d") \ .partial_flag(True) \ .input(0, "x", False, "required", "all") \ .input(1, "diagonal", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/mul_ds.py b/mindspore/ops/_op_impl/tbe/mul_ds.py new file mode 100644 index 0000000000..328fe4f669 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/mul_ds.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Mul op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +mul_ds_op_info = TBERegOp("Mul") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("mul.so") \ + .compute_cost(10) \ + .kernel_name("mul") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "y", False, "required", "all") \ + .output(0, "output", False, "required", "all") \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ + .get_op_info() + + +@op_info_register(mul_ds_op_info) +def _mul_ds_tbe(): + """Mul TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/one_hot.py b/mindspore/ops/_op_impl/tbe/one_hot.py index 81a80bf759..616c1738ba 100644 --- a/mindspore/ops/_op_impl/tbe/one_hot.py +++ b/mindspore/ops/_op_impl/tbe/one_hot.py @@ -21,7 +21,7 @@ one_hot_op_info = TBERegOp("OneHot") \ .async_flag(False) \ .binfile_name("one_hot.so") \ .compute_cost(10) \ - .kernel_name("one_hot") \ + .kernel_name("one_hot_d") \ .partial_flag(True) \ .attr("depth", "required", "int", "all") \ .attr("axis", "required", "int", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/real_div.py b/mindspore/ops/_op_impl/tbe/real_div.py index 9c6d9e0b27..679273cf7b 100644 --- a/mindspore/ops/_op_impl/tbe/real_div.py +++ b/mindspore/ops/_op_impl/tbe/real_div.py @@ -21,7 +21,7 @@ realdiv_op_info = TBERegOp("RealDiv") \ .async_flag(False) \ .binfile_name("realdiv.so") \ .compute_cost(10) \ - .kernel_name("realdiv") \ + .kernel_name("real_div") \ .partial_flag(True) \ .input(0, "x", False, "required", "all") \ .input(1, "y", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/real_div_ds.py b/mindspore/ops/_op_impl/tbe/real_div_ds.py new file mode 100644 index 0000000000..1038807468 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/real_div_ds.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RealDiv op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +realdiv_op_info = TBERegOp("RealDiv") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("realdiv.so") \ + .compute_cost(10) \ + .kernel_name("real_div") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "y", False, "required", "all") \ + .output(0, "z", False, "required", "all") \ + .op_pattern("broadcast") \ + .dtype_format(DataType.F16_None, DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(realdiv_op_info) +def _real_div_ds_tbe(): + """RealDiv TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/reduce_mean.py b/mindspore/ops/_op_impl/tbe/reduce_mean.py index b01fd3bebd..9cb9bd41d8 100644 --- a/mindspore/ops/_op_impl/tbe/reduce_mean.py +++ b/mindspore/ops/_op_impl/tbe/reduce_mean.py @@ -21,7 +21,7 @@ reduce_mean_op_info = TBERegOp("ReduceMean") \ .async_flag(False) \ .binfile_name("reduce_mean.so") \ .compute_cost(10) \ - .kernel_name("reduce_mean") \ + .kernel_name("reduce_mean_d") \ .partial_flag(True) \ .attr("axis", "optional", "listInt", "all") \ .attr("keep_dims", "optional", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/relu_grad.py b/mindspore/ops/_op_impl/tbe/relu_grad.py index 040294f973..bf809ca41d 100644 --- a/mindspore/ops/_op_impl/tbe/relu_grad.py +++ b/mindspore/ops/_op_impl/tbe/relu_grad.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType relugrad_op_info = TBERegOp("ReluGrad") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ - .binfile_name("relugrad.so") \ + .binfile_name("relu_grad.so") \ .compute_cost(10) \ - .kernel_name("relugrad") \ + .kernel_name("relu_grad") \ .partial_flag(True) \ .input(0, "gradients", False, "required", "all") \ .input(1, "features", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py index 4fa3107f59..38da10c9f2 100644 --- a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py +++ b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor.py @@ -21,7 +21,7 @@ resize_nearest_neighbor_op_info = TBERegOp("ResizeNearestNeighbor") \ .async_flag(False) \ .binfile_name("resize_nearest_neighbor_d.so") \ .compute_cost(10) \ - .kernel_name("resize_nearest_neighbor_d") \ + .kernel_name("resize_nearest_neighbor_v2_d") \ .partial_flag(True) \ .attr("size", "required", "listInt", "all") \ .attr("align_corners", "optional", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py index d13150b5fb..bf046de668 100644 --- a/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py +++ b/mindspore/ops/_op_impl/tbe/resize_nearest_neighbor_grad.py @@ -21,7 +21,7 @@ resize_nearest_neighbor_grad_op_info = TBERegOp("ResizeNearestNeighborGrad") \ .async_flag(False) \ .binfile_name("resize_nearest_neighbor_grad_d.so") \ .compute_cost(10) \ - .kernel_name("resize_nearest_neighbor_grad_d") \ + .kernel_name("resize_nearest_neighbor_v2_grad_d") \ .partial_flag(True) \ .attr("size", "required", "listInt", "all") \ .attr("align_corners", "optional", "bool", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/scatter_add_ds.py b/mindspore/ops/_op_impl/tbe/scatter_add_ds.py new file mode 100644 index 0000000000..011eaf04af --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_add_ds.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_add_ds_op_info = TBERegOp("ScatterAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_add.so") \ + .compute_cost(10) \ + .kernel_name("scatter_add") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(scatter_add_ds_op_info) +def _scatter_add_ds_tbe(): + """ScatterAdd TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/scatter_update_ds.py b/mindspore/ops/_op_impl/tbe/scatter_update_ds.py new file mode 100644 index 0000000000..83df66ef95 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/scatter_update_ds.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ScatterUpdate op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +scatter_update_op_info = TBERegOp("ScatterUpdate") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("scatter_update.so") \ + .compute_cost(10) \ + .kernel_name("scatter_update") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("use_locking", "optional", "bool", "all") \ + .input(0, "var", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(1, "updates", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.BOOL_Default, DataType.I32_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(scatter_update_op_info) +def _scatter_update_ds_tbe(): + """ScatterUpdate TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/softmax.py b/mindspore/ops/_op_impl/tbe/softmax.py index faefad87ec..6d19c7f6ef 100644 --- a/mindspore/ops/_op_impl/tbe/softmax.py +++ b/mindspore/ops/_op_impl/tbe/softmax.py @@ -21,7 +21,7 @@ softmax_op_info = TBERegOp("Softmax") \ .async_flag(False) \ .binfile_name("softmax.so") \ .compute_cost(10) \ - .kernel_name("softmax") \ + .kernel_name("softmax_v2") \ .partial_flag(True) \ .attr("axis", "optional", "listInt", "all") \ .input(0, "x", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py index a61f6174b9..2a9a8f3175 100644 --- a/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("sparse_apply_ftrl.so") \ + .binfile_name("sparse_apply_ftrl_d.so") \ .compute_cost(10) \ - .kernel_name("sparse_apply_ftrl") \ + .kernel_name("sparse_apply_ftrl_d") \ .partial_flag(True) \ .attr("lr", "required", "float", "all") \ .attr("l1", "required", "float", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py new file mode 100644 index 0000000000..4ac3ec8a9e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_ftrl_d_ds.py @@ -0,0 +1,52 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseApplyFtrl op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_apply_ftrl_d_op_info = TBERegOp("SparseApplyFtrl") \ + .fusion_type("DYNAMIC") \ + .async_flag(False) \ + .binfile_name("sparse_apply_ftrl.so") \ + .compute_cost(10) \ + .kernel_name("sparse_apply_ftrl_d") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("lr", "required", "float", "all") \ + .attr("l1", "required", "float", "all") \ + .attr("l2", "required", "float", "all") \ + .attr("lr_power", "required", "float", "all") \ + .attr("use_locking", "optional", "bool", "true,false", "false") \ + .input(0, "var", False, "required", "all") \ + .input(1, "accum", False, "required", "all") \ + .input(2, "linear", False, "required", "all") \ + .input(3, "grad", False, "required", "all") \ + .input(4, "indices", False, "required", "all") \ + .output(0, "var", False, "required", "all") \ + .output(1, "accum", False, "required", "all") \ + .output(2, "linear", False, "required", "all") \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, + DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, + DataType.I32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(sparse_apply_ftrl_d_op_info) +def _sparse_apply_ftrl_d_ds(): + """SparseApplyFtrl TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py index 782be983fa..32ca2a6e5f 100644 --- a/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py +++ b/mindspore/ops/_op_impl/tbe/sparse_apply_proximal_adagrad.py @@ -21,7 +21,7 @@ sparse_apply_proximal_adagrad_d_op_info = TBERegOp("SparseApplyProximalAdagrad") .async_flag(False) \ .binfile_name("sparse_apply_proximal_adagrad.so") \ .compute_cost(10) \ - .kernel_name("sparse_apply_proximal_adagrad") \ + .kernel_name("sparse_apply_proximal_adagrad_d") \ .partial_flag(True) \ .attr("use_locking", "optional", "bool", "true,false", "false") \ .input(0, "var", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/sqrt_ds.py b/mindspore/ops/_op_impl/tbe/sqrt_ds.py new file mode 100644 index 0000000000..e45b66e2eb --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sqrt_ds.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Sqrt op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sqrt_op_info = TBERegOp("Sqrt") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("sqrt.so") \ + .compute_cost(10) \ + .kernel_name("sqrt") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(sqrt_op_info) +def _sqrt_ds_tbe(): + """Sqrt TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/reduce_mean_d.py b/mindspore/ops/_op_impl/tbe/square_ds.py similarity index 69% rename from mindspore/ops/_op_impl/tbe/reduce_mean_d.py rename to mindspore/ops/_op_impl/tbe/square_ds.py index a0890816d2..d8d6c1e066 100644 --- a/mindspore/ops/_op_impl/tbe/reduce_mean_d.py +++ b/mindspore/ops/_op_impl/tbe/square_ds.py @@ -13,29 +13,27 @@ # limitations under the License. # ============================================================================ -"""ReduceMeanD op""" +"""Square op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -reduce_mean_d_op_info = TBERegOp("ReduceMeanD") \ +square_ds_op_info = TBERegOp("Square") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("reduce_mean_d.so") \ + .binfile_name("square.so") \ .compute_cost(10) \ - .kernel_name("reduce_mean_d") \ + .kernel_name("square") \ .partial_flag(True) \ - .attr("axis", "optional", "listInt", "all") \ - .attr("keep_dims", "optional", "bool", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .op_pattern("reduce") \ - .dtype_format(DataType.I8_None, DataType.I8_None) \ - .dtype_format(DataType.U8_None, DataType.U8_None) \ + .op_pattern("formatAgnostic") \ + .dynamic_shape(True) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ .dtype_format(DataType.F16_None, DataType.F16_None) \ .dtype_format(DataType.F32_None, DataType.F32_None) \ .get_op_info() -@op_info_register(reduce_mean_d_op_info) -def _reduce_mean_d_tbe(): - """Conv2D TBE register""" +@op_info_register(square_ds_op_info) +def _square_ds_tbe(): + """Square TBE register""" return diff --git a/mindspore/ops/_op_impl/tbe/square_sum_all.py b/mindspore/ops/_op_impl/tbe/square_sum_all.py index e9d56e44b1..2f9185aafb 100644 --- a/mindspore/ops/_op_impl/tbe/square_sum_all.py +++ b/mindspore/ops/_op_impl/tbe/square_sum_all.py @@ -21,7 +21,7 @@ square_sum_all_op_info = TBERegOp("SquareSumAll") \ .async_flag(False) \ .binfile_name("square_sum_all.so") \ .compute_cost(10) \ - .kernel_name("square_sum") \ + .kernel_name("square_sum_all") \ .partial_flag(True) \ .input(0, "x1", False, "required", "all") \ .input(1, "x2", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/tensor_add_ds.py b/mindspore/ops/_op_impl/tbe/tensor_add_ds.py new file mode 100644 index 0000000000..3a39acaa7c --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/tensor_add_ds.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""TensorAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +tensor_add_op_info = TBERegOp("TensorAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("add.so") \ + .compute_cost(10) \ + .kernel_name("add") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x1", False, "required", "all") \ + .input(1, "x2", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(tensor_add_op_info) +def _tensor_add_ds_tbe(): + """Add TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py b/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py index 46d6b20357..5f46d38f3d 100644 --- a/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py +++ b/mindspore/ops/_op_impl/tbe/tensor_scatter_update.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType tensor_scatter_update_op_info = TBERegOp("TensorScatterUpdate") \ .fusion_type("ELEMWISE") \ .async_flag(False) \ - .binfile_name("tensor_scatter_update.so") \ + .binfile_name("scatter_update.so") \ .compute_cost(10) \ - .kernel_name("tensor_scatter_update") \ + .kernel_name("scatter_update") \ .partial_flag(True) \ .input(0, "x", False, "required", "all") \ .input(1, "indices", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/top_k.py b/mindspore/ops/_op_impl/tbe/top_k.py index a97ecadae0..32e8f8844e 100644 --- a/mindspore/ops/_op_impl/tbe/top_k.py +++ b/mindspore/ops/_op_impl/tbe/top_k.py @@ -19,9 +19,9 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType top_k_op_info = TBERegOp("TopK") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("top_k_d.so") \ + .binfile_name("top_k.so") \ .compute_cost(10) \ - .kernel_name("top_k_d") \ + .kernel_name("top_k") \ .partial_flag(True) \ .attr("dim", "optional", "int", "all") \ .attr("k", "required", "int", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py b/mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py new file mode 100644 index 0000000000..acebde97f3 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/unsorted_segment_sum_ds.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""UnsortedSegmentSum op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +unsorted_segment_sum_ds_op_info = TBERegOp("UnsortedSegmentSum") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("unsorted_segment_sum.so") \ + .compute_cost(10) \ + .kernel_name("unsorted_segment_sum") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "segment_ids", False, "required", "all") \ + .input(2, "num_segments", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(unsorted_segment_sum_ds_op_info) +def _unsorted_segment_sum_ds_tbe(): + """UnsortedSegmentSumUnknown TBE register""" + return diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 0397b3ecab..ced224b7c9 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -354,6 +354,7 @@ class TBERegOp(RegOp): self.partial_flag_ = False self.reshape_type_ = '' self.dynamic_format_ = False + self.dynamic_shape_ = False self.op_pattern_ = "" def async_flag(self, async_flag): @@ -433,6 +434,17 @@ class TBERegOp(RegOp): self.dynamic_format_ = dynamic_format return self + def dynamic_shape(self, dynamic_shape): + """ + Whether the operator supports dynamic shape. + + Args: + dynamic_shape (bool): Value of dynamic shape. Default: false. + """ + self._is_bool(dynamic_shape) + self.dynamic_shape_ = dynamic_shape + return self + def op_pattern(self, pattern=None): """ The behavior type of opeator, such as broadcast, reduce and so on. diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 7fe3449008..044ab3a7f1 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -431,7 +431,7 @@ class Reshape(PrimitiveWithInfer): return out -class Shape(PrimitiveWithInfer): +class Shape(Primitive): """ Returns the shape of input tensor. @@ -452,13 +452,6 @@ class Shape(PrimitiveWithInfer): def __init__(self): """Initialize Shape""" - def __infer__(self, x): - validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) - out = {'shape': (), - 'dtype': mstype.tuple_, - 'value': tuple(x['shape'])} - return out - class DynamicShape(Primitive): """ @@ -478,8 +471,10 @@ class DynamicShape(Primitive): @prim_attr_register def __init__(self): - """Initialize Shape""" - + """init Shape""" + self.init_prim_io_names(inputs=['tensor'], outputs=['output']) + self.add_prim_attr('is_dynamic_shape', True) + self.add_prim_attr("dynamic_shape_depends", [0]) class Squeeze(PrimitiveWithInfer): """ @@ -643,6 +638,7 @@ class GatherV2(PrimitiveWithCheck): def __init__(self): """Initialize index_select""" self.init_prim_io_names(inputs=['params', 'indices', 'axis'], outputs=['output']) + self.add_prim_attr("dynamic_shape_depends", [2,]) def __check__(self, params, indices, axis): validator.check_subclass("params", params['dtype'], mstype.tensor, self.name) @@ -653,6 +649,17 @@ class GatherV2(PrimitiveWithCheck): rank = len(params_shp) validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) + if axis_v < 0: + axis_v += rank + out_shape = params_shp[:axis_v] + indices['shape'] + params_shp[axis_v + 1:] + out = {'shape': out_shape, + 'dtype': params['dtype'], + 'value': None} + if 'min_shape' in indices and 'max_shape' in indices: + out['min_shape'] = params_shp[:axis_v] + indices['min_shape'] + params_shp[axis_v + 1:] + out['max_shape'] = params_shp[:axis_v] + indices['max_shape'] + params_shp[axis_v + 1:] + return out + class SparseGatherV2(GatherV2): """ @@ -1475,6 +1482,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): def __init__(self): """Initialize UnsortedSegmentSum""" self.init_prim_io_names(inputs=['x', 'segment_ids', 'num_segments'], outputs=['y']) + self.add_prim_attr("dynamic_shape_depends", [2,]) def __infer__(self, x, segment_ids, num_segments): x_type = x['dtype'] @@ -1494,11 +1502,24 @@ class UnsortedSegmentSum(PrimitiveWithInfer): for i, value in enumerate(segment_ids_shp): validator.check("ids[%d]" % i, value, 'input[%d]' % i, x_shp[i], Rel.EQ, self.name) num_segments_v = num_segments['value'] - validator.check_value_type('num_segments', num_segments_v, [int], self.name) - validator.check_positive_int(num_segments_v, "num_segments", self.name) - shp = [num_segments_v] + num_segments_type = num_segments['dtype'] + validator.check_subclass("num_segments", num_segments_type, [mstype.tensor, mstype.number], self.name) + if isinstance(num_segments_type, type(mstype.tensor)): + validator.check_tensor_type_same({"num_segments": num_segments_type}, [mstype.int32], self.name) + shp = [-1] + else: + validator.check_value_type('num_segments', num_segments_v, [int], self.name) + validator.check_positive_int(num_segments_v, "num_segments", self.name) + shp = [num_segments_v] + shp += x_shp[segment_ids_shp_len:] + if 'max_shape' in x: + output_max_shape = x['max_shape'] + else: + output_max_shape = x_shp out = {'shape': shp, + 'max_shape': output_max_shape, + 'min_shape': [1] * segment_ids_shp_len + x_shp[segment_ids_shp_len:], 'dtype': mstype.tensor_type(x_type.element_type()), 'value': None} return out diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 934f72d209..591839e458 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -5122,6 +5122,8 @@ class SparseApplyFtrl(PrimitiveWithCheck): self.l2 = validator.check_number_range("l2", l2, 0.0, float("inf"), Rel.INC_LEFT, self.name) self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], + outputs=['var', 'accum', 'linear']) def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) diff --git a/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py b/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py index e58ea23a1f..4da0aa6cf3 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py @@ -18,27 +18,24 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P -from mindspore.train.model import Model +import mindspore.common.dtype as mstype context.set_context(device_target="Ascend") class Net(nn.Cell): - def __init__(self, num_segments): + def __init__(self): super(Net, self).__init__() self.seg_sum = P.UnsortedSegmentSum() - self.num_segments = num_segments - def construct(self, x, segment_ids): - return self.seg_sum(x, segment_ids, self.num_segments) + def construct(self, x, segment_ids, num_segments): + return self.seg_sum(x, segment_ids, num_segments) def me_un_seg_sum(input_, indices, num_segments): context.set_context(mode=context.GRAPH_MODE) - net = Net(num_segments) - net.set_train() - model = Model(net) - out = model.predict(Tensor(input_), Tensor(indices)) + net = Net() + out = net(Tensor(input_), Tensor(indices), Tensor(num_segments, mstype.int32)) return out.asnumpy() @@ -51,5 +48,6 @@ def comapre_un_seg_sum(shape, indices, num_segments, dtype): def test_net(): + np.random.seed(0) indices = np.random.randint(0, 1280, 1280) comapre_un_seg_sum([1280, 768], indices, 8192, np.float32) diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc index 5d42ff7069..9a832ec21b 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/add_input_to_output_test.cc @@ -38,7 +38,7 @@ class MockOpFinder : public OpFinder { public: MockOpFinder() = default; ~MockOpFinder() override = default; - int GetOpRegisteredOutputNum(const std::string &op_name) override { return 2; } + int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) override { return 2; } }; TEST_F(TestHWAddInputToOutput, test_add_input_to_output) { diff --git a/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc b/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc new file mode 100644 index 0000000000..4c54ef16d4 --- /dev/null +++ b/tests/ut/cpp/stub/dynamic_shape/dynamic_shape_stub.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h" +#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h" +#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h" +#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h" +#include "profiler/device/ascend/rt_callback_manager.h" +#include "runtime/device/ascend/executor/executor_callback.h" +#include "profiler/device/ascend/ascend_profiling.h" +#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h" +#include "backend/kernel_compiler/host/host_kernel_metadata.h" +#include "backend/kernel_compiler/host/host_kernel_build.h" + +namespace mindspore { +namespace device { +namespace ascend { +void HcclDynamicKernel::UpdateArgs() {} +void HcclDynamicKernel::Execute() {} +void HcclDynamicKernel::PostExecute() {} + +void MemcpyRtsDynamicKernel::Execute() {} + +void ProfilingRtsDynamicKernel::Execute() {} + +AiCoreDynamicKernel::~AiCoreDynamicKernel() {} +void AiCoreDynamicKernel::Execute() {} +void AiCoreDynamicKernel::UpdateArgs() {} +void AiCoreDynamicKernel::Initialize() {} +void AiCoreDynamicKernel::PostExecute() {} + +bool HcclExecutorManager::Initialize() { return true; } +bool HcclExecutorManager::Finalize() { return true; } + +void ExecutorCallback::RegistCallback(const std::function &callback) {} +void ExecutorCallback::Consume() {} + +void OpTilingCalculater::Init() {} +void OpTilingCalculater::CalculateTiling(const NotNull &cnode, const NotNull> &compile_info_json, + const std::map &depend_tensor_map, + NotNull op_run_info) {} +} // namespace ascend +} // namespace device +} // namespace mindspore + +namespace mindspore { +namespace profiler { +namespace ascend { +CallbackManager::CallbackManager(rtStream_t stream) : stream_(stream) {} +Status CallbackManager::Init() { return kSuccess; } +Status CallbackManager::Destroy() { return kSuccess; } +Status CallbackManager::RegisterCallback(rtCallback_t callback, const void *user_data) { return kSuccess; } +Status CallbackManager::RegisterCallback(const std::function &callback) { return kSuccess; } + +AscendProfiler::AscendProfiler() : counter_(0) { Reset(); } + +void AscendProfiler::RecordEvent(EventType event_type, const char *fmt, ...) {} + +void AscendProfiler::Dump(std::ostream &output_stream) {} + +void AscendProfiler::Reset() {} +} // namespace ascend +} // namespace profiler +} // namespace mindspore + +namespace mindspore { +namespace kernel { +void HostMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) {} +KernelModPtr HostOpBuild(const std::shared_ptr &anf_node) { return nullptr; } +} // namespace kernel +} // namespace mindspore