From 18d9e39ddf2d045307681cabbfe6c1621fc9866e Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Thu, 4 Jun 2020 11:51:32 +0800 Subject: [PATCH] synchronize with latest Ascend software suite 6 Jun 2020 --- CMakeLists.txt | 3 + inc/common/optimizer/graph_optimizer.h | 3 + inc/common/util/compress/compress.h | 36 + inc/common/util/error_manager/error_manager.h | 83 + inc/common/util/platform_info.h | 2 + inc/common/util/platform_info_def.h | 4 + inc/external/ge/ge_api.h | 4 +- inc/external/ge/ge_api_types.h | 77 +- inc/external/graph/graph.h | 4 + inc/external/graph/inference_context.h | 2 +- inc/external/graph/operator.h | 36 +- inc/external/graph/operator_reg.h | 62 +- inc/external/graph/tensor.h | 6 + inc/external/graph/types.h | 11 + inc/external/register/register.h | 31 +- inc/framework/common/debug/ge_log.h | 16 +- inc/framework/common/debug/log.h | 25 +- inc/framework/common/ge_inner_error_codes.h | 1 - inc/framework/common/ge_types.h | 6 +- inc/framework/common/gflags_util.h | 4 +- inc/framework/common/helper/model_helper.h | 7 +- inc/framework/common/helper/om_file_helper.h | 6 +- inc/framework/common/l2_cache_optimize.h | 4 +- inc/framework/common/op/attr_define.h | 810 -------- inc/framework/common/op/attr_value_util.h | 12 +- inc/framework/common/op/ge_op_utils.h | 2 + inc/framework/common/op/op_parser_util.h | 4 +- inc/framework/common/op_types.h | 4 +- inc/framework/common/scope_guard.h | 6 +- inc/framework/common/string_util.h | 4 +- inc/framework/common/types.h | 38 +- inc/framework/common/util.h | 138 +- inc/framework/ge_runtime/task_info.h | 0 inc/framework/generator/ge_generator.h | 2 + inc/framework/memory/memory_assigner.h | 2 +- inc/framework/omg/omg_inner_types.h | 30 +- inc/framework/omg/version.h | 4 +- inc/graph/attr_value_serializable.h | 18 +- inc/graph/compute_graph.h | 4 +- inc/graph/debug/ge_attr_define.h | 340 +++- inc/graph/detail/model_serialize_imp.h | 4 +- inc/graph/ge_attr_value.h | 45 +- inc/graph/ge_tensor.h | 6 + inc/graph/model.h | 4 - inc/graph/op_desc.h | 38 +- inc/graph/ref_relation.h | 79 + inc/graph/usr_types.h | 6 +- inc/graph/utils/attr_utils.h | 8 +- inc/graph/utils/graph_utils.h | 161 +- inc/graph/utils/node_utils.h | 35 +- inc/graph/utils/op_desc_utils.h | 41 +- src/common/graph/CMakeLists.txt | 3 +- src/common/graph/anchor.cc | 1 - src/common/graph/compute_graph.cc | 132 +- src/common/graph/debug/ge_op_types.h | 216 +-- src/common/graph/format_refiner.cc | 141 +- src/common/graph/ge_attr_define.cc | 249 ++- src/common/graph/ge_attr_value.cc | 31 +- src/common/graph/ge_tensor.cc | 40 + src/common/graph/graph.cc | 42 +- src/common/graph/model_serialize.cc | 80 +- src/common/graph/op_desc.cc | 95 +- src/common/graph/operator.cc | 167 +- src/common/graph/operator_factory_impl.cc | 1 - src/common/graph/ref_relation.cc | 422 ++++ src/common/graph/shape_refiner.cc | 21 +- src/common/graph/tensor.cc | 72 +- src/common/graph/utils/graph_utils.cc | 728 +++++-- src/common/graph/utils/node_utils.cc | 181 +- src/common/graph/utils/op_desc_utils.cc | 83 +- src/common/graph/utils/tensor_utils.cc | 1 + src/common/graph/utils/type_utils.cc | 5 +- src/ge/CMakeLists.txt | 256 +-- src/ge/client/CMakeLists.txt | 31 +- src/ge/client/ge_api.cc | 16 +- src/ge/common/CMakeLists.txt | 8 +- src/ge/common/auth/file_saver.cc | 7 +- src/ge/common/auth/file_saver.h | 52 +- src/ge/common/context/ctx.cc | 1 + src/ge/common/convert/pb2json.cc | 2 +- src/ge/common/debug/memory_dumper.cc | 4 +- .../format_transfers/datatype_transfer.cc | 15 +- .../format_transfers/datatype_transfer.h | 2 +- .../format_transfers/format_transfer.cc | 69 - .../format_transfer_c1hwncoc0_hwcn.cc | 21 +- .../format_transfer_c1hwncoc0_hwcn.h | 2 +- .../format_transfer_dhwcn_fracz3D.cc | 5 + .../format_transfer_dhwcn_fracz3D.h | 2 +- ...format_transfer_dhwnc_fracz3D_transpose.cc | 5 + .../format_transfer_dhwnc_fracz3D_transpose.h | 2 +- .../format_transfer_fractal_nz.cc | 10 + .../format_transfer_fractal_nz.h | 2 +- .../format_transfer_fractal_z.cc | 15 + .../format_transfer_fractal_z.h | 2 +- .../format_transfer_fractal_zz.cc | 10 + .../format_transfer_fractal_zz.h | 2 +- .../format_transfer_fracz_hwcn.cc | 6 + .../format_transfer_fracz_hwcn.h | 2 +- .../format_transfer_fracz_nchw.cc | 7 + .../format_transfer_fracz_nchw.h | 2 +- .../format_transfer_fracz_nhwc.cc | 6 + .../format_transfer_fracz_nhwc.h | 2 +- .../format_transfer_hwcn_c1hwncoc0.cc | 24 +- .../format_transfer_hwcn_c1hwncoc0.h | 2 +- .../format_transfer_nc1hwc0_nchw.cc | 8 +- .../format_transfer_nc1hwc0_nchw.h | 2 +- .../format_transfer_nc1hwc0_nhwc.cc | 8 +- .../format_transfer_nc1hwc0_nhwc.h | 2 +- .../format_transfer_nchw_fz_c04.cc | 314 +++ .../format_transfer_nchw_fz_c04.h | 35 + .../format_transfer_nchw_nc1hwc0.cc | 54 +- .../format_transfer_nchw_nc1hwc0.h | 2 +- .../format_transfer_nhwc_nc1hwc0.cc | 8 +- .../format_transfer_nhwc_nc1hwc0.h | 2 +- .../format_transfer_transpose.cc | 10 +- .../format_transfer_transpose.h | 2 +- src/ge/common/formats/formats.cc | 12 +- src/ge/common/formats/formats.h | 6 +- .../formats/utils/formats_trans_utils.cc | 6 +- .../formats/utils/formats_trans_utils.h | 3 + src/ge/common/fp16_t.h | 2 +- src/ge/common/ge/plugin_manager.cc | 29 +- src/ge/common/helper/model_cache_helper.cc | 1708 +++++++++++++++++ src/ge/common/helper/model_cache_helper.h | 123 ++ src/ge/common/helper/model_helper.cc | 11 +- src/ge/common/helper/om_file_helper.cc | 6 +- src/ge/common/math/math_util.h | 1 - src/ge/common/math_util.h | 4 +- src/ge/common/model_parser/base.cc | 8 +- src/ge/common/model_saver.cc | 2 +- src/ge/common/op/attr_define.cc | 814 -------- src/ge/common/op/attr_value_util.cc | 4 +- src/ge/common/op/ge_op_utils.cc | 9 +- src/ge/common/profiling/profiling_manager.cc | 219 ++- src/ge/common/profiling/profiling_manager.h | 68 +- src/ge/common/properties_manager.cc | 22 +- src/ge/common/properties_manager.h | 3 + src/ge/common/thread_pool.cc | 5 +- src/ge/common/types.cc | 389 ++-- src/ge/common/util.cc | 122 +- src/ge/executor/CMakeLists.txt | 7 +- src/ge/executor/ge_executor.cc | 32 +- src/ge/ge_local_engine/CMakeLists.txt | 6 +- .../ge_local_engine/engine/host_cpu_engine.cc | 2 +- .../ge_local_engine/engine/host_cpu_engine.h | 2 +- .../ge_local_ops_kernel_info.cc | 2 - .../ops_kernel_store/op/ge_deleted_op.cc | 2 +- .../ops_kernel_store/op/no_op.cc | 4 +- src/ge/ge_runtime/runtime_model.cc | 5 +- src/ge/generator/ge_generator.cc | 67 +- src/ge/generator/generator_api.cc | 2 +- src/ge/graph/build/graph_builder.cc | 30 +- src/ge/graph/build/graph_builder.h | 2 +- .../graph/build/logical_stream_allocator.cc | 368 ++-- src/ge/graph/build/logical_stream_allocator.h | 43 +- .../build/memory/binary_block_mem_assigner.cc | 6 +- .../graph/build/memory/block_mem_assigner.cc | 652 +++++-- .../graph/build/memory/block_mem_assigner.h | 99 +- .../graph/build/memory/graph_mem_assigner.cc | 912 +++++---- .../graph/build/memory/graph_mem_assigner.h | 33 +- .../graph/build/memory/hybrid_mem_assigner.cc | 6 +- .../graph/build/memory/hybrid_mem_assigner.h | 7 + src/ge/graph/build/memory/memory_assigner.cc | 25 +- .../graph/build/memory/var_mem_assign_util.cc | 46 +- src/ge/graph/build/model_builder.cc | 213 +- src/ge/graph/build/model_builder.h | 9 +- src/ge/graph/build/run_context.cc | 2 - src/ge/graph/build/stream_allocator.cc | 578 +++++- src/ge/graph/build/stream_allocator.h | 32 +- src/ge/graph/build/stream_graph_optimizer.cc | 131 +- src/ge/graph/build/stream_graph_optimizer.h | 4 +- src/ge/graph/build/task_generator.cc | 478 +++-- src/ge/graph/build/task_generator.h | 52 +- src/ge/graph/common/ge_call_wrapper.h | 38 + src/ge/graph/common/omg_util.cc | 11 +- src/ge/graph/common/transop_util.cc | 4 +- src/ge/graph/execute/graph_execute.cc | 28 +- src/ge/graph/execute/graph_execute.h | 7 +- src/ge/graph/label/case_label_maker.cc | 37 +- src/ge/graph/label/case_label_maker.h | 120 +- src/ge/graph/label/if_label_maker.cc | 55 +- src/ge/graph/label/if_label_maker.h | 90 +- src/ge/graph/label/label_maker.cc | 210 +- src/ge/graph/label/label_maker.h | 11 +- .../label/partitioned_call_label_maker.cc | 3 - src/ge/graph/label/while_label_maker.cc | 34 +- src/ge/graph/label/while_label_maker.h | 90 +- src/ge/graph/load/graph_loader.cc | 139 +- src/ge/graph/load/graph_loader.h | 11 +- .../new_model_manager/cpu_queue_schedule.cc | 99 +- .../new_model_manager/cpu_queue_schedule.h | 21 +- .../load/new_model_manager/data_dumper.cc | 315 ++- .../load/new_model_manager/data_dumper.h | 28 +- .../load/new_model_manager/davinci_model.cc | 1512 +++++++-------- .../load/new_model_manager/davinci_model.h | 322 ++-- .../new_model_manager/davinci_model_parser.cc | 4 +- .../load/new_model_manager/model_manager.cc | 199 +- .../load/new_model_manager/model_manager.h | 20 +- .../load/new_model_manager/model_output.cc | 41 - .../load/new_model_manager/model_utils.cc | 131 +- .../load/new_model_manager/model_utils.h | 22 - .../task_info/end_graph_task_info.cc | 9 +- .../task_info/end_graph_task_info.h | 5 +- .../task_info/hccl_task_info.cc | 83 +- .../task_info/hccl_task_info.h | 7 + .../task_info/kernel_ex_task_info.cc | 48 +- .../task_info/kernel_ex_task_info.h | 1 + .../task_info/kernel_task_info.cc | 300 +-- .../task_info/kernel_task_info.h | 16 +- .../task_info/label_goto_ex_task_info.cc | 70 + ..._task_info.h => label_goto_ex_task_info.h} | 12 +- .../task_info/label_goto_task_info.cc | 55 - .../task_info/label_set_task_info.cc | 35 +- .../label_switch_by_index_task_info.cc | 128 ++ .../label_switch_by_index_task_info.h | 42 + .../task_info/memcpy_addr_async_task_info.cc | 151 ++ .../task_info/memcpy_addr_async_task_info.h | 55 + .../task_info/memcpy_async_task_info.cc | 4 +- .../task_info/stream_active_task_info.cc | 2 +- .../task_info/stream_switch_task_info.cc | 12 +- .../task_info/stream_switchn_task_info.cc | 22 +- .../task_info/super_kernel/super_kernel.cc | 12 +- .../task_info/super_kernel/super_kernel.h | 15 +- .../super_kernel/super_kernel_factory.cc | 65 +- .../super_kernel/super_kernel_factory.h | 10 +- .../task_info/task_info_factory.h | 3 +- .../load/new_model_manager/zero_copy_task.cc | 179 ++ .../load/new_model_manager/zero_copy_task.h | 100 + src/ge/graph/load/output/output.h | 3 +- src/ge/graph/manager/graph_manager.cc | 780 ++++++-- src/ge/graph/manager/graph_manager.h | 47 +- src/ge/graph/manager/graph_manager_utils.cc | 20 +- src/ge/graph/manager/graph_manager_utils.h | 11 +- src/ge/graph/manager/graph_mem_allocator.cc | 9 +- src/ge/graph/manager/graph_mem_allocator.h | 6 +- src/ge/graph/manager/graph_var_manager.cc | 77 +- src/ge/graph/manager/graph_var_manager.h | 15 + src/ge/graph/manager/util/debug.cc | 2 +- src/ge/graph/manager/util/hcom_util.cc | 8 - .../manager/util/variable_accelerate_ctrl.cc | 2 +- src/ge/graph/optimize/common/params.h | 4 - src/ge/graph/optimize/graph_optimize.cc | 51 +- src/ge/graph/optimize/graph_optimize.h | 3 + src/ge/graph/optimize/summary_optimize.cc | 7 +- .../partition/dynamic_shape_partition.cc | 789 ++++++++ .../graph/partition/dynamic_shape_partition.h | 158 ++ src/ge/graph/partition/graph_partition.cc | 96 +- src/ge/graph/passes/addn_pass.cc | 2 +- .../passes/aicpu_constant_folding_pass.cc | 8 +- .../passes/aicpu_constant_folding_pass.h | 1 + src/ge/graph/passes/assert_pass.cc | 6 +- src/ge/graph/passes/atomic_addr_clean_pass.cc | 60 +- src/ge/graph/passes/base_pass.cc | 2 +- src/ge/graph/passes/cast_remove_pass.cc | 2 - src/ge/graph/passes/cast_translate_pass.cc | 6 +- .../common_subexpression_elimination_pass.cc | 16 +- src/ge/graph/passes/compile_nodes_pass.cc | 2 +- src/ge/graph/passes/compile_nodes_pass.h | 3 + src/ge/graph/passes/cond_pass.cc | 344 ++++ src/ge/graph/passes/cond_pass.h | 116 ++ .../graph/passes/constant_fuse_same_pass.cc | 3 - src/ge/graph/passes/control_op_attr_pass.cc | 256 --- src/ge/graph/passes/control_op_attr_pass.h | 47 - src/ge/graph/passes/control_trigger_pass.cc | 19 - src/ge/graph/passes/control_trigger_pass.h | 2 +- src/ge/graph/passes/dimension_adjust_pass.cc | 12 + src/ge/graph/passes/dropout_pass.cc | 2 +- src/ge/graph/passes/end_graph_pass.cc | 75 - src/ge/graph/passes/enter_pass.cc | 5 - src/ge/graph/passes/flow_ctrl_pass.cc | 34 +- .../graph/passes/folding_kernel/add_kernel.cc | 1 - .../graph/passes/folding_kernel/add_kernel.h | 2 +- .../folding_kernel/broadcast_args_kernel.cc | 2 - .../broadcast_gradient_args_kernel.cc | 2 - .../passes/folding_kernel/cast_kernel.cc | 14 +- .../folding_kernel/concat_offset_kernel.cc | 8 +- .../passes/folding_kernel/concat_v2_kernel.cc | 2 - .../folding_kernel/dynamic_stitch_kernel.cc | 7 +- .../passes/folding_kernel/empty_kernel.cc | 2 - .../folding_kernel/expanddims_kernel.cc | 2 - .../passes/folding_kernel/fill_kernel.cc | 4 - .../passes/folding_kernel/floordiv_kernel.cc | 2 - .../passes/folding_kernel/floordiv_kernel.h | 2 +- .../passes/folding_kernel/floormod_kernel.cc | 2 - .../passes/folding_kernel/gather_v2_kernel.cc | 7 +- .../passes/folding_kernel/greater_kernel.cc | 2 - .../passes/folding_kernel/kernel_utils.cc | 18 +- .../passes/folding_kernel/kernel_utils.h | 5 +- .../passes/folding_kernel/maximum_kernel.cc | 3 - .../graph/passes/folding_kernel/mul_kernel.cc | 2 - .../passes/folding_kernel/pack_kernel.cc | 30 +- .../passes/folding_kernel/permute_kernel.cc | 8 - .../passes/folding_kernel/range_kernel.cc | 2 - .../passes/folding_kernel/rank_kernel.cc | 3 - .../folding_kernel/reduce_prod_kernel.cc | 38 +- .../folding_kernel/reduce_prod_kernel.h | 2 +- .../passes/folding_kernel/reformat_kernel.cc | 2 - .../passes/folding_kernel/reshape_kernel.cc | 2 - .../passes/folding_kernel/rsqrt_kernel.cc | 8 +- .../passes/folding_kernel/shape_kernel.cc | 2 - .../passes/folding_kernel/shape_n_kernel.cc | 2 - .../passes/folding_kernel/size_kernel.cc | 3 +- .../passes/folding_kernel/slice_d_kernel.cc | 3 - .../passes/folding_kernel/slice_kernel.cc | 2 - .../passes/folding_kernel/squeeze_kernel.cc | 2 - .../folding_kernel/ssd_prior_box_kernel.cc | 15 +- .../folding_kernel/strided_slice_kernel.cc | 9 +- .../graph/passes/folding_kernel/sub_kernel.cc | 1 - .../passes/folding_kernel/transdata_kernel.cc | 10 +- .../passes/folding_kernel/transpose_kernel.cc | 161 ++ .../passes/folding_kernel/transpose_kernel.h | 34 + .../passes/folding_kernel/unpack_kernel.cc | 13 +- src/ge/graph/passes/folding_pass.cc | 16 +- src/ge/graph/passes/for_pass.cc | 732 +++++++ src/ge/graph/passes/for_pass.h | 192 ++ .../graph/passes/get_original_format_pass.cc | 16 +- src/ge/graph/passes/guarantee_const_pass.cc | 2 - src/ge/graph/passes/hccl_memcpy_pass.cc | 13 +- src/ge/graph/passes/identity_pass.cc | 7 +- .../graph/passes/isolated_op_remove_pass.cc | 3 - src/ge/graph/passes/iterator_op_pass.cc | 4 +- .../graph/passes/link_gen_mask_nodes_pass.cc | 10 +- .../graph/passes/link_gen_mask_nodes_pass.h | 4 + src/ge/graph/passes/merge_pass.cc | 4 +- src/ge/graph/passes/multi_batch_pass.cc | 16 +- src/ge/graph/passes/multi_batch_pass.h | 2 +- src/ge/graph/passes/net_output_pass.cc | 77 +- src/ge/graph/passes/net_output_pass.h | 16 +- src/ge/graph/passes/next_iteration_pass.cc | 9 - .../passes/no_use_reshape_remove_pass.cc | 11 +- .../passes/parallel_concat_start_op_pass.cc | 76 + .../parallel_concat_start_op_pass.h} | 20 +- src/ge/graph/passes/pass_manager.cc | 2 - src/ge/graph/passes/pass_utils.cc | 12 +- src/ge/graph/passes/pass_utils.h | 1 + src/ge/graph/passes/permute_pass.cc | 14 +- .../passes/placeholder_with_default_pass.cc | 2 - src/ge/graph/passes/prevent_gradient_pass.cc | 2 - src/ge/graph/passes/print_op_pass.h | 2 +- src/ge/graph/passes/prune_pass.cc | 4 - .../graph/passes/replace_transshape_pass.cc | 140 ++ ...graph_pass.h => replace_transshape_pass.h} | 23 +- .../passes/replace_with_empty_const_pass.cc | 156 ++ .../passes/replace_with_empty_const_pass.h | 34 + src/ge/graph/passes/reshape_remove_pass.cc | 22 +- .../same_transdata_breadth_fusion_pass.cc | 9 +- .../passes/shape_operate_op_remove_pass.cc | 3 +- src/ge/graph/passes/snapshot_pass.cc | 2 - src/ge/graph/passes/stop_gradient_pass.cc | 2 - src/ge/graph/passes/subgraph_pass.cc | 214 +++ src/ge/graph/passes/subgraph_pass.h | 91 + .../graph/passes/switch_logic_remove_pass.cc | 2 +- src/ge/graph/passes/switch_op_pass.cc | 65 +- src/ge/graph/passes/switch_op_pass.h | 4 +- src/ge/graph/passes/switch_pass.cc | 4 - .../passes/transop_breadth_fusion_pass.cc | 11 +- .../graph/passes/transop_depth_fusion_pass.cc | 11 +- .../transop_nearby_allreduce_fusion_pass.cc | 4 +- .../transop_symmetry_elimination_pass.cc | 169 ++ .../transop_symmetry_elimination_pass.h | 74 + .../transop_without_reshape_fusion_pass.cc | 23 +- .../graph/passes/transpose_transdata_pass.cc | 16 +- src/ge/graph/passes/unused_const_pass.cc | 4 +- src/ge/graph/passes/unused_op_remove_pass.cc | 6 - .../passes/var_is_initialized_op_pass.cc | 6 +- src/ge/graph/passes/variable_format_pass.cc | 6 +- src/ge/graph/passes/variable_op_pass.cc | 25 +- .../graph/passes/variable_prepare_op_pass.cc | 151 +- .../graph/passes/variable_prepare_op_pass.h | 8 +- .../passes/variable_ref_delete_op_pass.cc | 46 +- src/ge/graph/preprocess/graph_preprocess.cc | 1208 ++++++++++-- src/ge/graph/preprocess/graph_preprocess.h | 27 +- .../preprocess/insert_op/base_insert_op.h | 4 +- .../graph/preprocess/insert_op/ge_aipp_op.cc | 33 +- .../graph/preprocess/insert_op/ge_aipp_op.h | 2 +- .../insert_op/util_insert_aipp_op.cc | 43 +- .../preprocess/multi_batch_copy_graph.cc | 28 +- src/ge/inc/graph_pass.h | 10 +- src/ge/inc/kernel.h | 18 +- src/ge/inc/kernel_factory.h | 2 +- src/ge/inc/node_pass.h | 66 - src/ge/init/gelib.cc | 72 +- src/ge/init/gelib.h | 11 +- src/ge/ir_build/atc_ir_common.cc | 254 +++ src/ge/ir_build/atc_ir_common.h | 43 + src/ge/ir_build/ge_ir_build.cc | 50 +- src/ge/omm/csa_interact.cc | 2 - .../opskernel_manager/ops_kernel_manager.cc | 45 +- src/ge/opskernel_manager/ops_kernel_manager.h | 4 + .../optimizer_priority.pbtxt | 1 + src/ge/plugin/engine/CMakeLists.txt | 2 +- src/ge/session/inner_session.cc | 14 +- src/ge/session/inner_session.h | 3 +- src/ge/session/session_manager.cc | 17 +- src/ge/session/session_manager.h | 5 +- src/ge/single_op/single_op.cc | 24 +- src/ge/single_op/single_op.h | 1 + src/ge/single_op/single_op_manager.cc | 32 +- src/ge/single_op/single_op_model.cc | 60 +- src/ge/single_op/single_op_model.h | 7 +- src/ge/single_op/stream_resource.cc | 13 +- src/ge/single_op/stream_resource.h | 7 +- src/ge/single_op/task/aicpu_task_builder.cc | 135 ++ src/ge/single_op/task/aicpu_task_builder.h | 44 + src/ge/single_op/task/op_task.cc | 51 +- src/ge/single_op/task/op_task.h | 31 +- src/ge/single_op/task/tbe_task_builder.cc | 25 +- src/ge/single_op/task/tbe_task_builder.h | 7 +- src/proto/ge_ir.proto | 5 +- src/proto/op_mapping_info.proto | 10 + src/proto/task.proto | 24 +- .../fwkacllib/inc/cce/fwk_adpt_struct.h | 18 + third_party/fwkacllib/inc/hccl/base.h | 6 + third_party/fwkacllib/inc/hccl/hcom.h | 2 +- third_party/fwkacllib/inc/ops/aipp.h | 2 +- third_party/fwkacllib/inc/ops/all_ops.h | 4 +- third_party/fwkacllib/inc/ops/array_ops.h | 95 +- third_party/fwkacllib/inc/ops/condtake_ops.h | 55 + third_party/fwkacllib/inc/ops/data_flow_ops.h | 156 +- .../inc/ops/elewise_calculation_ops.h | 116 +- .../inc/ops/fsrdetectionoutput_ops.h | 67 - .../fwkacllib/inc/ops/functional_ops.h | 98 + third_party/fwkacllib/inc/ops/hcom_ops.h | 5 +- third_party/fwkacllib/inc/ops/image_ops.h | 5 +- third_party/fwkacllib/inc/ops/linalg_ops.h | 2 +- third_party/fwkacllib/inc/ops/math_ops.h | 145 +- .../inc/ops/matrix_calculation_ops.h | 195 +- .../fwkacllib/inc/ops/nn_batch_norm_ops.h | 2 +- .../fwkacllib/inc/ops/nn_calculation_ops.h | 102 +- third_party/fwkacllib/inc/ops/nn_detect_ops.h | 410 +++- third_party/fwkacllib/inc/ops/nn_norm_ops.h | 96 +- third_party/fwkacllib/inc/ops/nn_ops.h | 199 +- .../fwkacllib/inc/ops/nn_pooling_ops.h | 211 +- .../fwkacllib/inc/ops/nn_training_ops.h | 1198 +++++++++++- .../fwkacllib/inc/ops/nonlinear_fuc_ops.h | 55 +- .../fwkacllib/inc/ops/npu_loss_scale_ops.h | 2 +- third_party/fwkacllib/inc/ops/outfeed_ops.h | 30 +- third_party/fwkacllib/inc/ops/pad_ops.h | 2 +- third_party/fwkacllib/inc/ops/power_ops.h | 49 - third_party/fwkacllib/inc/ops/quantize_ops.h | 52 +- .../fwkacllib/inc/ops/ragged_array_ops.h | 6 +- .../fwkacllib/inc/ops/ragged_conversion_ops.h | 38 + .../fwkacllib/inc/ops/ragged_math_ops.h | 8 +- third_party/fwkacllib/inc/ops/reduce_ops.h | 51 +- third_party/fwkacllib/inc/ops/rnn.h | 10 +- .../fwkacllib/inc/ops/roipooling_ops.h | 78 - third_party/fwkacllib/inc/ops/rpn_ops.h | 2 +- .../inc/ops/rpn_proposal_post_processing.h | 39 + .../fwkacllib/inc/ops/score_filter_pre_sort.h | 36 + third_party/fwkacllib/inc/ops/sdca_ops.h | 2 +- third_party/fwkacllib/inc/ops/selection_ops.h | 466 ++--- third_party/fwkacllib/inc/ops/sparse_ops.h | 2 +- third_party/fwkacllib/inc/ops/spectral_ops.h | 46 + .../fwkacllib/inc/ops/split_combination_ops.h | 2 +- third_party/fwkacllib/inc/ops/state_ops.h | 10 + .../fwkacllib/inc/ops/stateful_random_ops.h | 18 +- .../fwkacllib/inc/ops/transformation_ops.h | 60 +- .../inc/register/op_kernel_registry.h | 3 +- .../fwkacllib/inc/register/op_registry.h | 3 + third_party/fwkacllib/inc/register/register.h | 53 + .../inc/register/register_format_transfer.h | 22 +- third_party/fwkacllib/inc/runtime/base.h | 44 + third_party/fwkacllib/inc/runtime/config.h | 2 + third_party/fwkacllib/inc/runtime/dev.h | 9 +- third_party/fwkacllib/inc/runtime/kernel.h | 2 +- third_party/fwkacllib/inc/runtime/mem.h | 11 +- third_party/fwkacllib/inc/runtime/rt_model.h | 295 +-- third_party/fwkacllib/inc/runtime/stream.h | 10 +- third_party/fwkacllib/inc/tdt/data_common.h | 22 + third_party/fwkacllib/inc/tdt/status.h | 2 + .../fwkacllib/inc/tdt/tdt_host_interface.h | 18 + third_party/fwkacllib/inc/tdt/tsd_client.h | 82 + third_party/fwkacllib/inc/toolchain/slog.h | 31 +- 473 files changed, 22416 insertions(+), 9544 deletions(-) create mode 100644 inc/common/util/compress/compress.h create mode 100644 inc/common/util/error_manager/error_manager.h delete mode 100644 inc/framework/common/op/attr_define.h mode change 100644 => 100755 inc/framework/ge_runtime/task_info.h create mode 100644 inc/graph/ref_relation.h create mode 100644 src/common/graph/ref_relation.cc delete mode 100644 src/ge/common/formats/format_transfers/format_transfer.cc create mode 100644 src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc create mode 100644 src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h create mode 100644 src/ge/common/helper/model_cache_helper.cc create mode 100644 src/ge/common/helper/model_cache_helper.h delete mode 100644 src/ge/common/op/attr_define.cc create mode 100644 src/ge/graph/common/ge_call_wrapper.h delete mode 100644 src/ge/graph/load/new_model_manager/model_output.cc create mode 100644 src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc rename src/ge/graph/load/new_model_manager/task_info/{label_goto_task_info.h => label_goto_ex_task_info.h} (75%) delete mode 100644 src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.cc create mode 100644 src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc create mode 100644 src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h create mode 100644 src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc create mode 100644 src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h create mode 100644 src/ge/graph/load/new_model_manager/zero_copy_task.cc create mode 100644 src/ge/graph/load/new_model_manager/zero_copy_task.h create mode 100644 src/ge/graph/partition/dynamic_shape_partition.cc create mode 100644 src/ge/graph/partition/dynamic_shape_partition.h create mode 100644 src/ge/graph/passes/cond_pass.cc create mode 100644 src/ge/graph/passes/cond_pass.h delete mode 100644 src/ge/graph/passes/control_op_attr_pass.cc delete mode 100644 src/ge/graph/passes/control_op_attr_pass.h delete mode 100644 src/ge/graph/passes/end_graph_pass.cc create mode 100644 src/ge/graph/passes/folding_kernel/transpose_kernel.cc create mode 100644 src/ge/graph/passes/folding_kernel/transpose_kernel.h create mode 100644 src/ge/graph/passes/for_pass.cc create mode 100644 src/ge/graph/passes/for_pass.h create mode 100644 src/ge/graph/passes/parallel_concat_start_op_pass.cc rename src/ge/graph/{load/new_model_manager/model_output.h => passes/parallel_concat_start_op_pass.h} (57%) create mode 100644 src/ge/graph/passes/replace_transshape_pass.cc rename src/ge/graph/passes/{end_graph_pass.h => replace_transshape_pass.h} (59%) create mode 100644 src/ge/graph/passes/replace_with_empty_const_pass.cc create mode 100644 src/ge/graph/passes/replace_with_empty_const_pass.h create mode 100644 src/ge/graph/passes/subgraph_pass.cc create mode 100644 src/ge/graph/passes/subgraph_pass.h create mode 100644 src/ge/graph/passes/transop_symmetry_elimination_pass.cc create mode 100644 src/ge/graph/passes/transop_symmetry_elimination_pass.h delete mode 100644 src/ge/inc/node_pass.h create mode 100644 src/ge/ir_build/atc_ir_common.cc create mode 100644 src/ge/ir_build/atc_ir_common.h create mode 100644 src/ge/opskernel_manager/optimizer_priority.pbtxt create mode 100644 src/ge/single_op/task/aicpu_task_builder.cc create mode 100644 src/ge/single_op/task/aicpu_task_builder.h create mode 100644 third_party/fwkacllib/inc/ops/condtake_ops.h delete mode 100644 third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h delete mode 100644 third_party/fwkacllib/inc/ops/power_ops.h delete mode 100644 third_party/fwkacllib/inc/ops/roipooling_ops.h create mode 100644 third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h create mode 100644 third_party/fwkacllib/inc/ops/score_filter_pre_sort.h create mode 100644 third_party/fwkacllib/inc/ops/spectral_ops.h create mode 100644 third_party/fwkacllib/inc/register/register.h rename src/ge/common/formats/format_transfers/format_transfer.h => third_party/fwkacllib/inc/register/register_format_transfer.h (83%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 373edcf2..cd8192f8 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,6 +16,7 @@ cmake_minimum_required(VERSION 3.14) project (GraphEngine[CXX]) set(CMAKE_CXX_STANDARD 14) +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) @@ -71,6 +72,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) find_library(register libregister.so ${GE_LIB_PATH}) find_library(hccl libhccl.so ${GE_LIB_PATH}) find_library(resource libresource.so ${GE_LIB_PATH}) + find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) else() # Ascend mode if(DEFINED ENV{ASCEND_CUSTOM_PATH}) @@ -88,6 +90,7 @@ else() find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) + find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) endif() # add compile flags diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h index bce3cb18..5897842f 100644 --- a/inc/common/optimizer/graph_optimizer.h +++ b/inc/common/optimizer/graph_optimizer.h @@ -44,6 +44,9 @@ class GraphOptimizer { // optimize original graph, using in graph preparation stage virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; + // optimize original graph, using for conversion operator insert in graph preparation stage + virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } + // optimize fused graph virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h new file mode 100644 index 00000000..6908fb75 --- /dev/null +++ b/inc/common/util/compress/compress.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_H +#define COMPRESS_H + +#include + +enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; + +struct CompressConfig { + size_t inputSize; // length of data to compress + size_t engineNum; // how many decompress engines + size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) + size_t channel; // channels of L2 or DDR. For load balance + size_t fractalSize; // size of compressing block + bool isTight; // whether compose compressed data tightly +}; + +CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, + size_t& compressedLength); + +#endif // COMPRESS_H diff --git a/inc/common/util/error_manager/error_manager.h b/inc/common/util/error_manager/error_manager.h new file mode 100644 index 00000000..76d5ce33 --- /dev/null +++ b/inc/common/util/error_manager/error_manager.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ERROR_MANAGER_H_ +#define ERROR_MANAGER_H_ + +#include +#include +#include + +class ErrorManager { + public: + /// + /// @brief Obtain ErrorManager instance + /// @return ErrorManager instance + /// + static ErrorManager &GetInstance(); + + /// + /// @brief init + /// @param [in] path current so path + /// @return int 0(success) -1(fail) + /// + int Init(std::string path); + + /// + /// @brief Report error message + /// @param [in] errCode error code + /// @param [in] mapArgs parameter map + /// @return int 0(success) -1(fail) + /// + int ReportErrMessage(std::string error_code, const std::map &args_map); + + /// @brief output error message + /// @param [in] handle print handle + /// @return int 0(success) -1(fail) + /// + int OutputErrMessage(int handle); + + /// @brief Report error message + /// @param [in] vector parameter key, vector parameter value + /// + void ATCReportErrMessage(std::string error_code, const std::vector &key = {}, + const std::vector &value = {}); + + private: + struct ErrorInfo { + std::string error_id; + std::string error_message; + std::vector arglist; + }; + + ErrorManager() {} + ~ErrorManager() {} + + ErrorManager(const ErrorManager &) = delete; + ErrorManager(ErrorManager &&) = delete; + ErrorManager &operator=(const ErrorManager &) = delete; + ErrorManager &operator=(ErrorManager &&) = delete; + + int ParseJsonFile(std::string path); + + int ReadJsonFile(const std::string &file_path, void *handle); + + bool is_init_ = false; + std::map error_map_; + std::vector error_message_evc_; +}; + +#endif // ERROR_MANAGER_H_ diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h index 52dc0621..cd143fcc 100644 --- a/inc/common/util/platform_info.h +++ b/inc/common/util/platform_info.h @@ -65,6 +65,8 @@ class PlatformInfoManager { void ParseUBOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + void ParseUnzipOfAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); + void ParseAICoreSpec(map &aiCoreSpecMap, PlatformInfo &platformInfoTemp); void ParseBufferOfAICoreMemoryRates(map &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); diff --git a/inc/common/util/platform_info_def.h b/inc/common/util/platform_info_def.h index 663a2cae..e840a8b9 100644 --- a/inc/common/util/platform_info_def.h +++ b/inc/common/util/platform_info_def.h @@ -65,6 +65,10 @@ typedef struct tagAiCoreSpec { uint64_t ubbankNum; uint64_t ubburstInOneBlock; uint64_t ubbankGroupNum; + uint32_t unzipEngines; + uint32_t unzipMaxRatios; + uint32_t unzipChannels; + uint8_t unzipIsTight; } AiCoreSpec; typedef struct tagAiCoreMemoryRates { diff --git a/inc/external/ge/ge_api.h b/inc/external/ge/ge_api.h index e9beae6f..f3e9fcb6 100644 --- a/inc/external/ge/ge_api.h +++ b/inc/external/ge/ge_api.h @@ -82,14 +82,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// @brief run graph in the session with specific session id asynchronously /// @param [in] graphId: graph id /// @param [in] inputs: input data - /// @param [out] outputs: output data /// @param [out] callback: callback while runing graph has been finished. /// The callback function will not be checked. /// Please ensure that the implementation of the function is trusted. /// @return Status result of function /// - Status RunGraphAsync(uint32_t graphId, const std::vector &inputs, - std::vector &outputs, std::function callback); + Status RunGraphAsync(uint32_t graphId, const std::vector &inputs, RunAsyncCallback callback); /// /// @ingroup ge_graph diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index bf9a10b4..6fa269ce 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -21,6 +21,8 @@ #include #include #include +#include +#include namespace ge { // Option key: graph run mode @@ -40,6 +42,12 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; +const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; +const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +// profiling flag +const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; +const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; @@ -173,6 +181,9 @@ const std::string AICORE_NUM = "ge.aicoreNum"; // Configure L1FUSION const std::string L1_FUSION = "ge.l1Fusion"; +// Configure l1,l2,and others optimize option +const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; + // Configure Small Channel flag const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; @@ -188,6 +199,9 @@ const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; // Save original model file name const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; +// FE enable quant optimize +const std::string QUANT_OPTIMIZE = "ge.quantOptimize"; + const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; @@ -196,36 +210,49 @@ const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; // Its value should be "0" or "1", default value is "1" const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; +// Configure whether to use single stream. +// Its value should be "true" or "false", default value is "false" +const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; -// Data description -struct DataDesc { - void *data = nullptr; // data address - uint32_t length = 0; // data size - bool isDataSupportMemShare = false; +// Input/Output tensor info +struct InputTensorInfo { + uint32_t data_type; // data type + std::vector dims; // shape description + void *data; // tensor data + int64_t length; // tensor length }; -// Input/Output shape description -struct ShapeDesc { - int64_t num = 0; - int64_t channel = 0; - int64_t height = 0; - int64_t width = 0; - std::vector dims; +struct OutputTensorInfo { + uint32_t data_type; // data type + std::vector dims; // shape description + std::unique_ptr data; // tensor data + int64_t length; // tensor length + OutputTensorInfo() : data_type(0), dims({}), data(nullptr), length(0) {} + OutputTensorInfo(OutputTensorInfo &&out) + : data_type(out.data_type), dims(out.dims), data(std::move(out.data)), length(out.length) {} + + OutputTensorInfo &operator=(OutputTensorInfo &&out) { + if (this != &out) { + data_type = out.data_type; + dims = out.dims; + data = std::move(out.data); + length = out.length; + } + return *this; + } + OutputTensorInfo(const OutputTensorInfo &) = delete; + OutputTensorInfo &operator=(const OutputTensorInfo &) = delete; }; -// Input/Output tensor info -struct TensorInfo { - uint32_t dataType; // data type - DataDesc data; // tensor data - ShapeDesc shapeInfo; // tensor shape -}; +using Status = uint32_t; +using RunAsyncCallback = std::function &)>; // for ir build namespace ir_option { static const char *const INPUT_FORMAT = "input_format"; static const char *const INPUT_SHAPE = "input_shape"; -static const char *const OP_NAME_MAP = "op_name_map"; static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); @@ -235,13 +262,15 @@ static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); +static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; + // for interface: aclgrphBuildModel -const std::set ir_builder_suppported_options = { - INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, - DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, - AUTO_TUNE_MODE}; +const std::set ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, + DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; // for interface: aclgrphBuildInitialize -const std::set global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; +const std::set global_options = { + HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; } // namespace ir_option } // namespace ge diff --git a/inc/external/graph/graph.h b/inc/external/graph/graph.h index b4ebb435..30886733 100644 --- a/inc/external/graph/graph.h +++ b/inc/external/graph/graph.h @@ -55,12 +55,16 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { graphStatus FindOpByName(const string &name, ge::Operator &op) const; + graphStatus FindOpByType(const string &type, std::vector &ops) const; + graphStatus GetAllOpName(std::vector &op_name) const; graphStatus SaveToFile(const string &file_name) const; graphStatus LoadFromFile(const string &file_name); + const std::string &GetName() const; + /// /// Set is need train iteration. /// If set true, it means this graph need to be run iteration some diff --git a/inc/external/graph/inference_context.h b/inc/external/graph/inference_context.h index 68a9ecf5..69079142 100644 --- a/inc/external/graph/inference_context.h +++ b/inc/external/graph/inference_context.h @@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { static std::unique_ptr Create(); private: - InferenceContext(std::unique_ptr &impl); + explicit InferenceContext(std::unique_ptr &impl); std::shared_ptr inference_context_impl_; }; } // namespace ge diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h index ed2e639a..be7f10db 100644 --- a/inc/external/graph/operator.h +++ b/inc/external/graph/operator.h @@ -44,11 +44,16 @@ namespace ge { class OperatorImpl; - +class NamedAttrs; +class Graph; class AttrValue; +using SubgraphBuilder = std::function; using OperatorImplPtr = std::shared_ptr; +class Graph; +using GraphBuilderCallback = std::function; + class OpIO; using OutHandler = std::shared_ptr; using InHandler = std::shared_ptr; @@ -69,6 +74,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { using OpBool = bool; using OpTensor = Tensor; using OpType = ge::DataType; + using OpNamedAttrs = ge::NamedAttrs; using OpListInt = std::vector; using OpListFloat = std::vector; using OpListString = std::vector; @@ -77,6 +83,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { using OpBytes = std::vector; using OpListListInt = std::vector>; using OpListType = std::vector; + using OpListNamedAttrs = std::vector; Operator() {} @@ -132,6 +139,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { void SetInferenceContext(const InferenceContextPtr &inference_context); InferenceContextPtr GetInferenceContext() const; + void SetGraphBuilder(const GraphBuilderCallback &builder); + graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; + + void AddSubgraphName(const string &name); + string GetSubgraphName(int index) const; + graphStatus VerifyAllAttr(bool disable_common_verifier = false); size_t GetInputsSize() const; @@ -190,8 +203,21 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { Operator &SetAttr(const string &name, const ge::DataType &attr_value); graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; + // func type + Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); + graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; + Operator &SetAttr(const string &name, const std::vector &attr_value); + graphStatus GetAttr(const string &name, std::vector &attr_value) const; + void BreakConnect() const; + size_t GetSubgraphNamesCount() const; + std::vector GetSubgraphNames() const; + SubgraphBuilder GetSubgraphBuilder(const string &name) const; + Graph GetSubgraph(const string &name) const; + SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; + Graph GetDynamicSubgraph(const string &name, uint32_t index) const; + protected: void AttrRegister(const string &name, float attr_value); void AttrRegister(const string &name, const std::vector &attr_value); @@ -207,6 +233,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { void AttrRegister(const string &name, const std::vector> &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const ge::DataType &attr_value); + void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); + void AttrRegister(const string &name, const std::vector &attr_value); explicit Operator(OperatorImplPtr &&op_impl); @@ -224,6 +252,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); + void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); + void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); void RequiredAttrRegister(const string &name); @@ -235,6 +265,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); + void SubgraphRegister(const std::string &name, bool dynamic); + void SubgraphCountRegister(const std::string &name, uint32_t count); + void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); + private: Operator &SetInput(const string &dst_name, const OutHandler &out_handler); diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h index 2878b4eb..57b1f8fe 100644 --- a/inc/external/graph/operator_reg.h +++ b/inc/external/graph/operator_reg.h @@ -22,10 +22,11 @@ #include #include -#include "./operator.h" -#include "./operator_factory.h" -#include "./tensor.h" -#include "./types.h" +#include "graph/operator.h" +#include "graph/operator_factory.h" +#include "graph/tensor.h" +#include "graph/types.h" +#include "graph/graph.h" namespace ge { using std::function; @@ -46,6 +47,10 @@ class OpReg { OpReg &OUTPUT() { return *this; } + OpReg &GRAPH() { return *this; } + + OpReg &DYNAMIC_GRAPH() { return *this; } + OpReg &INFER_SHAPE_AND_TYPE() { return *this; } }; @@ -191,6 +196,10 @@ class OpReg { Operator::DynamicInputRegister(#x, num, isPushBack); \ return *this; \ } \ + _THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ + Operator::DynamicInputRegisterByIndex(#x, num, index); \ + return *this; \ + } \ TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ @@ -229,6 +238,51 @@ class OpReg { void __dy_output_##x() { \ (void)OpReg() +#define GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + static const string name_graph_##x() { return #x; } \ + SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ + _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, 0, v); \ + return *this; \ + } \ + Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, false); \ + Operator::SubgraphCountRegister(#x, 1); \ + (void)OpReg() + +#define DYNAMIC_GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + static const string name_graph_##x() { return #x; } \ + _THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ + Operator::SubgraphCountRegister(#x, num); \ + return *this; \ + } \ + SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ + return Operator::GetDynamicSubgraphBuilder(#x, index); \ + } \ + Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ + _THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, index, v); \ + return *this; \ + } \ + \ + private: \ + void __graph_##x() { \ + Operator::SubgraphRegister(#x, true); \ + (void)OpReg() + #define PASTE(g_register, y) g_register##y #define __OP_END_IMPL__(x, y) \ N(); \ diff --git a/inc/external/graph/tensor.h b/inc/external/graph/tensor.h index f60d245b..5174c248 100644 --- a/inc/external/graph/tensor.h +++ b/inc/external/graph/tensor.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "./ge_error_codes.h" #include "./types.h" @@ -62,6 +63,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); Shape GetShape() const; void SetShape(const Shape &shape); + // set shape with -2, it stand for unknown shape + graphStatus SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; Format GetFormat() const; void SetFormat(Format format); diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h index c76c5556..46cb34b9 100644 --- a/inc/external/graph/types.h +++ b/inc/external/graph/types.h @@ -23,7 +23,9 @@ namespace ge { static const int64_t UNKNOWN_DIM = -1; +static const int64_t UNKNOWN_DIM_NUM = -2; static const std::vector UNKNOWN_SHAPE = {0}; +static const std::vector UNKNOWN_RANK = {-2}; #ifdef HOST_VISIBILITY #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) @@ -140,10 +142,19 @@ enum Format { FORMAT_NC, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + FORMAT_FRACTAL_ZN_LSTM, FORMAT_RESERVED, FORMAT_ALL }; +// for unknown shape op type +enum UnknowShapeOpType { + DEPEND_IN_SHAPE = 1, // op out shape get by input shape + DEPEND_CONST_VALUE = 2, // op out shape get by const op value + DEPEND_SHAPE_RANGE = 3, // op out shape get by range + DEPEND_COMPUTE = 4 // op out shape get by totally computing +}; + struct TensorDescInfo { Format format_ = FORMAT_RESERVED; // tbe op register support format DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype diff --git a/inc/external/register/register.h b/inc/external/register/register.h index 045a1570..28c984bf 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -58,12 +58,18 @@ Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, std::map> dynamic_name_attr_value, int in_pos = -1, int out_pos = -1); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function &input, + const std::function &output); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); using google::protobuf::Message; class OpRegistrationDataImpl; using ParseParamFunc = std::function; using FusionParseParamFunc = std::function, ge::Operator &)>; +using ParseSubgraphFunc = std::function; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { public: @@ -81,6 +87,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); + OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); @@ -93,6 +101,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { domi::FrameworkType GetFrameworkType() const; ParseParamFunc GetParseParamFn() const; FusionParseParamFunc GetFusionParseParamFn() const; + ParseSubgraphFunc GetParseSubgraphPostFn() const; private: std::shared_ptr impl_; @@ -116,27 +125,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { namespace ge { using OpRegistrationData = domi::OpRegistrationData; using OpReceiver = domi::OpReceiver; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { - public: - HostCpuOp() = default; - virtual ~HostCpuOp() = default; - - virtual graphStatus Compute(Operator &op, const std::map &inputs, - std::map &outputs) = 0; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { - public: - HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); -}; - -#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ - static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ - ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) } // namespace ge #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index f2df79a7..e2023cb8 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -51,24 +51,24 @@ inline pid_t GetTid() { return tid; } -#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() +#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() #define GE_TIMESTAMP_END(stage, stage_name) \ do { \ - uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ (endUsec_##stage - startUsec_##stage)); \ } while (0); -#define GE_TIMESTAMP_CALLNUM_START(stage) \ - uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ - uint64_t call_num_of##stage = 0; \ +#define GE_TIMESTAMP_CALLNUM_START(stage) \ + uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ + uint64_t call_num_of##stage = 0; \ uint64_t time_of##stage = 0 -#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) +#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) -#define GE_TIMESTAMP_ADD(stage) \ - time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ +#define GE_TIMESTAMP_ADD(stage) \ + time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ call_num_of##stage++ #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index 9a192a82..147c3bdf 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -22,7 +22,6 @@ #include "cce/cce_def.hpp" #include "common/string_util.h" #include "common/util.h" -#include "dlog/log.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api_error_codes.h" @@ -30,7 +29,7 @@ using cce::CC_STATUS_SUCCESS; using cce::ccStatus_t; #if !defined(__ANDROID__) && !defined(ANDROID) -#define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) +#define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) #else #include #if defined(BUILD_VERSION_PERF) @@ -103,17 +102,17 @@ using cce::ccStatus_t; } while (0); // If expr is not true, print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - std::string msg; \ - (void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ - (void)msg.append( \ - domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ - DOMI_LOGE("%s", msg.c_str()); \ - return _status; \ - } \ +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + std::string msg; \ + (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ + (void)msg.append( \ + ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ + DOMI_LOGE("%s", msg.c_str()); \ + return _status; \ + } \ } while (0); // If expr is not true, print the log and return the specified status diff --git a/inc/framework/common/ge_inner_error_codes.h b/inc/framework/common/ge_inner_error_codes.h index b563aef7..4b5538d3 100644 --- a/inc/framework/common/ge_inner_error_codes.h +++ b/inc/framework/common/ge_inner_error_codes.h @@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 -GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 6ff3404e..6c70aa4c 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -25,6 +25,7 @@ #include "common/fmk_error_codes.h" #include "ge/ge_api_error_codes.h" #include "external/graph/types.h" +#include "external/ge/ge_api_types.h" namespace ge { enum RuntimeType { HOST = 0, DEVICE = 1 }; @@ -130,7 +131,8 @@ class ModelListener { /// @param [in] data_index Index of the input_data /// @param [in] resultCode Execution results /// - virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code) = 0; + virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, + std::vector &outputs) = 0; }; // OMM configuration item @@ -147,6 +149,8 @@ struct Options { std::string rankTableFile; int32_t ge_hccl_flag = 0; int32_t physical_device_id; + std::string profiling_mode; + std::string profiling_options; }; // Profiling info of task diff --git a/inc/framework/common/gflags_util.h b/inc/framework/common/gflags_util.h index 4fb9511f..94d66ffb 100644 --- a/inc/framework/common/gflags_util.h +++ b/inc/framework/common/gflags_util.h @@ -20,7 +20,7 @@ #include #include -namespace domi { +namespace ge { class GflagsUtils { public: static bool IsSetCommandTrue(const char *name) { @@ -66,6 +66,6 @@ class GflagsUtils { } } }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index c16e3c23..bd9a6c57 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -26,7 +26,7 @@ #include "graph/model.h" #include "model/ge_model.h" -namespace domi { +namespace ge { class ModelHelper { public: ModelHelper() = default; @@ -38,7 +38,7 @@ class ModelHelper { Status LoadModel(const ge::ModelData& model_data); Status GetModelBufferData(ge::ModelBufferData& model); - ModelFileHeader* GetFileHeader() { return file_header_; } + const ModelFileHeader* GetFileHeader() const { return file_header_; } GeModelPtr GetGeModel(); void SetSaveMode(bool val) { is_offline_ = val; } @@ -65,9 +65,8 @@ class ModelHelper { Status LoadTask(OmFileLoadHelper& om_load_helper); Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); Status ReleaseLocalModelData() noexcept; - Status SaveModelPartition(std::shared_ptr& om_file_save_helper, ModelPartitionType type, const uint8_t* data, size_t size); }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ diff --git a/inc/framework/common/helper/om_file_helper.h b/inc/framework/common/helper/om_file_helper.h index 7c301f97..1e4cee9b 100644 --- a/inc/framework/common/helper/om_file_helper.h +++ b/inc/framework/common/helper/om_file_helper.h @@ -26,8 +26,10 @@ #include "framework/common/ge_types.h" using ProcParam = struct PROC_PARAM; +using std::string; +using std::vector; -namespace domi { +namespace ge { struct ModelPartition { ModelPartitionType type; uint8_t *data = 0; @@ -88,5 +90,5 @@ class OmFileSaveHelper { ModelFileHeader model_header_; OmFileContext context_; }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ diff --git a/inc/framework/common/l2_cache_optimize.h b/inc/framework/common/l2_cache_optimize.h index 8aa0a5d1..c65f67b3 100644 --- a/inc/framework/common/l2_cache_optimize.h +++ b/inc/framework/common/l2_cache_optimize.h @@ -30,7 +30,7 @@ using std::vector; -namespace domi { +namespace ge { // Size of RC memory alignment, 2M constexpr size_t ALIGN_SIZE = 2097152; @@ -118,6 +118,6 @@ class L2CacheOptimize { bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ \ No newline at end of file diff --git a/inc/framework/common/op/attr_define.h b/inc/framework/common/op/attr_define.h deleted file mode 100644 index 536a860e..00000000 --- a/inc/framework/common/op/attr_define.h +++ /dev/null @@ -1,810 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ -#define INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ - -#include -#include "framework/common/fmk_types.h" - -namespace domi { -// Public Attribute -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAME; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TYPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHT_NAME; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALPHA; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BETA; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODES; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS_TERM; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_HAS_BIAS_VALUE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_MODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SCALE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WINDOWS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GLOBAL_POOLING; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CEIL_MODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELU_FLAG; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALGO; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FORMAT; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER_FORMAT; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_K; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_NORM_REGION; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_LOCAL_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_ALPHA; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_BETA; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROADCAST; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TIDX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TPADDINGS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_W; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TMULTIPLES; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTIPLES; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_T; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_N; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TSHAPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAN_OPT; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AIPP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string NEW_AIPP_CONV_OP; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SESSION_GRAPH_ID; - -static const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; -static const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_BATCH_NUM; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_FORMAT; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_OP_DEF; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_TENSOR_DESC; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INFERRED_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PRED_PERMUTE_DELETED; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IGNORE_PRED_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHTS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DIM_ALIGN; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_START; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; - -// To be deleted -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_TO_BE_DELETED; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_LOC_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_CONF_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_OCR_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; - -// Refinedet -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_LOC_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_CONF_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIORBOX_CONCAT; - -// _Arg -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INDEX; -// _RetVal -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETVAL_ATTR_NAME_INDEX; -// Data -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DATA_ATTR_NAME_DATA_TYPE; - -// Send -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SEND_ATTR_EVENT_ID; - -// Recv -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RECV_ATTR_EVENT_ID; - -// convolution -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_COEF; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDES; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATIONS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_MODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ALGO; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_GROUP; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD_MODE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_STRIDE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_DILATION; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_NUM_OUTPUT; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_KERNEL; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_FILTER; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BIAS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_RELU_FLAG; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ADJ; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_TARGET_SHAPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BEFORE_PAD; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_HAS_BIAS; - -// Pooling -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAN_OPT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_GLOBAL_POOLING; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_WINDOW; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_STRIDE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_CEIL_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_DATA_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_BEFORE_PAD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAME_ALGO; - -// Eltwise -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_COEFF; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_WEIGHT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_RELU_FLAG; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_ALPHA; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_BETA; - -// BatchNorm -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_EPSILON; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_BIAS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_DATA_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; - -// Huberloss -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HUBER_LOSS_ATTR_DELTA; - -// SSDRealDivTileMul -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; - -// SSDSumMulRealDivMean -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string - SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; - -/// ConcatFive2Four -/// ConcatFour2Five -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_CLASS_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TRANS_FOR_LOSS_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOX_TYPE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_HIGH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_WIDTH; - -// Scale -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_BIAS; - -// FullConnection -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_FILTER; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_BIAS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_RELU_FLAG; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_ATTR_NAME_ALGO; - -// SoftmaxOpParams -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_ALGO; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_MODE; - -// SparseSoftmaxCrossEntropy -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; - -// Activation -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_COEF; - -// Concat -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_ATTR_NAME_AXIS; - -// Const -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_TYPE; - -// Roipooling -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; - -// DetectionOutput -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_TOP_K; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; -// Ssd DetectionOutput -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ETA; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string - DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; -// Refinedet DetectionOutput -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; -// yolo DetectionOutput -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ClASSES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BIASES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_RELATIVE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; - -// DetectionPostprocess -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; - -// Spatialtransfrom -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; - -// Proposal -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_RATIO; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_W; -// Softmax -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_AXIS; - -// Permute -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_ORDER; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_PERM; - -// SSD Normalize -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_EPS; - -// Flatten -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_END_AXIS; - -// SsdPRIORBOX -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_FLIP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_CLIP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_OFFSET; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; - -// RefinedetPRIORBOX -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; -// PRelu -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PRELU_ATTR_CHANNEL_SHARED; - -// Psroi pooling -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_GROUP_SIZE; - -// Power -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_POWER; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SHIFT; - -// Log -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SHIFT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_BASE; - -// Pack -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PACK_ATTR_NAME_NUM; - -// Dynamic stitch -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; - -// Unpack -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UNPACK_ATTR_NAME_NUM; -// Gathernd -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TINDICES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TPARAMS; - -// Argmax -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_TOPK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_OUTMAX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXISTYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_KEEPDIMS; - -// Upsample -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_W; - -// Relu -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NEGATIVE_SLOPE; - -// FreeSpaceExtract -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; - -// split -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SLICE_POINT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_NUM_SPLIT; - -// Tvm -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_MAGIC; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_BLOCKDIM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_METADATA; - -// Squeeze -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_DIMS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_OP_NAME; - -// Stride slice -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_END_MASK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; - -// Slice -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_BEGINS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_SIZES; - -// Roialign -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SPATIAL_SCALE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SAMPLING_RATIO; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_W; - -// Generate_rpn_proposal -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string - GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; -// Decode_bbox -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DECODE_BBOX_ATTR_DECODECLIP; - -// Cast -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_DSTT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_SRCT; - -// Fastrcnnn predications -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; - -// REORG -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_STRIDE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_REVERSE; - -// MERGE -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_DEAD_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_PRENODE_FLAG; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TO_BE_OUTPUT; -static const std::string NOT_NET_OUTPUT = "not_net_output"; - -// Concatv2 -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_TIDX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_N; -// SUM -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_TIDX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_KEEP_DIMS; - -// ResizeBilinear -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_HEIGHT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_WIDTH; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_END; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALPHA; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_BETA; - -// RetinaNet -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_FILTER_BACKGROUND_TRUE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_ANCHOR_FUSION; - -// MatMul -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_X; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_HAS_BIAS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_ATTR_IS_TRAINING; - -// Flatten -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_START_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_END_AXIS; - -// Reshape -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NUM_AXES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_SHAPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_ALPHA; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_BETA; - -// Frameoworkop -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_IN_DATATYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_OUT_DATATYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_N; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_C; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_H; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_W; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_DEPTH_CONV; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_CONV; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BEFORE_PAD; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ANN_MEAN_KEEPDIMS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_PADDINGDS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_CONSTANT_VALUE; - -// ConvGradFilter -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; -// ConvGradInput -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; - -// Rnn -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_MODE_STATIC; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MUTI_RNN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CELL_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CNN_RNN; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GRU_CELL; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_HT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_XT_HT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_BATCH_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL_CLIP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_PROJ_CLIP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_ACTIVATE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MAP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_STATE_OUT_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_TIME_MAJOR; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_IS_INPUT_PRE_PROCESS; - -// Upsample -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE; - -// PadV2 -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PADS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_T; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PAD_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_CONST_VALUE; - -// MirrorPad -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PADS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; - -// Filler -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_VALUE; - -// Shufflechannel -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHUFFLE_CHANNEL_GROUP; - -// TopKV2 -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TOPKV2_ATTR_K; - -// Calibaration -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_H_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_W_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_TOP_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_BOTTOM_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_RIGHT_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_LEFT_INDEX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_CONST; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GROUP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_EPSILON; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_POOLING_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CLASS_NUM; - -// model -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TARGET_TYPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_STREAM_NUM; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; - -// Public Attribute -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IMPLY_TYPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BYTE_SIZE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_INFERENCE_ID; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_OPDEF; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_SCOPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OPATTR; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELUFLAG; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SEQLEN_INDEX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_X_INDEX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CONT_INDEX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_XSTATIC_INDEX; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_MINI; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_TINY; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_LITE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_LABEL; - -// L2_normalize -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_AXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_EPS; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_WINDOW; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_CEIL_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_DATA_MODE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_NAN_OP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_PAD_MOD; -// HCOM -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_ROOT_RANK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCE_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_RANK_SIZE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCTION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_GROUP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SR_TAG; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SRC_RANK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DEST_RANK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SHAPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DATA_TYPE; -// Log time stamp -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_LOGID; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_NOTIFY; -// SpaceToDepth/DepthToSpace -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BLOCK_SIZE; - -// SparseSoftmaxCrossEntropyWithLogits -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; - -// MaxPoolGradWithArgmax -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; - -// AvgPoolGrad -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; - -// Pad -extern const std::string ATTR_PAD_FORMAT; - -// Varible -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_NAME; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FRACTALZ_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_4D_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_5D_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DATA_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_NAME; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHAPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HALF_VAR_NAME_END; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_CONTAINER; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHARED_NAME; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DTYPE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_ADDR_OFFSET; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX_KEY; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SRC_VAR_NAME; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_SAVE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_RESTORE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_BROADCAST; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_SRC_VAR_NAME; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_PRE_PEER_OUT_INDEX; - -// Assign -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ASSIGN_VALIDATE_SHAPE; - -// ShapeN -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_N; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_IN_TYPE; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_OUT_TYPE; - -// Space2bacth batch2space -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_BLOCK; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_PADDING; -// Depth_to_space space_to_depth -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; -// FakeQuantWithMinMaxVars -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MAX; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MIN; -// Mobilenet_ssd_conv_fusion -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_BOXES_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_SCORES_FUSION; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; - -// Lsh project -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSH_PROJ_TYPE; - -// Control flow -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ITERATORS_PER_LOOP; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TRUE_BRANCH_STREAM; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; - -// GatherV2 attr def -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TAXIS; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TINDICES; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TPARAMS; - -// Reshape attr def -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_INPUT_DESC; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; - -// Axis attr def -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS_ORG_OP; - -// The node link with SparseSoftmaxCrossEntropyWithLogits -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LINK_WITH_SPARE; - -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; -// For constant folding -extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; -} // namespace domi - -#endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ diff --git a/inc/framework/common/op/attr_value_util.h b/inc/framework/common/op/attr_value_util.h index b55d3391..8a90cfa2 100644 --- a/inc/framework/common/op/attr_value_util.h +++ b/inc/framework/common/op/attr_value_util.h @@ -21,11 +21,17 @@ #include #include -#include "common/op/attr_define.h" #include "common/types.h" +#include "graph/debug/ge_attr_define.h" #include "proto/om.pb.h" -namespace domi { +using domi::AttrDef; +using domi::AttrDef_ListValue; +using domi::ModelDef; +using domi::NamedAttrs; +using domi::OpDef; + +namespace ge { using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; using AttrDefPair = ::google::protobuf::MapPair; @@ -150,6 +156,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ diff --git a/inc/framework/common/op/ge_op_utils.h b/inc/framework/common/op/ge_op_utils.h index b3730f16..87cf54d8 100644 --- a/inc/framework/common/op/ge_op_utils.h +++ b/inc/framework/common/op/ge_op_utils.h @@ -62,6 +62,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; + class OpUtils { public: /// diff --git a/inc/framework/common/op/op_parser_util.h b/inc/framework/common/op/op_parser_util.h index e64ddc92..49b4350a 100644 --- a/inc/framework/common/op/op_parser_util.h +++ b/inc/framework/common/op/op_parser_util.h @@ -22,7 +22,7 @@ #include #include -namespace domi { +namespace ge { // general const float DEFAULT_ALPHA_VALUE = 1.0; const float DEFAULT_BETA_VALUE = 0.0; @@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; // Shufflechannel const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ diff --git a/inc/framework/common/op_types.h b/inc/framework/common/op_types.h index 8d859169..4555d5c3 100644 --- a/inc/framework/common/op_types.h +++ b/inc/framework/common/op_types.h @@ -20,7 +20,7 @@ #include #include -namespace domi { +namespace ge { class OpTypeContainer { public: static OpTypeContainer *Instance() { @@ -57,6 +57,6 @@ class OpTypeRegistrar { const OpTypeRegistrar g_##var_name##_reg(str_name); #define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ diff --git a/inc/framework/common/scope_guard.h b/inc/framework/common/scope_guard.h index 6e5c4b4a..2154648d 100644 --- a/inc/framework/common/scope_guard.h +++ b/inc/framework/common/scope_guard.h @@ -25,10 +25,10 @@ /// MAKE_GUARD([&] { Release Resource 1 }) /// Acquire Resource 2 // MAKE_GUARD([&] { Release Resource 2 }) -#define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) +#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() -namespace domi { +namespace ge { class ScopeGuard { public: // Noncopyable @@ -55,6 +55,6 @@ class ScopeGuard { std::function on_exit_scope_; bool dismissed_; }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ diff --git a/inc/framework/common/string_util.h b/inc/framework/common/string_util.h index 42d5a2cd..b74eddcf 100644 --- a/inc/framework/common/string_util.h +++ b/inc/framework/common/string_util.h @@ -25,7 +25,7 @@ #include #include -namespace domi { +namespace ge { class StringUtils { public: static std::string &Ltrim(std::string &s) { @@ -151,6 +151,6 @@ class StringUtils { return ret > 0 ? buffer : ""; } }; -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index d98f784c..7bb8d5e7 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -26,6 +26,7 @@ #include #include #include + #include "framework/common/fmk_error_codes.h" #include "framework/common/fmk_types.h" #include "framework/common/op_types.h" @@ -46,9 +47,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; -} // namespace ge +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; -namespace domi { // Supported public properties name FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path @@ -68,14 +68,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map PROFILE_COMPONENT_MAP; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; -/// @brief Data structure definition related to task sinking -/// Build model -enum BuildMode { - GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) - GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) - GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) -}; - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; @@ -341,8 +333,9 @@ REGISTER_OPTYPE_DECLARE(END, "End"); REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); +REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") -/***************ANN dedicated operator *************************/ +// ANN dedicated operator REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); @@ -359,7 +352,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); -/********************Training operator ***********************/ +// Training operator REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); @@ -438,11 +431,13 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); +REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); +REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DECLARE(SEND, "Send"); @@ -450,6 +445,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); +REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); @@ -828,9 +824,6 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; // number of partitions in the current model static constexpr uint32_t PARTITION_SIZE = 4; -#define SIZE_OF_MODEL_PARTITION_TABLE(table) \ - (sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) - enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; struct ModelPartitionMemInfo { @@ -844,6 +837,8 @@ struct ModelPartitionTable { ModelPartitionMemInfo partition[0]; }; +#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) + static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success // Filter format @@ -975,8 +970,8 @@ typedef enum tagDomiNanPropagation { // mode of cropandresize typedef enum tagDomiCropAndResizeMode { - DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ - DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ + DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear + DOMI_RESIZE_METHOD_NEAREST, // resize nearest DOMI_RESIZE_RESERVED } domiCropAndResizeMode_t; @@ -1063,6 +1058,15 @@ struct BasicInfo { uint32_t total_size; // total memory size }; #pragma pack() // Cancels single-byte alignment +} // namespace ge + +namespace domi { +/// @brief Data structure definition related to task sinking +enum BuildMode { + GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) + GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) + GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) +}; } // namespace domi #endif // INC_FRAMEWORK_COMMON_TYPES_H_ diff --git a/inc/framework/common/util.h b/inc/framework/common/util.h index 6447340f..952ce955 100644 --- a/inc/framework/common/util.h +++ b/inc/framework/common/util.h @@ -30,12 +30,12 @@ #include "framework/common/ge_inner_error_codes.h" #include "mmpa/mmpa_api.h" -#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ - do { \ - if (size <= 0) { \ - DOMI_LOGE(param[#size] is not a positive number); \ - return PARAM_INVALID; \ - } \ +#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ + do { \ + if (size <= 0) { \ + DOMI_LOGE("param[%s] is not a positive number", #size); \ + return PARAM_INVALID; \ + } \ } while (0) #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ @@ -44,7 +44,7 @@ if (!b) { \ exec_expr; \ } \ - }; + } // new ge marco // Encapsulate common resource releases @@ -113,101 +113,101 @@ } while (0) // Check if the parameter is null. If yes, return PARAM_INVALID and record the error -#define GE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - DOMI_LOGE(param[#val] must not be null.); \ - return ge::PARAM_INVALID; \ - } \ +#define GE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return ge::PARAM_INVALID; \ + } \ } while (0) // Check if the parameter is null. If yes, just return and record the error -#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ - do { \ - if (val == nullptr) { \ - DOMI_LOGE(param[#val] must not be null.); \ - return; \ - } \ +#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return; \ + } \ } while (0) // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log -#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ - do { \ - if (val == nullptr) { \ - DOMI_LOGE(param[#val] must not be null.); \ - exec_expr; \ - } \ +#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + exec_expr; \ + } \ } while (0) // Check whether the parameter is null. If yes, return directly and record the error log -#define GE_RT_VOID_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - DOMI_LOGE(param[#val] must not be null.); \ - return; \ - } \ +#define GE_RT_VOID_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return; \ + } \ } while (0) // Check if the parameter is null. If yes, return false and record the error log -#define GE_RT_FALSE_CHECK_NOTNULL(val) \ - do { \ - if (val == nullptr) { \ - DOMI_LOGE(param[#val] must not be null.); \ - return false; \ - } \ +#define GE_RT_FALSE_CHECK_NOTNULL(val) \ + do { \ + if (val == nullptr) { \ + DOMI_LOGE("param[%s] must not be null.", #val); \ + return false; \ + } \ } while (0) // Check if the parameter is out of bounds -#define GE_CHECK_SIZE(size) \ - do { \ - if (size == 0) { \ - DOMI_LOGE(param[#size] is out of range); \ - return ge::PARAM_INVALID; \ - } \ +#define GE_CHECK_SIZE(size) \ + do { \ + if (size == 0) { \ + DOMI_LOGE("param[%s] is out of range", #size); \ + return ge::PARAM_INVALID; \ + } \ } while (0) // Check if the container is empty -#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ - do { \ - if (vector.empty()) { \ - DOMI_LOGE(param[#vector] is empty !); \ - return ge::FAILED; \ - } \ +#define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ + do { \ + if (vector.empty()) { \ + DOMI_LOGE("param[%s] is empty!", #vector); \ + return ge::FAILED; \ + } \ } while (0) // Check if the value on the left is greater than or equal to the value on the right -#define GE_CHECK_GE(lhs, rhs) \ - do { \ - if (lhs < rhs) { \ - DOMI_LOGE(param[#lhs] is less than[#rhs]); \ - return ge::PARAM_INVALID; \ - } \ +#define GE_CHECK_GE(lhs, rhs) \ + do { \ + if (lhs < rhs) { \ + DOMI_LOGE("param[%s] is less than[%s]", #lhs, #rhs); \ + return ge::PARAM_INVALID; \ + } \ } while (0) // Check if the value on the left is less than or equal to the value on the right -#define GE_CHECK_LE(lhs, rhs) \ - do { \ - if (lhs > rhs) { \ - DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ - return ge::PARAM_INVALID; \ - } \ +#define GE_CHECK_LE(lhs, rhs) \ + do { \ + if (lhs > rhs) { \ + DOMI_LOGE("param[%s] is greater than[%s]", #lhs, #rhs); \ + return ge::PARAM_INVALID; \ + } \ } while (0) #define GE_DELETE_NEW_SINGLE(var) \ - { \ + do { \ if (var != nullptr) { \ delete var; \ var = nullptr; \ } \ - }; + } while (0) #define GE_DELETE_NEW_ARRAY(var) \ - { \ + do { \ if (var != nullptr) { \ delete[] var; \ var = nullptr; \ } \ - }; + } while (0) /** * @ingroup domi_common @@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; */ #define CEIL(N, n) (((N) + (n)-1) / (n)) -namespace domi { +namespace ge { using google::protobuf::Message; /// @@ -373,7 +373,7 @@ std::string RealPath(const char *path); /// @param [in] file_path path of input file /// @param [out] result /// -bool CheckInputPathValid(const std::string &file_path); +bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); /// /// @ingroup domi_common @@ -381,7 +381,7 @@ bool CheckInputPathValid(const std::string &file_path); /// @param [in] file_path path of output file /// @param [out] result /// -bool CheckOutputPathValid(const std::string &file_path); +bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); /// /// @ingroup domi_common @@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); /// @param [out] result /// bool ValidateStr(const std::string &filePath, const std::string &mode); -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_COMMON_UTIL_H_ diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h old mode 100644 new mode 100755 diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index 9bc13f24..a18e730d 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -47,6 +47,8 @@ class GeGenerator { Status GenerateOnlineModel(const Graph &graph, const vector &inputs, ge::ModelBufferData &model); + Status GenerateInfershapeGraph(const Graph &graph); + /// /// @ingroup ge /// @brief: Build single OP in Model. diff --git a/inc/framework/memory/memory_assigner.h b/inc/framework/memory/memory_assigner.h index 34c58d26..bbec014b 100644 --- a/inc/framework/memory/memory_assigner.h +++ b/inc/framework/memory/memory_assigner.h @@ -33,7 +33,7 @@ class MemoryAssigner { MemoryAssigner &operator=(const MemoryAssigner &) = delete; - Status AssignMemory(bool is_loop_graph, size_t &mem_offset); + Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size); private: ge::ComputeGraphPtr compute_graph_; diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 925aa9dd..547fbe2f 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -28,21 +28,27 @@ #include "framework/common/types.h" #include "register/register_fmk_types.h" +using domi::DOMI_TENSOR_ND; +using domi::DOMI_TENSOR_RESERVED; +using domi::domiTensorFormat_t; +using domi::FMK_TYPE_RESERVED; +using domi::FrameworkType; using std::map; using std::string; using std::unordered_map; using std::vector; -namespace domi { +namespace ge { /** * @ingroup domi_omg * @brief run model */ enum RunMode { - GEN_OM_MODEL = 0, // generate offline model file - MODEL_TO_JSON = 1, // convert to JSON file - ONLY_PRE_CHECK = 3, // only for pre-check - PBTXT_TO_JSON = 5 // pbtxt to json + GEN_OM_MODEL = 0, // generate offline model file + MODEL_TO_JSON = 1, // convert to JSON file + MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape + ONLY_PRE_CHECK = 3, // only for pre-check + PBTXT_TO_JSON = 5 // pbtxt to json }; /// @@ -93,7 +99,7 @@ struct OmgContext { std::string ddk_version; // preferential format used by the entire network domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; - FrameworkType type = FMK_TYPE_RESERVED; + domi::FrameworkType type = domi::FMK_TYPE_RESERVED; RunMode run_mode = ONLY_PRE_CHECK; bool train_flag = false; // whether to use FP16 high precision @@ -102,23 +108,25 @@ struct OmgContext { std::string output_type; // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the - // network require special processing based on the specific network. - // e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope - // fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The - // convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. + // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module + // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the + // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape + // operator needs to be deleted for the convolution kernel. std::string net_name; // Whether to use dynamic batch size or dynamic image size bool is_dynamic_input = false; std::string dynamic_batch_size; std::string dynamic_image_size; }; +} // namespace ge +namespace domi { /** * @ingroup domi_omg * @brief get OMG context * @return OmgContext context */ -OmgContext &GetContext(); +ge::OmgContext &GetContext(); struct TEBinInfo { // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. diff --git a/inc/framework/omg/version.h b/inc/framework/omg/version.h index 300f32eb..ac649d83 100644 --- a/inc/framework/omg/version.h +++ b/inc/framework/omg/version.h @@ -26,7 +26,7 @@ #include "common/string_util.h" #include "framework/common/debug/ge_log.h" -namespace domi { +namespace ge { class PlatformVersionManager { public: PlatformVersionManager() = delete; @@ -40,6 +40,6 @@ class PlatformVersionManager { return SUCCESS; } }; // class PlatformManager -} // namespace domi +} // namespace ge #endif // INC_FRAMEWORK_OMG_VERSION_H_ diff --git a/inc/graph/attr_value_serializable.h b/inc/graph/attr_value_serializable.h index 2b2a7733..a69beb96 100644 --- a/inc/graph/attr_value_serializable.h +++ b/inc/graph/attr_value_serializable.h @@ -86,16 +86,16 @@ class _GeSerializable { } template - static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { GeAttrValue itemVal = SaveItemAsAttrValue(item); (void)namedAttrs.SetAttr(itemName, itemVal); SaveItem(namedAttrs, args...); } - static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) {} + static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} template - static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { auto itemVal = namedAttrs.GetItem(itemName); auto status = LoadItemFromAttrValue(item, itemVal); if (status != GRAPH_SUCCESS) { @@ -104,7 +104,9 @@ class _GeSerializable { return LoadItem(namedAttrs, args...); } - static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } + static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { + return GRAPH_SUCCESS; + } }; #define _GE_FI(a) #a, a @@ -171,13 +173,13 @@ class _GeSerializable { \ private: \ ge::graphStatus Save(GeAttrValue &ar) const { \ - GeAttrValue::NamedAttrs named_attrs; \ + GeAttrValue::NAMED_ATTRS named_attrs; \ _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ - return ar.SetValue(named_attrs); \ + return ar.SetValue(named_attrs); \ } \ ge::graphStatus Load(const GeAttrValue &ar) { \ - GeAttrValue::NamedAttrs named_attrs; \ - ge::graphStatus status = ar.GetValue(named_attrs); \ + GeAttrValue::NAMED_ATTRS named_attrs; \ + ge::graphStatus status = ar.GetValue(named_attrs); \ if (status != GRAPH_SUCCESS) { \ return status; \ } \ diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index c63494f8..dbde46f5 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -83,6 +83,7 @@ class ComputeGraph : public std::enable_shared_from_this, public A // AddNode with NodePtr NodePtr AddNode(NodePtr node); NodePtr AddNode(OpDescPtr op); + NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. NodePtr AddNodeFront(NodePtr node); NodePtr AddNodeFront(const OpDescPtr &op); NodePtr AddInputNode(NodePtr node); @@ -236,8 +237,9 @@ class ComputeGraph : public std::enable_shared_from_this, public A std::deque &stack); graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, std::map &breadth_node_map); - graphStatus TopologicalSortingSubgraph(); + graphStatus TopologicalSortingGraph(); graphStatus SortNodes(std::vector &stack, std::map &mapInEdgeNum); + Vistor AllGraphNodes(std::vector> &subgraphs) const; size_t GetInEdgeSize(const NodePtr &node); size_t GetOutEdgeSize(const NodePtr &node); graphStatus RemoveExtraOutEdge(const NodePtr &node); diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index ed992a62..23bb114a 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -32,6 +32,12 @@ namespace ge { #define GE_FUNC_DEV_VISIBILITY #endif // Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; @@ -58,6 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; @@ -74,8 +82,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; -// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string -// ATTR_NAME_WEIGHTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; @@ -123,6 +130,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; @@ -140,12 +154,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; + // to be deleted GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; @@ -158,15 +184,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; // _Arg @@ -255,7 +281,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; - +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; + +// Huberloss +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; + +// SSDRealDivTileMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; + +// SSDSumMulRealDivMean +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; +/// ConcatFive2Four +/// ConcatFour2Five +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; // Scale GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; @@ -292,7 +340,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; // Roipooling @@ -305,6 +352,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI // DetectionOutput GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; @@ -363,6 +411,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ // Permute GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; // SSD Normalize GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; @@ -403,9 +452,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; +// Log +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; // Pack GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; // Unpack GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; // Gathernd @@ -414,8 +469,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND // Argmax GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; // Relu GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; @@ -486,6 +549,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; +static const std::string NOT_NET_OUTPUT = "not_net_output"; // ENTER GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; @@ -511,6 +575,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; +// RetinaNet +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; // MatMul GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; @@ -559,10 +626,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; // Upsample GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; +// PadV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; + +// MirrorPad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; // Filler GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; @@ -583,36 +670,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; @@ -627,24 +684,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; - // Public attribute GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; @@ -678,6 +731,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_T GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; @@ -696,6 +751,161 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; + +// L2_normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +// Log time stamp +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; +// SpaceToDepth/DepthToSpace +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; + +// SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; + +// MaxPoolGradWithArgmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; + +// AvgPoolGrad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; + +// Varible +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; + +// Assign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; + +// ShapeN +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; + +// Space2bacth batch2space +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; +// Depth_to_space space_to_depth +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; +// FakeQuantWithMinMaxVars +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; +// Mobilenet_ssd_conv_fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; + +// Lsh project +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; + +// GatherV2 attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; + +// Reshape attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; + +// Axis attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; +// The node link with SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; +// For constant folding +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; + // Used for mark the active label list to find stream of activated node GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; @@ -708,7 +918,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // Control flow GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; @@ -722,6 +931,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // Function Op GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_CONST_TYPE; // Used for mark the active node is for loop, type:bool GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; @@ -752,6 +962,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE // For mutil-batch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; // For inserted op GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; @@ -772,6 +983,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // used for l1 fusion and other fusion in future GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; @@ -782,10 +994,44 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; + +// functional ops attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; // used for label switch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; + +// Varible +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; + +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; +// used for LX tiling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h index 1d50577c..b8b3916a 100644 --- a/inc/graph/detail/model_serialize_imp.h +++ b/inc/graph/detail/model_serialize_imp.h @@ -22,7 +22,7 @@ #include #include #include "graph/anchor.h" -#include "detail/attributes_holder.h" +#include "graph/detail/attributes_holder.h" #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/node.h" @@ -77,6 +77,8 @@ class ModelSerializeImp { void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } private: + bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs); + std::vector graph_input_node_names_; std::vector graph_output_node_names_; std::vector node_input_node_names_; diff --git a/inc/graph/ge_attr_value.h b/inc/graph/ge_attr_value.h index c5186fd1..b665beba 100644 --- a/inc/graph/ge_attr_value.h +++ b/inc/graph/ge_attr_value.h @@ -43,30 +43,31 @@ using ComputeGraphPtr = std::shared_ptr; using ConstComputeGraphPtr = std::shared_ptr; class GeTensorDesc; - +class GeAttrValue; class GeAttrValueImp; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { public: - class NamedAttrs : public AttrHolder { - public: - NamedAttrs(); - virtual ~NamedAttrs() = default; - void SetName(const std::string &name); - string GetName() const; - GeAttrValue GetItem(const string &key) const; - - protected: - ProtoAttrMapHelper MutableAttrMap() override; - ConstProtoAttrMapHelper GetAttrMap() const override; - - private: - // Create namedAttrs from protobuf obj - NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); - GeIrProtoHelper named_attrs_; - friend class GeAttrValueImp; - }; + NamedAttrs(); + virtual ~NamedAttrs() = default; + void SetName(const std::string &name); + string GetName() const; + GeAttrValue GetItem(const string &key) const; + + protected: + ProtoAttrMapHelper MutableAttrMap() override; + ConstProtoAttrMapHelper GetAttrMap() const override; + + private: + // Create namedAttrs from protobuf obj + NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); + GeIrProtoHelper named_attrs_; + friend class GeAttrValueImp; + friend class GeAttrValue; +}; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { + public: using INT = int64_t; using FLOAT = float; using BOOL = bool; @@ -75,7 +76,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { using TENSOR_DESC = GeTensorDesc; using GRAPH = ComputeGraphPtr; using BYTES = Buffer; - using NAMED_ATTRS = NamedAttrs; + using NAMED_ATTRS = ge::NamedAttrs; using DATA_TYPE = ge::DataType; using LIST_INT = vector; @@ -90,6 +91,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { using LIST_LIST_INT = vector>; using LIST_DATA_TYPE = vector; + using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). + enum ValueType { VT_NONE = 0, VT_STRING, diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h index 7a3eed68..a434591e 100644 --- a/inc/graph/ge_tensor.h +++ b/inc/graph/ge_tensor.h @@ -87,6 +87,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH GeShape &MutableShape(); void SetShape(GeShape shape); + // set shape with -2, it stand for unknown shape + void SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + GeShape GetOriginShape() const; void SetOriginShape(const GeShape &originShape); diff --git a/inc/graph/model.h b/inc/graph/model.h index 464a2401..38ea501b 100644 --- a/inc/graph/model.h +++ b/inc/graph/model.h @@ -25,11 +25,7 @@ #include "graph/ge_attr_value.h" #include "graph/graph.h" -namespace domi { -class ModelHelper; -} namespace ge { -using domi::ModelHelper; using std::map; using std::string; using std::vector; diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index ab59155e..1827e6be 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -50,6 +50,8 @@ class GeAttrValue; using ConstOpDesc = const OpDesc; +enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; + class OpDesc : public std::enable_shared_from_this, public AttrHolder { public: template @@ -83,6 +85,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus AddInputDescForward(const string &name, const unsigned int num); + graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); + graphStatus AddOutputDescForward(const string &name, const unsigned int num); graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); @@ -141,6 +145,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); + graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); + graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); bool IsOptionalInput(const string &name) const; @@ -214,6 +220,9 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { void SetIsInputConst(const vector &is_input_const); vector GetIsInputConst() const; + void SetOpInferDepends(const vector &depend_names); + vector GetOpInferDepends() const; + string GetInputNameByIndex(uint32_t index) const; int GetInputIndexByName(const string &name) const; @@ -236,12 +245,23 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { std::string GetOpEngineName() const; + void RegisterSubgraphIrName(const std::string &name, SubgraphType type); + const std::map &GetSubgraphIrNames() const; + SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; + graphStatus AddSubgraphName(const std::string &name); const std::map &GetSubgraphNameIndexes() const; std::string GetSubgraphInstanceName(uint32_t index) const; const std::vector &GetSubgraphInstanceNames() const; - void AddSubgraphInstanceName(std::string name); + /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, + /// because this kind of functions will only append a new subgraph instance name + /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. + /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. + /// \param index + /// \param name + /// \return + graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); void RemoveSubgraphInstanceName(const std::string &name); protected: @@ -256,7 +276,23 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeIrProtoHelper op_def_; std::vector subgraph_instance_names_; + + // subgraph names to index, for a `if` operator: + // then_branch: 0 + // else_branch: 1 + // or for a `case` node: + // branches0: 0 + // branches1: 1 + // branches2: 2 std::map subgraph_names_to_index_; + + // subgraph ir names to type, for a `if` operator: + // then_branch: static + // else_branch: dynamic + // or for a `case` op: + // branches: dynamic + std::map subgraph_ir_names_to_type_; + vector inputs_desc_{}; vector outputs_desc_{}; map output_name_idx_{}; diff --git a/inc/graph/ref_relation.h b/inc/graph/ref_relation.h new file mode 100644 index 00000000..71457916 --- /dev/null +++ b/inc/graph/ref_relation.h @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMON_GRAPH_REF_RELATION_H_ +#define COMMON_GRAPH_REF_RELATION_H_ + +#include +#include +#include +#include + +#include "graph/compute_graph.h" +#include "graph/types.h" +#include "graph/ge_error_codes.h" +#include "node.h" + +namespace ge { +enum InOutFlag { + NODE_IN = 0, // input flag + NODE_OUT = 1, // output flag +}; + +struct RefCell { + std::string node_name; + ge::NodePtr node = nullptr; + InOutFlag in_out = NODE_IN; + int in_out_idx = 0; + + bool operator==(const RefCell &c) const { + return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; + } + + RefCell() = default; + RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { + node_name = name; + node = node_ptr; + in_out = in_out_flag; + in_out_idx = idx; + }; + ~RefCell() = default; +}; + +struct RefCellHash { + size_t operator()(const RefCell &c) const { + unsigned long number = reinterpret_cast(reinterpret_cast(c.node.get())); + string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); + return std::hash()(tmp); + } +}; + +class RefRelations { + public: + graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear(); + + RefRelations(); + ~RefRelations() = default; + + public: + class Impl; + std::shared_ptr impl_ = nullptr; +}; + +} // namespace ge +#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/inc/graph/usr_types.h b/inc/graph/usr_types.h index 796a70a3..90e02001 100644 --- a/inc/graph/usr_types.h +++ b/inc/graph/usr_types.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ -#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ +#ifndef INC_GRAPH_USR_TYPES_H_ +#define INC_GRAPH_USR_TYPES_H_ #include #include @@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { #undef USR_TYPE_BYTES_DEC } // namespace ge -#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ +#endif // INC_GRAPH_USR_TYPES_H_ diff --git a/inc/graph/utils/attr_utils.h b/inc/graph/utils/attr_utils.h index ab89ebc7..15a815d4 100644 --- a/inc/graph/utils/attr_utils.h +++ b/inc/graph/utils/attr_utils.h @@ -62,9 +62,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector &value); - static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value); + static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, - const vector &value); + const vector &value); static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector &value); @@ -91,9 +91,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector &value); static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector &value); - static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value); + static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, - vector &value); + vector &value); static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector &value); // Value will be moved static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 8066e8b5..904684e7 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -95,12 +95,35 @@ }; namespace ge { +enum IOType { kIn, kOut }; + +struct NodeIndexIO { + NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) + : node(std::move(node)), index(index), io_type(io_type) {} + NodeIndexIO(ge::NodePtr node, int index, IOType io_type) + : node(std::move(node)), index(static_cast(index)), io_type(io_type) {} + ~NodeIndexIO() {} + + NodePtr node = nullptr; + uint32_t index = 0; + IOType io_type = kOut; + + std::string ToString() const { + if ((node == nullptr) || (node->GetOwnerComputeGraph() == nullptr)) { + return ""; + } + return node->GetName() + (io_type == kOut ? "_out_" : "_in_") + std::to_string(index); + } +}; + class GraphUtils { public: static ComputeGraphPtr GetComputeGraph(const Graph &graph); static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); + static graphStatus RecoverGraphOperators(const Graph &graph); + static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); @@ -262,6 +285,108 @@ class GraphUtils { static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); + + static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); + + /// + /// Get reference-mapping of all data_anchors in graph + /// @param [in] graph + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + private: + /// + /// Get reference-mapping for in_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Get reference-mapping for out_data_anchors of node + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle input of Merge op + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Handle output of subgraph + /// @param [in] node + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Union ref-mapping + /// @param [in] exist_node_info1 + /// @param [in] exist_node_info2 + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @param [out] symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol); + + /// + /// Update symbol mapping with a new reference pair + /// @param [in] cur_node_info + /// @param [in] exist_node_info + /// @param [out] symbol_to_anchors + /// @param [out] anchor_to_symbol + /// @return success: GRAPH_SUCESS + /// + static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol); + + /// + /// Check if out_data_anchor is reference of input + /// @param [in] out_data_anchor + /// @param [out] reuse_in_index + /// @return bool + /// + static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); }; class ComputeGraphBuilder { @@ -441,12 +566,12 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { private: /// - /// @brief Build inputs + /// @brief Add data nodes /// @param [out] error_code /// @param [out] error_msg /// @return void /// - void BuildInputs(graphStatus &error_code, std::string &error_msg); + void AddDataNodes(graphStatus &error_code, std::string &error_msg); /// /// @brief Add data node @@ -455,41 +580,15 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { /// @param [out] error_msg /// @return void /// - NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); + NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); /// - /// @brief Build outputs + /// @brief Add RetVal nodes /// @param [out] error_code /// @param [out] error_msg /// @return void /// - void BuildOutputs(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add NetOutput node - /// @param [out] error_code - /// @param [out] error_msg - /// @return NodePtr - /// - NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); - - /// - /// @brief Add input/output tensor for NetOutput node - /// @param [in] out_nodes_info - /// @param [out] net_output_desc - /// @return graphStatus - /// - graphStatus BuildInOutForNetOutput(const std::vector> &out_nodes_info, - OpDescPtr &net_output_desc); - - /// - /// @brief Add edge for NetOutput node - /// @param [in] out_nodes_info - /// @param [out] net_output_node - /// @return graphStatus - /// - graphStatus AddEdgeForNetOutput(const std::vector> &out_nodes_info, - const NodePtr &net_output_node); + void AddRetValNodes(graphStatus &error_code, std::string &error_msg); std::string name_; NodePtr parent_node_; diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index c979f727..e4c18d51 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -55,11 +55,44 @@ class NodeUtils { static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); + // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; + // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, + // the out param "is_unknow" will be true too + static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); static std::string GetNodeType(const Node &node); static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); - static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); + static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); + + /// + /// Check if node is input of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphInput(const NodePtr &node); + + /// + /// Check if node is output of subgraph + /// @param [in] node + /// @return bool + /// + static bool IsSubgraphOutput(const NodePtr &node); + + /// + /// @brief Get subgraph original input node. + /// @param [in] node + /// @return Node + /// + static NodePtr GetParentInput(const NodePtr &node); + + /// + /// @brief Get subgraph input is constant. + /// @param [in] node + /// @param [out] string + /// @return bool + /// + static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); private: static std::map> map_send_info_; diff --git a/inc/graph/utils/op_desc_utils.h b/inc/graph/utils/op_desc_utils.h index 210ba0a5..6a9a4695 100644 --- a/inc/graph/utils/op_desc_utils.h +++ b/inc/graph/utils/op_desc_utils.h @@ -81,6 +81,9 @@ class OpDescUtils { static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); + static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, + const std::string& subgraph_instance_name, OpDescPtr& op_desc); + private: static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); @@ -104,6 +107,14 @@ class OpDescBuilder { /// OpDescBuilder& AddInput(const std::string& name); + /// + /// @brief Add input + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); + /// /// @brief Add dynamic input /// @param [in] name @@ -112,6 +123,15 @@ class OpDescBuilder { /// OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); + /// + /// @brief Add dynamic input + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); + /// /// @brief Add output /// @param [in] name @@ -119,6 +139,14 @@ class OpDescBuilder { /// OpDescBuilder& AddOutput(const std::string& name); + /// + /// @brief Add output + /// @param [in] name + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); + /// /// @brief Add dynamic output /// @param [in] name @@ -127,6 +155,15 @@ class OpDescBuilder { /// OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); + /// + /// @brief Add dynamic output + /// @param [in] name + /// @param [in] num + /// @param [in] tensor + /// @return OpDescBuilder + /// + OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); + /// /// @brief Build op_desc /// @return OpDescPtr @@ -136,8 +173,8 @@ class OpDescBuilder { private: std::string name_; std::string type_; - std::vector inputs_; - std::vector outputs_; + std::vector> inputs_; + std::vector> outputs_; }; } // namespace ge diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt index c0f8ccaf..43f5b597 100755 --- a/src/common/graph/CMakeLists.txt +++ b/src/common/graph/CMakeLists.txt @@ -34,13 +34,12 @@ ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) # need to remove dependencies on pb files later -file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "*.cc" "utils/*.cc" "opsproto/*.cc" "detail/*.cc" "debug/*.cc" - "op_imp.cc" "option/*.cc" ) diff --git a/src/common/graph/anchor.cc b/src/common/graph/anchor.cc index 0b9eb00a..f02037e5 100644 --- a/src/common/graph/anchor.cc +++ b/src/common/graph/anchor.cc @@ -53,7 +53,6 @@ void Anchor::UnlinkAll() noexcept { if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { GELOGW("unlink peer_anchor_ptr failed."); } - } while (!peer_anchors_.empty()); } } diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index a35747d4..591ff0b5 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -42,8 +42,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { attrs_.InitDefault(); } + ComputeGraph::~ComputeGraph() {} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { @@ -53,24 +56,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS } return s; } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - vector all_nodes(nodes_.size()); - (void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); - for (const auto &sub_graph : sub_graph_) { - if (sub_graph == nullptr) { - GELOGW("sub graph is nullptr"); + if (sub_graph_.empty()) { + return Vistor(shared_from_this(), nodes_); + } + + std::vector> subgraphs; + return AllGraphNodes(subgraphs); +} + +ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector> &subgraphs) const { + std::vector all_nodes; + std::deque candidates; + + candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); + while (!candidates.empty()) { + NodePtr node = candidates.front(); + all_nodes.emplace_back(node); + candidates.pop_front(); + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { continue; } - for (const auto &node : sub_graph->GetAllNodes()) { - all_nodes.push_back(node); + + const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); + for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { + auto subgraph = GetSubgraph(*name_iter); + if (subgraph != nullptr) { + subgraphs.emplace_back(subgraph); + candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); + } } } + return Vistor(shared_from_this(), all_nodes); } + size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { return Vistor(shared_from_this(), nodes_); } + ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { return Vistor(shared_from_this(), input_nodes_); } @@ -82,6 +111,7 @@ ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { } return Vistor(shared_from_this(), result); } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { for (const auto &node : nodes_) { if (node == nullptr) { @@ -203,10 +233,6 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { return nullptr; } node->GetOpDesc()->SetId(nodes_.size()); - if (nodes_[0] == nullptr) { - GELOGE(GRAPH_FAILED, "nodes_ size or nodes_[0] is nullptr"); - return nullptr; - } if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { (void)nodes_.insert(nodes_.begin() + 1, node); } else { @@ -248,6 +274,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); return AddNode(node_ptr); } + +NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); + return nullptr; + } + op->SetId(id); + NodePtr node = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); + nodes_.push_back(node); + return node; +} + NodePtr ComputeGraph::AddInputNode(NodePtr node) { if (node == nullptr) { GELOGE(GRAPH_FAILED, "The node ptr should be not null."); @@ -259,6 +299,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { } return node; } + NodePtr ComputeGraph::AddOutputNode(NodePtr node) { if (node == nullptr || node->GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); @@ -336,6 +377,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN } return GRAPH_FAILED; } + // Used in sub_graph scenes graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { if (node == nullptr) { @@ -372,20 +414,24 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); return GRAPH_SUCCESS; } + std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { if (sub_graph == nullptr) { GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); return nullptr; } sub_graph_.push_back(sub_graph); + names_to_subgraph_[sub_graph->GetName()] = sub_graph; return sub_graph; } + graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { if (sub_graph == nullptr) { GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); return GRAPH_FAILED; } + names_to_subgraph_.erase(sub_graph->GetName()); auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); if (iter != sub_graph_.end()) { (void)sub_graph_.erase(iter); @@ -462,8 +508,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( const std::string &name) const { - auto iter = names_to_subgraph_.find(name); - return iter == names_to_subgraph_.end() ? nullptr : iter->second; + std::shared_ptr parent = parent_graph_.lock(); + if (parent == nullptr) { + auto iter = names_to_subgraph_.find(name); + return iter == names_to_subgraph_.end() ? nullptr : iter->second; + } else { + return parent->GetSubgraph(name); + } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> @@ -495,7 +546,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { - for (auto &input : input_nodes_) { + size_t update_num = 0; + for (auto &input : nodes_) { + if (update_num >= input_mapping.size()) { + break; + } uint32_t cur_index = 0; if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { continue; @@ -508,6 +563,7 @@ ComputeGraph::UpdateInputMapping(const std::map &input_mappi GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); return GRAPH_FAILED; } + update_num++; } return GRAPH_SUCCESS; @@ -520,9 +576,9 @@ ComputeGraph::UpdateInputMapping(const std::map &input_mappi /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { - NodePtr net_output = FindNode(kNodeNameNetOutput); + NodePtr net_output = FindNode(NODE_NAME_NET_OUTPUT); if (net_output == nullptr) { - GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); + GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", NODE_NAME_NET_OUTPUT); return GRAPH_FAILED; } OpDescPtr op_desc = net_output->GetOpDesc(); @@ -557,13 +613,13 @@ ComputeGraph::UpdateOutputMapping(const std::map &output_map GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { std::vector node_vec = nodes_; - for (const auto &node : GetAllNodes()) { + for (const auto &node : GetDirectNode()) { if (node == nullptr || node->GetOpDesc() == nullptr) { GELOGW("node or OpDescPtr is nullptr."); continue; } GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); - if (node->GetOpDesc()->GetType() == kRecvType) { + if (node->GetOpDesc()->GetType() == RECV) { auto iter = find(node_vec.begin(), node_vec.end(), node); if (iter == node_vec.end()) { GELOGW("no node found."); @@ -574,7 +630,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); (void)node_vec.insert(dst_iter, node); } - if (node->GetOpDesc()->GetType() == kSendType) { + if (node->GetOpDesc()->GetType() == SEND) { auto iter = find(node_vec.begin(), node_vec.end(), node); if (iter == node_vec.end()) { GELOGW("no node found."); @@ -602,7 +658,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::vector &stack) { - GELOGI("Runing_Dfs_Sort"); + GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); // Record the number of non data nodes but no input nodes GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); @@ -647,7 +703,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack) { - GELOGI("Runing_Bfs_Sort"); + GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); std::vector stack_input; std::map breadth_node_map; // Record the number of non data nodes but no input nodes @@ -708,23 +764,36 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::mapTopologicalSortingSubgraph(); + for (const auto &sub_graph : sub_graph_) { + ret = sub_graph->TopologicalSortingGraph(); if (ret != SUCCESS) { GELOGE(ret, "Sub graph topological sort Failed"); return ret; } } + + std::vector> subgraphs; + (void)AllGraphNodes(subgraphs); + if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original + GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); + return SUCCESS; + } + sub_graph_.swap(subgraphs); return SUCCESS; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { +graphStatus ComputeGraph::TopologicalSortingGraph() { std::vector node_vec; std::map map_in_edge_num; bool use_BFS = false; @@ -735,7 +804,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog use_BFS = true; } } else { - GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); + GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); } if (use_BFS) { @@ -793,8 +862,8 @@ graphStatus ComputeGraph::SortNodes(std::vector &stack, std::mapGetOpDesc() == nullptr, continue); map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); if (map_in_edge_num[node] == 0) { - if ((node->GetOpDesc()->GetType() != kDataType) && (node->GetOpDesc()->GetType() != kAippDataType) && - (node->GetOpDesc()->GetType() != kInputType) && (node->GetOpDesc()->GetType() != kAnnDataType)) { + if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && + (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { // At present, can only judge the isolated point without input and output. // It is impossible to judge the situation with multiple output nodes. if (verify_isolated && GetOutEdgeSize(node) == 0) { @@ -832,6 +901,7 @@ graphStatus ComputeGraph::SortNodes(std::vector &stack, std::mapGetOutAllNodes(); @@ -954,6 +1026,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate } } } + // If there is an input control side auto in_ctrl_anchor = node->GetInControlAnchor(); GE_CHECK_NOTNULL(in_ctrl_anchor); @@ -991,6 +1064,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate return RemoveExtraOutEdge(node); } + graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { GE_CHECK_NOTNULL(node); // Remove redundant output edges @@ -1041,7 +1115,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferSh node_ptr->GetName().c_str()); graphStatus status = node_ptr->InferShapeAndType(); - GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == kDataType || GRAPH_PARAM_INVALID != status, break, + GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, "Op %s does not have the IMPLEMT_INFERFUNC definition," " and subsequent operators no longer perform shape inference.", node_ptr->GetName().c_str()); diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h index d79eece4..3c511bdd 100644 --- a/src/common/graph/debug/ge_op_types.h +++ b/src/common/graph/debug/ge_op_types.h @@ -16,237 +16,41 @@ #ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ #define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ -#include -#include -#include -#include -#include -#include -#include namespace ge { -#define GE_REGISTER_OPTYPE(var_name, str_name) static const char* var_name __attribute__((unused)) = str_name +#define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name GE_REGISTER_OPTYPE(DATA, "Data"); GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); -GE_REGISTER_OPTYPE(CONVOLUTION, "Convolution"); -GE_REGISTER_OPTYPE(CORRELATION, "Correlation"); -GE_REGISTER_OPTYPE(CORRELATIONV2, "Correlation_V2"); -GE_REGISTER_OPTYPE(DECONVOLUTION, "Deconvolution"); -GE_REGISTER_OPTYPE(POOLING, "Pooling"); -GE_REGISTER_OPTYPE(ELTWISE, "Eltwise"); -GE_REGISTER_OPTYPE(RELU, "ReLU"); -GE_REGISTER_OPTYPE(RELU6, "ReLU6"); -GE_REGISTER_OPTYPE(SIGMOID, "Sigmoid"); -GE_REGISTER_OPTYPE(ABSVAL, "AbsVal"); -GE_REGISTER_OPTYPE(TANH, "TanH"); -GE_REGISTER_OPTYPE(PRELU, "PReLU"); -GE_REGISTER_OPTYPE(BATCHNORM, "BatchNorm"); -GE_REGISTER_OPTYPE(FUSIONBATCHNORM, "FusionBatchNorm"); -GE_REGISTER_OPTYPE(SCALE, "Scale"); -GE_REGISTER_OPTYPE(FULL_CONNECTION, "FullConnection"); -GE_REGISTER_OPTYPE(SOFTMAX, "Softmax"); -GE_REGISTER_OPTYPE(PLUS, "Plus"); -GE_REGISTER_OPTYPE(ACTIVATION, "Activation"); -GE_REGISTER_OPTYPE(FLATTEN, "Flatten"); -GE_REGISTER_OPTYPE(ADD, "Add"); -GE_REGISTER_OPTYPE(SUB, "Sub"); -GE_REGISTER_OPTYPE(MUL, "Mul"); GE_REGISTER_OPTYPE(MATMUL, "MatMul"); -GE_REGISTER_OPTYPE(RSQRT, "Rsqrt"); -GE_REGISTER_OPTYPE(BIASADD, "BiasAdd"); GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); -GE_REGISTER_OPTYPE(DEPCONVOLUTION, "ConvolutionDepthwise"); -GE_REGISTER_OPTYPE(DROPOUT, "Dropout"); -GE_REGISTER_OPTYPE(CONCAT, "Concat"); -GE_REGISTER_OPTYPE(ROIPOOLING, "ROIPooling"); -GE_REGISTER_OPTYPE(PROPOSAL, "Proposal"); -GE_REGISTER_OPTYPE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); -GE_REGISTER_OPTYPE(DETECTIONPOSTPROCESS, "Detectpostprocess"); -GE_REGISTER_OPTYPE(LRN, "LRN"); -GE_REGISTER_OPTYPE(TRANSDATA, "TransData"); GE_REGISTER_OPTYPE(PERMUTE, "Permute"); -GE_REGISTER_OPTYPE(SSDNORMALIZE, "SSDNormalize"); -GE_REGISTER_OPTYPE(SSDPRIORBOX, "SSDPriorBox"); GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); -GE_REGISTER_OPTYPE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); -GE_REGISTER_OPTYPE(CHANNELAXPY, "ChannelAxpy"); -GE_REGISTER_OPTYPE(PSROIPOOLING, "PSROIPooling"); -GE_REGISTER_OPTYPE(POWER, "Power"); -GE_REGISTER_OPTYPE(ROIALIGN, "ROIAlign"); -GE_REGISTER_OPTYPE(PYTHON, "Python"); -GE_REGISTER_OPTYPE(FREESPACEEXTRACT, "FreespaceExtract"); -GE_REGISTER_OPTYPE(SPATIALTF, "SpatialTransform"); -GE_REGISTER_OPTYPE(SHAPE, "Shape"); -GE_REGISTER_OPTYPE(ARGMAX, "ArgMax"); -GE_REGISTER_OPTYPE(GATHERND, "GatherNd"); -GE_REGISTER_OPTYPE(GATHER, "Gather"); -GE_REGISTER_OPTYPE(REALDIV, "RealDiv"); -GE_REGISTER_OPTYPE(PACK, "Pack"); -GE_REGISTER_OPTYPE(SLICE, "Slice"); -GE_REGISTER_OPTYPE(FLOORDIV, "FloorDiv"); +GE_REGISTER_OPTYPE(_WHILE, "_While"); +GE_REGISTER_OPTYPE(WHILE, "While"); +GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); -GE_REGISTER_OPTYPE(STRIDEDSLICE, "StridedSlice"); -GE_REGISTER_OPTYPE(RANGE, "Range"); -GE_REGISTER_OPTYPE(RPNPROPOSALS, "GenerateRpnProposals"); -GE_REGISTER_OPTYPE(DECODEBBOX, "DecodeBBox"); -GE_REGISTER_OPTYPE(PAD, "Pad"); -GE_REGISTER_OPTYPE(TILE, "Tile"); -GE_REGISTER_OPTYPE(SIZE, "Size"); -GE_REGISTER_OPTYPE(CLIPBOXES, "Clipboxes"); -GE_REGISTER_OPTYPE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); -GE_REGISTER_OPTYPE(SPLIT, "Split"); GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); -GE_REGISTER_OPTYPE(MEAN, "Mean"); -GE_REGISTER_OPTYPE(GREATER, "Greater"); GE_REGISTER_OPTYPE(SWITCH, "Switch"); -GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); GE_REGISTER_OPTYPE(MERGE, "Merge"); -GE_REGISTER_OPTYPE(REFMERGE, "RefMerge"); -GE_REGISTER_OPTYPE(ENTER, "Enter"); -GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); -GE_REGISTER_OPTYPE(LOOPCOND, "LoopCond"); +GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); -GE_REGISTER_OPTYPE(EXIT, "Exit"); -GE_REGISTER_OPTYPE(REFEXIT, "RefExit"); -GE_REGISTER_OPTYPE(CONTROLTRIGGER, "ControlTrigger"); -GE_REGISTER_OPTYPE(TRANSPOSE, "Transpose"); -GE_REGISTER_OPTYPE(CAST, "Cast"); -GE_REGISTER_OPTYPE(REGION, "Region"); -GE_REGISTER_OPTYPE(YOLO, "Yolo"); -GE_REGISTER_OPTYPE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); -GE_REGISTER_OPTYPE(FILL, "Fill"); -GE_REGISTER_OPTYPE(REVERSE, "Reverse"); -GE_REGISTER_OPTYPE(UNPACK, "Unpack"); -GE_REGISTER_OPTYPE(YOLO2REORG, "Yolo2Reorg"); -GE_REGISTER_OPTYPE(REDUCESUM, "ReduceSum"); GE_REGISTER_OPTYPE(CONSTANT, "Const"); -GE_REGISTER_OPTYPE(RESIZEBILINEAR, "ResizeBilinear"); -GE_REGISTER_OPTYPE(MAXIMUM, "Maximum"); GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); -GE_REGISTER_OPTYPE(ARG, "_Arg"); -GE_REGISTER_OPTYPE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); -GE_REGISTER_OPTYPE(LSTM, "LSTM"); -GE_REGISTER_OPTYPE(HIGHWAY, "HighWay"); -GE_REGISTER_OPTYPE(RNN, "RNN"); -GE_REGISTER_OPTYPE(ATTENTIONDECODER, "AttentionDecoder"); -GE_REGISTER_OPTYPE(LOGICAL_NOT, "LogicalNot"); -GE_REGISTER_OPTYPE(LOGICAL_AND, "LogicalAnd"); -GE_REGISTER_OPTYPE(EQUAL, "Equal"); -GE_REGISTER_OPTYPE(INTERP, "Interp"); -GE_REGISTER_OPTYPE(SHUFFLECHANNEL, "ShuffleChannel"); -GE_REGISTER_OPTYPE(AIPP, "Aipp"); - -GE_REGISTER_OPTYPE(CROPANDRESIZE, "CropAndResize"); -GE_REGISTER_OPTYPE(UNUSEDCONST, "UnusedConst"); -GE_REGISTER_OPTYPE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); -GE_REGISTER_OPTYPE(BROADCASTARGS, "BroadcastArgs"); -GE_REGISTER_OPTYPE(STOPGRADIENT, "StopGradient"); -GE_REGISTER_OPTYPE(PPREVENTGRADIENT, "PreventGradient"); -GE_REGISTER_OPTYPE(GUARANTEECONST, "GuaranteeConst"); -GE_REGISTER_OPTYPE(SPARSETODENSE, "SparseToDense"); -GE_REGISTER_OPTYPE(NONMAXSUPPRESSION, "NonMaxSuppression"); -GE_REGISTER_OPTYPE(TOPKV2, "TopKV2"); -GE_REGISTER_OPTYPE(INVERTPERMUTATION, "InvertPermutation"); -GE_REGISTER_OPTYPE(MULTINOMIAL, "Multinomial"); -GE_REGISTER_OPTYPE(REVERSESEQUENCE, "ReverseSequence"); GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); GE_REGISTER_OPTYPE(INITDATA, "InitData"); - -// ANN specific operator -GE_REGISTER_OPTYPE(ANN_MEAN, "AnnMean"); -GE_REGISTER_OPTYPE(ANN_CONVOLUTION, "AnnConvolution"); -GE_REGISTER_OPTYPE(ANN_DEPCONVOLUTION, "AnnDepthConv"); -GE_REGISTER_OPTYPE(DIV, "Div"); -GE_REGISTER_OPTYPE(ANN_FULLCONNECTION, "AnnFullConnection"); -GE_REGISTER_OPTYPE(ANN_NETOUTPUT, "AnnNetOutput"); GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); -// Training operator -GE_REGISTER_OPTYPE(CONVGRADFILTER, "Conv2DBackpropFilter"); -GE_REGISTER_OPTYPE(CONV2D, "Conv2D"); -GE_REGISTER_OPTYPE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); -GE_REGISTER_OPTYPE(ACTIVATIONGRAD, "ReluGrad"); GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); -GE_REGISTER_OPTYPE(AVGPOOLGRAD, "AvgPoolGrad"); -GE_REGISTER_OPTYPE(SQUARE, "Square"); -GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); -GE_REGISTER_OPTYPE(END, "End"); GE_REGISTER_OPTYPE(VARIABLE, "Variable"); +GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); -/// @ingroup domi_omg -/// @brief INPUT node type -static const char* const kInputType = "Input"; - -/// -/// @ingroup domi_omg -/// @brief AIPP tag, tag for aipp conv operator -/// -static const char* const kAippConvFlag = "Aipp_Conv_Flag"; - -/// -/// @ingroup domi_omg -/// @brief AIPP tag, tag for aipp data operator -/// -static const char* const kAippDataFlag = "Aipp_Data_Flag"; - -/// -/// @ingroup domi_omg -/// @brief AIPP tag, tag for aipp data operator -/// -static const char* const kAippDataType = "AippData"; - -/// -/// @ingroup domi_omg -/// @brief DATA node type -/// -static const char* const kDataType = "Data"; - -/// -/// @ingroup domi_omg -/// @brief Frame operator type -/// -static const char* const kFrameworkOpType = "FrameworkOp"; - -/// -/// @ingroup domi_omg -/// @brief Data node type -/// -static const char* const kAnnDataType = "AnnData"; -static const char* const kAnnNetoutputType = "AnnNetOutput"; -/// -/// @ingroup domi_omg -/// @brief Convolution node type -/// -static const char* const kNodeNameNetOutput = "Node_Output"; - -/// -/// @ingroup domi_omg -/// @brief RECV node type -/// -static const char* const kRecvType = "Recv"; +GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); -/// -/// @ingroup domi_omg -/// @brief SEND node type -/// -static const char* const kSendType = "Send"; +GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); -/// -/// @ingroup domi_omg -/// @brief Convolution node type -/// -static const char* const kOpTypeConvolution = "Convolution"; -/// -/// @ingroup domi_omg -/// @brief Add convolution node name to hard AIPP -/// -static const char* const kAippConvOpNmae = "aipp_conv_op"; -/// -/// @ingroup domi_omg -/// @brief Operator configuration item separator -/// -static const char* const kOpConfDelimiter = ":"; +GE_REGISTER_OPTYPE(RECV, "Recv"); +GE_REGISTER_OPTYPE(SEND, "Send"); }; // namespace ge #endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc index 04294180..91d388d0 100644 --- a/src/common/graph/format_refiner.cc +++ b/src/common/graph/format_refiner.cc @@ -15,11 +15,14 @@ */ #include "format_refiner.h" + #include #include #include #include #include + +#include "graph/ref_relation.h" #include "./compute_graph.h" #include "./ge_error_codes.h" #include "./graph/ge_tensor.h" @@ -34,14 +37,41 @@ #include "utils/tensor_utils.h" #include "utils/type_utils.h" +using namespace ge; +using namespace std; namespace ge { namespace { static const std::unordered_set kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE}; static bool net_format_is_nd = true; static Format g_user_set_format = FORMAT_ND; static bool is_first_infer = true; +static RefRelations reflection_builder; } // namespace +graphStatus ReflectionProcess(const std::unordered_set &reflection, + std::deque &nodes, ge::Format to_be_set_format) { + for (const auto &cell : reflection) { + auto node = cell.node; + auto in_out_idx = cell.in_out_idx; + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (cell.in_out == ge::NODE_IN) { + auto desc = node->GetOpDesc()->GetInputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateInputDesc(static_cast(in_out_idx), desc); + } else { + auto desc = node->GetOpDesc()->GetOutputDesc(static_cast(in_out_idx)); + desc.SetOriginFormat(to_be_set_format); + desc.SetFormat(to_be_set_format); + (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(in_out_idx), desc); + } + nodes.push_back(cell.node); + } + + return GRAPH_SUCCESS; +} + graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { GE_CHECK_NOTNULL(op_desc); if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { @@ -66,7 +96,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std anchor_points.clear(); // Get all anchor point nodes and switch nodes for (const auto &node_ptr : graph->GetAllNodes()) { - std::vector is_node_set_format; if (node_ptr == nullptr) { return GRAPH_FAILED; } @@ -86,7 +115,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std for (uint32_t i = 0; i < input_size; i++) { // Operator pre-set format but not origin format auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); - // Pre-save data node and default infer fail + // Pre-save data node (only main graph data) and default infer fail if (node_ptr->GetType() == DATA) { data_nodes.push_back(node_ptr); } @@ -163,6 +192,16 @@ graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge:: } // Check format whether have been set int idx = peer_out_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", + (peer_out_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } + auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast(idx)); if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); @@ -181,18 +220,26 @@ graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, ge:: continue; } - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast(idx), ge_tensor_desc); - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); - graphStatus status = peer_out_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); - return GRAPH_FAILED; + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); + status = peer_out_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_out_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } } - nodes.push_back(peer_out_data_node); } } return GRAPH_SUCCESS; @@ -213,17 +260,23 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, g continue; } for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (peer_in_data_anchor == nullptr) { - GELOGW("Node[%s] some peer_in_anchor is null", (node->GetName()).c_str()); - continue; - } + GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - if (peer_in_data_node == nullptr || peer_in_data_node->GetOpDesc() == nullptr) { - GELOGW("Node[%s] peer_in_data_node or peer_in_data_node desc is null", node->GetName().c_str()); - continue; - } + GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); + GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); + // Check format whether have been set int idx = peer_in_data_anchor->GetIdx(); + // do peer_out_node name and index as key to lookup reflections + ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); + std::unordered_set reflection; + auto status = reflection_builder.LookUpRefRelations(key, reflection); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", + (peer_in_data_node->GetName()).c_str(), idx); + return GRAPH_FAILED; + } auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast(idx)); if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); @@ -240,24 +293,33 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, g GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); continue; } - ge_tensor_desc.SetOriginFormat(to_be_set_format); - ge_tensor_desc.SetFormat(to_be_set_format); - (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(idx, ge_tensor_desc); - /// Because netoutput node added before infer format ,so netoutput is end condition - /// must set netoutput format , because saved result depend on format - if (peer_in_data_node_type == NETOUTPUT) { - continue; - } + if (reflection.empty()) { + ge_tensor_desc.SetOriginFormat(to_be_set_format); + ge_tensor_desc.SetFormat(to_be_set_format); + (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast(idx), ge_tensor_desc); - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Forward]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); - graphStatus status = peer_in_data_node->InferOriginFormat(); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); - return GRAPH_FAILED; + /// Because netoutput node added before infer format ,so netoutput is end condition + /// must set netoutput format , because saved result depend on format + if (peer_in_data_node_type == NETOUTPUT) { + continue; + } + + // Call operator infer format api (forward) to get out format + GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); + status = peer_in_data_node->InferOriginFormat(); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); + return GRAPH_FAILED; + } + nodes.push_back(peer_in_data_node); + } else { + auto status = ReflectionProcess(reflection, nodes, to_be_set_format); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "reflection process failed!"); + return GRAPH_FAILED; + } } - nodes.push_back(peer_in_data_node); } } } @@ -355,8 +417,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) GELOGE(GRAPH_FAILED, "input graph is null"); return GRAPH_FAILED; } + // build reflection relations of boundary + (void)reflection_builder.Clear(); + auto status = reflection_builder.BuildRefRelations(*graph); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); + return GRAPH_FAILED; + } // User set global net format - graphStatus status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); + status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); if (status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); return GRAPH_FAILED; diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index 139bb4f3..23c1cff0 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -18,6 +18,12 @@ namespace ge { // Public attribute +const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; + +const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; + +const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; + const std::string ATTR_NAME_NAME = "name"; const std::string ATTR_NAME_TYPE = "type"; @@ -42,6 +48,8 @@ const std::string ATTR_NAME_BIAS = "bias"; const std::string ATTR_NAME_BIAS_TERM = "bias_term"; +const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; + const std::string ATTR_NAME_PAD = "pad"; const std::string ATTR_NAME_PADS = "pad"; @@ -83,6 +91,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; const std::string ATTR_NAME_AXIS = "axis"; const std::string ATTR_NAME_BROADCAST = "broadcast"; +const std::string ATTR_NAME_OUTPUT = "output"; const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; const std::string ATTR_NAME_TIDX = "t_idx"; @@ -103,6 +112,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; const std::string ATTR_NAME_NAN_OPT = "nan_opt"; const std::string ATTR_NAME_AIPP = "aipp"; +const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; + +const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; + +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; +const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; +const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; @@ -111,6 +127,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; +const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; @@ -122,9 +139,12 @@ const std::string ATTR_NAME_WEIGHTS = "value"; const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; - -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; +const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; +const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; +const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; +const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; +const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; +const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; // To be deleted const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; @@ -138,15 +158,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; - // Refinedet const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; + const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; +const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; +const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; // _Arg const std::string ATTR_NAME_INDEX = "index"; @@ -236,6 +254,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; const std::string BATCHNORM_ATTR_SCALE = "scale"; const std::string BATCHNORM_ATTR_BIAS = "bias"; +const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; +const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; +const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; + +// huberloss +const std::string HUBER_LOSS_ATTR_DELTA = "delta"; + +// SSDRealDivTileMul +const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; + +// SSDSumMulRealDivMean +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; +const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; + +// ConcatFive2Four +// ConcatFour2Five +const std::string SSD_BOX_TYPE_NUM = "box_type_num"; +const std::string SSD_CLASS_NUM = "class_num"; +const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; +const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; +const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; +const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; // Scale const std::string SCALE_ATTR_SCALE = "scale"; @@ -340,6 +382,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; // Permute const std::string PERMUTE_ATTR_ORDER = "order"; +const std::string PERMUTE_ATTR_PERM = "perm"; // SSD Normalize const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; @@ -367,6 +410,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; +// RefinedetDetectionOutput +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; + // PRelu const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; @@ -380,11 +427,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; const std::string POWER_ATTR_NAME_SCALE = "scale"; const std::string POWER_ATTR_NAME_SHIFT = "shift"; +// log +const std::string LOG_ATTR_NAME_SCALE = "scale"; +const std::string LOG_ATTR_NAME_SHIFT = "shift"; +const std::string LOG_ATTR_NAME_BASE = "base"; // Pack const std::string PACK_ATTR_NAME_NUM = "N"; // Unpack const std::string UNPACK_ATTR_NAME_NUM = "num"; +const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; // Gathernd const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; @@ -394,6 +446,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; +const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; +const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; +const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; + +// upsample +const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; +const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; // Relu const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; @@ -531,19 +590,41 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; // Rnn -const std::string RNN_MODE_ = "rnn_"; -const std::string CNN_RNN = "cnn_rnn"; +const std::string RNN_MODE_STATIC = "rnn_static"; const std::string MUTI_RNN = "multi_rnn"; +const std::string CNN_RNN = "cnn_rnn"; +const std::string RNN_MODE_ = "rnn_"; + const std::string CELL_MODE = "mode"; const std::string LSTM_CELL = "lstm_cell"; const std::string GRU_CELL = "gru_cell"; const std::string RNN_HT = "ht"; const std::string RNN_XT_HT = "xt_ht"; const std::string RNN_BATCH_SIZE = "batch_size"; +const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; +const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; +const std::string LSTM_ACTIVATE = "lstm_activate"; +const std::string LSTM_OUT_MAP = "lstm_out_map"; +const std::string LSTM_OUT_MODE = "lstm_out_mode"; +const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; +const std::string LSTM_TIME_MAJOR = "lstm_time_major"; +const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; // Upsample const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; +// PadV2 +const std::string PADV2_ATTR_NAME_MODE = "mode"; +const std::string PADV2_ATTR_NAME_PADS = "paddings"; +const std::string PADV2_ATTR_NAME_T = "T"; +const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; + +// MirrorPad +const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; +const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; +const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; +const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; // Filler const std::string FILLER_TYPE = "filler_type"; const std::string FILLER_VALUE = "filler_value"; @@ -554,9 +635,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; // TopKV2 const std::string TOPKV2_ATTR_K = "k"; -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - // Calibaration const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; @@ -611,10 +689,14 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; const std::string ATTR_MODEL_EVENT_NUM = "event_num"; +const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; + const std::string ATTR_MODEL_LABEL_NUM = "label_num"; const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; +const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; + const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; @@ -660,8 +742,125 @@ const std::string TARGET_TYPE_TINY = "TINY"; const std::string TARGET_TYPE_LITE = "LITE"; +// l2_normalize +const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; +const std::string L2_NORMALIZE_ATTR_EPS = "eps"; + +const std::string POOL_PARAMA_ATTR_WINDOW = "window"; +const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; +const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; +const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; +const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; +const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; + +// HCOM +const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; +const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; + +const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; +const std::string HCOM_ATTR_GROUP = "group"; +const std::string HCOM_ATTR_SR_TAG = "sr_tag"; +const std::string HCOM_ATTR_SRC_RANK = "src_rank"; +const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; +const std::string HCOM_ATTR_FUSION = "fusion"; +const std::string HCOM_ATTR_SHAPE = "shape"; +const std::string HCOM_ATTR_DATA_TYPE = "dtype"; + +// SpaceToDepth/DepthToSpace +const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; + +// SparseSoftmaxCrossEntropyWithLogits +const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; + +// MaxPoolGradWithArgmax +const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; + +// AvgPoolGrad +const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; + +// Pad +const std::string ATTR_PAD_FORMAT = "attr_pad_format"; + +// Varible +const std::string VAR_ATTR_FORMAT = "_var_format"; +const std::string VAR_ATTR_NAME = "var_name"; +const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; +const std::string VAR_ATTR_4D_FORMAT = "4D"; +const std::string VAR_ATTR_5D_FORMAT = "5D"; +const std::string VAR_ATTR_DATA_TYPE = "data_format"; +const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; +const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; +const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; +const std::string VAR_ATTR_SHAPE = "shape"; +const std::string HALF_VAR_NAME_END = "_fp16"; +const std::string VAR_ATTR_INITED = "var_is_inited"; + +const std::string VAR_ATTR_CONTAINER = "container"; +const std::string VAR_ATTR_SHARED_NAME = "shared_name"; +const std::string VAR_ATTR_DTYPE = "dtype"; + +const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; +const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; +const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; +const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; +const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; +const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; + +// Assign +const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; + +// space2bacth batch2space +const std::string BATCH_SPACE_ATTR_BLOCK = "block"; +const std::string BATCH_SPACE_ATTR_PADDING = "padding"; + +// depth_to_space space_to_depth +const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; + +// FakeQuantWithMinMaxVars +const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; +const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; + +// mobilenet_ssd_conv_fusion +const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; +const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; +const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; + +// lsh project +const std::string LSH_PROJ_TYPE = "lsh_project_type"; + +// log time stamp +const std::string LOG_TIME_STAMP_LOGID = "logid"; +const std::string LOG_TIME_STAMP_NOTIFY = "notify"; + +// ShapeN +const std::string SHAPEN_ATTR_N = "N"; +const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; +const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; + +// GatherV2 attr def +const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; +const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; +const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; + +// Reshape attr def +const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; +const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; + +// axis attr def +const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; + +const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; + +const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; +const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; + +// For constant folding +const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; + const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; +const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; + const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; const std::string ATTR_NAME_REFERENCE = "reference"; @@ -694,6 +893,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; +const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; +const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; @@ -705,6 +906,7 @@ const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; // Function Op const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; +const std::string ATTR_NAME_PARENT_CONST_TYPE = "_parent_const_type"; // Used for mark the active node is for loop, type:bool const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; @@ -719,6 +921,7 @@ const std::string MODEL_ATTR_SESSION_ID = "session_id"; // l1 fusion and other fusion in future const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; +const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; @@ -730,6 +933,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; +const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; +const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; +const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; // Atomic addr clean attrs const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; @@ -748,6 +954,8 @@ const std::string ATTR_NEED_COMPILE = "_node_need_compile"; const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; +const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; + // For inserted op const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; @@ -764,7 +972,22 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; +// functional ops attr +const std::string ATTR_NAME_WHILE_COND = "cond"; +const std::string ATTR_NAME_WHILE_BODY = "body"; + // used for label switch const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; + +const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; +const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; + +// used for LX tiling +const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; +const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; +const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; +const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; +const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; +const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; } // namespace ge diff --git a/src/common/graph/ge_attr_value.cc b/src/common/graph/ge_attr_value.cc index 0a2893a4..004d0227 100644 --- a/src/common/graph/ge_attr_value.cc +++ b/src/common/graph/ge_attr_value.cc @@ -31,19 +31,18 @@ using std::string; using std::vector; namespace ge { -GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } +NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } -GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) - : named_attrs_(owner, proto_msg) {} +NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} -void GeAttrValue::NamedAttrs::SetName(const std::string &name) { +void NamedAttrs::SetName(const std::string &name) { auto proto_msg = named_attrs_.GetProtoMsg(); if (proto_msg != nullptr) { proto_msg->set_name(name); } } -string GeAttrValue::NamedAttrs::GetName() const { +string NamedAttrs::GetName() const { auto proto_msg = named_attrs_.GetProtoMsg(); if (proto_msg != nullptr) { return proto_msg->name(); @@ -51,13 +50,13 @@ string GeAttrValue::NamedAttrs::GetName() const { return string(); } -GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { +GeAttrValue NamedAttrs::GetItem(const string &key) const { GeAttrValue value; - GetAttr(key, value); + (void)GetAttr(key, value); return value; } -ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { +ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { auto proto_msg = named_attrs_.GetProtoMsg(); if (proto_msg != nullptr) { return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); @@ -65,7 +64,7 @@ ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); } -ConstProtoAttrMapHelper GeAttrValue::NamedAttrs::GetAttrMap() const { +ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { auto proto_msg = named_attrs_.GetProtoMsg(); if (proto_msg != nullptr) { return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); @@ -515,7 +514,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { +bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector &value) { if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { return false; @@ -739,7 +738,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - GeAttrValue::NamedAttrs &value) { + GeAttrValue::NAMED_ATTRS &value) { if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { return false; } @@ -752,7 +751,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, - vector &value) { + vector &value) { value.clear(); if (!AttrUtilsHelper::GetValueCheckListType( proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { @@ -760,7 +759,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM } auto &list = proto_attr_val.list(); for (const auto &item : list.na()) { - value.emplace_back(GeAttrValue::NamedAttrs()); + value.emplace_back(GeAttrValue::NAMED_ATTRS()); if (value.empty()) { return false; } @@ -967,7 +966,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) ATTR_UTILS_SET_IMP(Tensor, GeTensor) -ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NamedAttrs) +ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) @@ -982,7 +981,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) ATTR_UTILS_SET_IMP(ListTensor, vector) -ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) +ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) ATTR_UTILS_SET_GET_IMP(ListBytes, vector) ATTR_UTILS_SET_GET_IMP(ListGraph, vector) ATTR_UTILS_SET_GET_IMP(ListDataType, vector) diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc index ccf757fa..5d7b6a2e 100644 --- a/src/common/graph/ge_tensor.cc +++ b/src/common/graph/ge_tensor.cc @@ -83,6 +83,12 @@ size_t GeShape::GetDimNum() const { auto proto_msg = shape_def_.GetProtoMsg(); if (proto_msg != nullptr) { if (proto_msg->dim_size() >= 0) { + // check whether contain -2, if true, return -1 + for (auto i : proto_msg->dim()) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } return proto_msg->dim_size(); } else { return 0; @@ -157,6 +163,10 @@ int64_t GeShape::GetShapeSize() const { return 0; } for (auto i : proto_msg->dim()) { + // if unknown shape, return -1 + if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { + return UNKNOWN_DIM; + } res *= i; } } @@ -209,6 +219,7 @@ const string TENSOR_UTILS_RC = "rc"; const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; +const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} @@ -396,6 +407,35 @@ GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } +// set shape with -2, it stand for unknown shape +void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } + +// for unknown shape +graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { + std::vector> shape_range; + for (const auto &ele : range) { + shape_range.emplace_back(std::vector({ele.first, ele.second})); + } + auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + return ret ? GRAPH_SUCCESS : GRAPH_FAILED; +} +graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { + std::vector> shape_range; + (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); + + for (const auto &ele : shape_range) { + // here must be only two elemenet because pair + if (ele.size() != 2) { + GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); + return GRAPH_FAILED; + } + std::pair pair({ele[0], ele[1]}); + range.push_back(pair); + } + + return GRAPH_SUCCESS; +} + GeShape GeTensorDesc::GetOriginShape() const { vector origin_shape; if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc index 4d7c2a3b..09d4fd56 100644 --- a/src/common/graph/graph.cc +++ b/src/common/graph/graph.cc @@ -16,11 +16,12 @@ #include "external/graph/graph.h" #include "debug/ge_util.h" -#include "external/graph/operator.h" #include "framework/common/debug/ge_log.h" -#include "graph/ge_attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_op_types.h" #include "graph/model.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" using std::map; using std::pair; @@ -214,6 +215,23 @@ class GraphImpl { return GRAPH_SUCCESS; } + graphStatus FindOpByType(const string &type, std::vector &ops) const { + for (auto &op : op_list_) { + auto op_type = op.second.GetOpType(); + if (op_type == type) { + ops.push_back(op.second); + continue; + } + if (op_type == ge::FRAMEWORKOP) { + op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); + if (op_type == type) { + ops.push_back(op.second); + } + } + } + return GRAPH_SUCCESS; + } + void SetNeedIteration(bool need_iteration) { if (compute_graph_ == nullptr) { GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); @@ -222,6 +240,8 @@ class GraphImpl { compute_graph_->SetNeedIteration(need_iteration); } + const std::string &GetName() const { return name_; } + private: std::string name_; std::string output_name_; @@ -255,6 +275,11 @@ graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { return impl_->FindOpByName(name, op); } +graphStatus Graph::FindOpByType(const string &type, std::vector &ops) const { + GE_CHECK_NOTNULL(impl_); + return impl_->FindOpByType(type, ops); +} + Graph &Graph::SetInputs(const vector &inputs) { GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); @@ -331,6 +356,8 @@ graphStatus Graph::LoadFromFile(const string &file_name) { return GRAPH_SUCCESS; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); @@ -343,4 +370,15 @@ GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) return graph; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { + GE_CHECK_NOTNULL(graph.impl_); + GE_CHECK_NOTNULL(graph.impl_->compute_graph_); + + graph.impl_->op_list_.clear(); + for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { + graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); + } + return SUCCESS; +} } // namespace ge diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc index 0ec4a2eb..a3b7a936 100644 --- a/src/common/graph/model_serialize.cc +++ b/src/common/graph/model_serialize.cc @@ -16,7 +16,10 @@ #include "graph/model_serialize.h" #include + +#include #include + #include "debug/ge_attr_define.h" #include "debug/ge_log.h" #include "debug/ge_util.h" @@ -26,6 +29,7 @@ #include "utils/graph_utils.h" #include "debug/ge_op_types.h" +using std::map; using std::string; namespace ge { @@ -121,6 +125,11 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op } } } + + op_def_proto->set_id(op_desc->GetId()); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + op_def_proto->add_subgraph_name(name); + } } return true; } @@ -196,6 +205,14 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode GELOGE(GRAPH_FAILED, "SerializeGraph fail"); return false; } + + for (auto subgraph : compute_graph->GetAllSubgraphs()) { + if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { + GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); + return false; + } + } + return true; } @@ -228,6 +245,14 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); op_desc->outputs_desc_.push_back(temp_value); } + + op_desc->SetId(op_def_proto.id()); + uint32_t graph_index = 0; + for (const std::string &name : op_def_proto.subgraph_name()) { + op_desc->AddSubgraphName(name); + op_desc->SetSubgraphInstanceName(graph_index++, name); + } + return true; } @@ -238,7 +263,7 @@ bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op GELOGW("UnserializeOpDesc error."); } - NodePtr node = graph->AddNode(op_desc); + NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); // Inputs @@ -319,6 +344,35 @@ bool ModelSerializeImp::HandleNodeNameRef() { return true; } +bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map &subgraphs) { + std::queue all_graphs; + all_graphs.emplace(compute_graph); + while (!all_graphs.empty()) { + ComputeGraphPtr graph = all_graphs.front(); + all_graphs.pop(); + + for (const NodePtr &node : graph->GetDirectNode()) { + const OpDescPtr op_desc = node->GetOpDesc(); + for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { + auto it = subgraphs.find(name); + if (it == subgraphs.end()) { + GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), + subgraphs.size()); + return false; + } + + ComputeGraphPtr &subgraph = it->second; + subgraph->SetParentGraph(graph); + subgraph->SetParentNode(node); + compute_graph->AddSubgraph(subgraph->GetName(), subgraph); + all_graphs.emplace(subgraph); + } + } + } + + return true; +} + bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { model.name_ = model_proto.name(); model.version_ = model_proto.version(); @@ -332,7 +386,31 @@ bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_pr if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); } + + // 0 is main graph, following is subgraph. + map subgraphs; + for (int idx = 1; idx < graphs_proto.size(); ++idx) { + ComputeGraphPtr subgraph; + ModelSerializeImp impl; + if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { + GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); + return false; + } + + if (!impl.HandleNodeNameRef()) { + GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); + return false; + } + + subgraphs[subgraph->GetName()] = subgraph; + } + + if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { + GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); + return false; + } } + if (!HandleNodeNameRef()) { GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); return false; diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index 620c815c..582cfa9a 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -61,6 +61,8 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; +const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; + const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; @@ -227,6 +229,40 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp } } +graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { + auto input_name_idx = GetAllInputName(); + for (unsigned int i = 0; i < num; i++) { + string input_name = name + std::to_string(i); + GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, + "Add input tensor_desc is existed. name[%s]", input_name.c_str()); + + std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + + if (index > inputs_desc_.size()) { + GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); + return GRAPH_FAILED; + } + + (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); + + // Update index in input_name_idx + for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { + if (it->second >= (index + i)) { + it->second += 1; + } + } + + (void)input_name_idx.insert(make_pair(input_name, i + index)); + } + SetAllInputName(input_name_idx); + + return GRAPH_SUCCESS; +} + graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { auto input_name_idx = GetAllInputName(); for (unsigned int i = 0; i < num; i++) { @@ -239,7 +275,6 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); return GRAPH_FAILED; } - (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); // Update index in input_name_idx @@ -634,6 +669,13 @@ graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int n return GRAPH_SUCCESS; } +graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { + if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { if (is_push_back) { for (unsigned int i = 0; i < num; i++) { @@ -1054,6 +1096,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetDstName return dst_name; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector &depend_names) { + auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + if (ret != true) { + GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); + } +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetOpInferDepends() const { + vector depend_names; + (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); + return depend_names; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector &dst_index) { auto proto_msg = op_def_.GetProtoMsg(); if (proto_msg != nullptr) { @@ -1199,20 +1254,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &O return subgraph_instance_names_; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { - subgraph_instance_names_.emplace_back(std::move(name)); -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { if (*iter == name) { - subgraph_instance_names_.erase(iter); + *iter = ""; return; } } } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { + GELOGI("Add subgraph name is %s", name.c_str()); auto iter = subgraph_names_to_index_.find(name); if (iter != subgraph_names_to_index_.end()) { GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); @@ -1220,6 +1272,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphNa } auto size = subgraph_names_to_index_.size(); subgraph_names_to_index_[name] = size; + subgraph_instance_names_.resize(size + 1); return GRAPH_SUCCESS; } @@ -1227,4 +1280,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map= subgraph_instance_names_.size()) { + GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); + return GRAPH_PARAM_INVALID; + } + subgraph_instance_names_[index] = name; + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, + SubgraphType type) { + subgraph_ir_names_to_type_[name] = type; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphIrNames() + const { + return subgraph_ir_names_to_type_; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType +OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { + auto iter = subgraph_ir_names_to_type_.find(name); + if (iter == subgraph_ir_names_to_type_.end()) { + return kSubgraphTypeEnd; + } + return iter->second; +} } // namespace ge diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index 6d372297..c4ff7ac5 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -15,6 +15,7 @@ */ #include "external/graph/operator.h" +#include "external/graph/operator_factory.h" #include #include #include @@ -38,6 +39,11 @@ #include "utils/tensor_adapter.h" #include "utils/tensor_utils.h" #include "utils/type_utils.h" +#include +#include +#include +#include +#include using std::enable_shared_from_this; using std::make_pair; @@ -343,15 +349,71 @@ class OperatorImpl : public std::enable_shared_from_this { InferenceContextPtr GetInferenceContext() const { return inference_context_; } + void SubgraphRegister(const std::string &name, bool dynamic) { + op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); + } + + void SubgraphCountRegister(const std::string &name, uint32_t count) { + if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { + op_desc_->AddSubgraphName(name); + } else { + for (uint32_t i = 0; i < count; ++i) { + op_desc_->AddSubgraphName(name + std::to_string(i)); + } + } + + subgraph_names_to_builders_[name].resize(count, nullptr); + } + + void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { + auto iter = subgraph_names_to_builders_.find(name); + if (iter == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); + return; + } + if (iter->second.size() <= index) { + GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", + name.c_str(), index, iter->second.size()); + return; + } + iter->second[index] = builder; + } + + SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { + auto iter = subgraph_names_to_builders_.find(name); + if (iter == subgraph_names_to_builders_.end()) { + GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); + return nullptr; + } + if (iter->second.size() <= index) { + GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", + name.c_str(), index, iter->second.size()); + return nullptr; + } + return iter->second[index]; + } + + std::vector GetSubgraphNames() const { + std::vector names; + for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { + names.emplace_back(subgraph_name_to_type.first); + } + return names; + } + + size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } + OpDescPtr op_desc_ = nullptr; private: ge::ConstNodePtr node_{nullptr}; ge::InferenceContextPtr inference_context_; + GraphBuilderCallback graph_builder_callback_; std::map> output_links_{}; std::map input_link_{}; std::vector> control_input_link_{}; std::vector> control_output_link_{}; + std::map> subgraph_names_to_builders_; }; // Used to manage OperatorImpl instances created by ge api. @@ -559,7 +621,6 @@ InferenceContextPtr Operator::GetInferenceContext() const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); return operator_impl_->GetInferenceContext(); } - TensorDesc Operator::GetInputDesc(uint32_t index) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); @@ -698,7 +759,7 @@ const std::map Operator::GetAllAttrNamesAndTypes() con void Operator::InputRegister(const string &name) { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); - operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); + (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); } void Operator::OptionalInputRegister(const string &name) { @@ -745,6 +806,12 @@ void Operator::DynamicInputRegister(const string &name, const unsigned int num, (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); } +void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { + GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); + GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); + operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); +} + int Operator::GetDynamicInputNum(const string &name) const { GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); @@ -896,6 +963,11 @@ OP_ATTR_GET_IMP(string &, Str) OP_ATTR_SET_IMP(const vector &, ListStr) OP_ATTR_GET_IMP(vector &, ListStr) +OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) +OP_ATTR_GET_IMP(vector &, ListNamedAttrs) + OP_ATTR_REG_IMP(int64_t, Int) OP_ATTR_REG_IMP(const vector &, ListInt) OP_ATTR_REG_IMP(float, Float) @@ -905,6 +977,8 @@ OP_ATTR_REG_IMP(const vector &, ListStr) OP_ATTR_REG_IMP(bool, Bool) OP_ATTR_REG_IMP(const vector &, ListBool) OP_ATTR_REG_IMP(const vector> &, ListListInt) +OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) +OP_ATTR_REG_IMP(const vector &, ListNamedAttrs) #undef OP_ATTR_SET_IMP #undef OP_ATTR_GET_IMP @@ -1114,6 +1188,95 @@ void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { } } +void Operator::SubgraphRegister(const std::string &name, bool dynamic) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); +} + +void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SubgraphCountRegister(name, count); +} + +void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); + return; + } + operator_impl_->SetSubgraphBuilder(name, index, builder); +} + +std::vector Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } + +SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { + if (operator_impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "operator impl is nullptr."); + return nullptr; + } + return operator_impl_->GetSubgraphBuilder(name, index); +} + +SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } + +Graph Operator::GetSubgraph(const string &name) const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); + return Graph(""); + } + auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); + if (op_desc == nullptr) { + GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); + return Graph(""); + } + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(name); + if (iter == subgraph_names_to_index.end()) { + GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); + return Graph(""); + } + auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); + if (subgraph_instance_name.empty()) { + GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); + return Graph(""); + } + + auto node = operator_impl_->GetNode(); + if (node == nullptr) { + GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); + return Graph(""); + } + auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); + return Graph(""); + } + auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); + if (subgraph == nullptr) { + GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), + iter->second, subgraph_instance_name.c_str()); + return Graph(""); + } + return GraphUtils::CreateGraphFromComputeGraph(subgraph); +} + +Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { + return GetSubgraph(name + std::to_string(index)); +} + +size_t Operator::GetSubgraphNamesCount() const { + if (operator_impl_ == nullptr) { + GE_LOGE("Failed to get subgraph names count, the operator impl is null"); + return 0; + } + return operator_impl_->GetSubgraphNamesCount(); +} + class GraphBuilderImpl { public: explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared(name)) { diff --git a/src/common/graph/operator_factory_impl.cc b/src/common/graph/operator_factory_impl.cc index fbfdafc3..026a85bc 100644 --- a/src/common/graph/operator_factory_impl.cc +++ b/src/common/graph/operator_factory_impl.cc @@ -96,7 +96,6 @@ VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { if (operator_creators_ == nullptr) { - GELOGI("operator_creators_ init"); operator_creators_.reset(new (std::nothrow) std::map()); } auto it = operator_creators_->find(operator_type); diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc new file mode 100644 index 00000000..cacf213f --- /dev/null +++ b/src/common/graph/ref_relation.cc @@ -0,0 +1,422 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/ref_relation.h" + +#include +#include + +#include "utils/mem_utils.h" +#include "debug/ge_log.h" +#include "debug/ge_op_types.h" +#include "debug/ge_util.h" +#include "debug/ge_attr_define.h" +#include "graph/ge_error_codes.h" +#include "graph/utils/graph_utils.h" +#include "framework/common/debug/ge_log.h" + +using namespace std; +using namespace ge; +namespace ge { +namespace { +const char *kRefIndex = "_parent_node_index"; +const string kWhile = "While"; +const string kIf = "If"; +const string kCase = "Case"; + +const int kMaxElementNum = 100; + +std::unordered_set function_op = {kWhile, kIf, kCase}; +} // namespace + +/* Impl */ +class RefRelations::Impl { + public: + graphStatus LookUpRefRelations(const RefCell &key, unordered_set &result) { + unsigned long number = static_cast(reinterpret_cast(key.node.get())); + std::string lookup_key = + key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); + auto iter = look_up_table_.find(lookup_key); + if (iter != look_up_table_.end()) { + for (auto &c : iter->second) { + result.insert(c); + } + return GRAPH_SUCCESS; + } + GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); + return GRAPH_SUCCESS; + }; + graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); + graphStatus Clear() { + GELOGD("Start clear boundary reflections between main graph and sub graph!"); + look_up_table_.clear(); + values_.clear(); + return GRAPH_SUCCESS; + }; + + private: + graphStatus BuildLookUpTables(); + graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, + const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, + vector> &node_refs); + void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, + vector &netoutput_nodes, const std::vector &sub_graph_names, + const std::string &node_type); + + graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); + graphStatus ProcessSubgraphDataNodes(vector &data_nodes, vector> &classed_data_nodes); + graphStatus ProcessSubgraphNetoutput(const vector &netoutput_nodes, + vector>> &classed_netoutput_nodes); + + std::unordered_map> look_up_table_; + std::vector>> values_; +}; + +// Node Level +graphStatus RefRelations::Impl::BuildRefRelationsForBranch( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + GELOGD("Enter BuildRefRelationsForBranch!"); + + size_t ref_i = 0; + for (const auto &ref_i_data_nodes : classed_data_nodes) { + vector in_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_IN; + cell_root.in_out_idx = ref_i; + in_ref_i_all_refs.emplace_back(cell_root); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + in_ref_i_all_refs.emplace_back(cell_in); + in_ref_i_all_refs.emplace_back(cell_out); + } + node_refs.emplace_back(in_ref_i_all_refs); + ref_i++; + } + + size_t ref_o = 0; + for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { + vector out_ref_i_all_refs; + RefCell cell_root; + cell_root.node_name = root_node->GetName(); + cell_root.node = root_node; + cell_root.in_out = NODE_OUT; + cell_root.in_out_idx = ref_o; + out_ref_i_all_refs.emplace_back(cell_root); + for (const auto &ele : ref_o_net_nodes) { + RefCell cell_netoutput_in; + RefCell cell_netoutput_out; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + cell_netoutput_out.node_name = (ele.first)->GetName(); + cell_netoutput_out.node = ele.first; + cell_netoutput_out.in_out = NODE_OUT; + cell_netoutput_out.in_out_idx = ele.second; + out_ref_i_all_refs.emplace_back(cell_netoutput_in); + out_ref_i_all_refs.emplace_back(cell_netoutput_out); + } + node_refs.emplace_back(out_ref_i_all_refs); + ref_o++; + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildLookUpTables() { + for (size_t i = 0; i < values_.size(); i++) { + vector> &val = values_[i]; + for (const auto &ele : val) { + for (const auto &ref_cell : ele) { + string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + + std::to_string(static_cast(reinterpret_cast(ref_cell.node.get()))); + look_up_table_[key] = ele; + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelationsForWhile( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + GELOGD("Enter BuildRefRelations for while op!"); + // data_nodes has been sorted + // for while, input num must be same as output num + auto input_num = root_node->GetAllInDataAnchorsSize(); + + size_t ref_i = 0; + while (ref_i < input_num) { + auto &ref_i_data_nodes = classed_data_nodes[ref_i]; + auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; + + vector ref_i_all_refs; + RefCell cell_root_i; + RefCell cell_root_o; + cell_root_i.node_name = root_node->GetName(); + cell_root_i.node = root_node; + cell_root_i.in_out = NODE_IN; + cell_root_i.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_i); + cell_root_o.node_name = root_node->GetName(); + cell_root_o.node = root_node; + cell_root_o.in_out = NODE_OUT; + cell_root_o.in_out_idx = ref_i; + ref_i_all_refs.emplace_back(cell_root_o); + for (const auto &data : ref_i_data_nodes) { + RefCell cell_in; + RefCell cell_out; + cell_in.node_name = data->GetName(); + cell_in.node = data; + cell_in.in_out = NODE_IN; + cell_in.in_out_idx = 0; + cell_out.node_name = data->GetName(); + cell_out.node = data; + cell_out.in_out = NODE_OUT; + cell_out.in_out_idx = 0; + ref_i_all_refs.emplace_back(cell_in); + ref_i_all_refs.emplace_back(cell_out); + } + + for (const auto &ele : ref_i_net_nodes) { + RefCell cell_netoutput_in; + RefCell cell_netoutput_out; + cell_netoutput_in.node_name = (ele.first)->GetName(); + cell_netoutput_in.node = ele.first; + cell_netoutput_in.in_out = NODE_IN; + cell_netoutput_in.in_out_idx = ele.second; + cell_netoutput_out.node_name = (ele.first)->GetName(); + cell_netoutput_out.node = ele.first; + cell_netoutput_out.in_out = NODE_OUT; + cell_netoutput_out.in_out_idx = ele.second; + ref_i_all_refs.emplace_back(cell_netoutput_in); + ref_i_all_refs.emplace_back(cell_netoutput_out); + } + node_refs.emplace_back(ref_i_all_refs); + ref_i++; + } + + return GRAPH_SUCCESS; +} +// build ref relations according to diff func op type +graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( + const NodePtr &root_node, const vector> &classed_data_nodes, + const vector>> &classed_netoutput_nodes, vector> &node_refs) { + // data_nodes has been sorted + auto node_type = root_node->GetType(); + + auto status = GRAPH_SUCCESS; + if (node_type == kIf || node_type == kCase) { + status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } else if (node_type == kWhile) { + status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); + } else { + GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); + status = GRAPH_PARAM_INVALID; + } + return status; +} + +void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector &data_nodes, + vector &netoutput_nodes, + const std::vector &sub_graph_names, + const std::string &node_type) { + int sub_graph_idx = 0; + for (const auto &name : sub_graph_names) { + auto sub_graph = root_graph.GetSubgraph(name); + for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { + auto sub_graph_node_type = sub_graph_node->GetType(); + + if (sub_graph_node_type == DATA) { + data_nodes.emplace_back(sub_graph_node); + } else if (sub_graph_node_type == NETOUTPUT) { + // if while, the first subgraph must be cond subgraph. + // There is no meaning for refs ,so continue + if (node_type == kWhile && sub_graph_idx == 0) { + continue; + } + netoutput_nodes.emplace_back(sub_graph_node); + } + continue; + } + sub_graph_idx++; + } +} + +graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { + auto parent_graph_ptr = graph.GetParentGraph(); + if (parent_graph_ptr == nullptr) { + root_graph = graph; + return GRAPH_SUCCESS; + } + auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); + if (root_graph_ptr == nullptr) { + GE_LOGE("Get null root graph"); + return GRAPH_PARAM_INVALID; + } + root_graph = *root_graph_ptr; + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_nodes, + vector> &classed_data_nodes) { + int max_ref_idx = 0; + for (const auto &e : data_nodes) { + int i; + bool is_exist = true; + is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); + if (!is_exist) { + GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); + return GRAPH_FAILED; + } + max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; + } + + while (!data_nodes.empty()) { + auto data = data_nodes.back(); + data_nodes.pop_back(); + int ref_idx = 0; + (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); + classed_data_nodes[ref_idx].emplace_back(data); + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( + const vector &netoutput_nodes, vector>> &classed_netoutput_nodes) { + for (const auto &sub_netoutput_node : netoutput_nodes) { + auto op_desc = sub_netoutput_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { + auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); + if (in_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", + sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); + return GRAPH_FAILED; + } + int ref_o; + if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { + if (ref_o >= kMaxElementNum) { + return GRAPH_FAILED; + } + classed_netoutput_nodes[ref_o].emplace_back( + std::pair({sub_netoutput_node, static_cast(in_data_anchor->GetIdx())})); + } + } + } + return GRAPH_SUCCESS; +} + +graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { + /* First Step: Get root graph */ + ge::ComputeGraph &root_graph = graph; + auto status = GetRootGraph(graph, root_graph); + if (status != GRAPH_SUCCESS) { + return status; + } + + for (const auto &node : graph.GetAllNodes()) { + auto node_type = node->GetType(); + if (function_op.find(node_type) == function_op.end()) { + continue; + } + std::vector ref_nodes; + auto op_desc = node->GetOpDesc(); + auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); + vector data_nodes; + vector netoutput_nodes; + // Get data and netoutput of sub_graph + GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); + vector> classed_data_nodes(kMaxElementNum); // according to ref_idx + vector>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx + status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); + return status; + } + + // for netoutput + // check netoutput + // here main graph output number must be the same as every sub_graph netoutput node + // key: netoutput node_ptr , + status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "process netoutput failed!"); + return status; + } + + vector> node_refs; + status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); + return status; + } + if (!node_refs.empty()) { + values_.push_back(node_refs); + } + } + /* Seconde Step: generate map */ + status = BuildLookUpTables(); + if (status != GRAPH_SUCCESS) { + GELOGE(status, "Build look up tables failed!"); + return status; + } + return GRAPH_SUCCESS; +} + +/* Ref Relations Interface */ +RefRelations::RefRelations() { + impl_ = MakeShared(); + if (impl_ == nullptr) { + GELOGE(GRAPH_FAILED, "MakeShared failed!"); + return; + } +} + +graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set &result) { + GE_CHECK_NOTNULL(impl_); + return impl_->LookUpRefRelations(key, result); +} + +graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { + GE_CHECK_NOTNULL(impl_); + return impl_->BuildRefRelations(root_graph); +} + +graphStatus RefRelations::Clear() { + GE_CHECK_NOTNULL(impl_); + return impl_->Clear(); +} +} // namespace ge \ No newline at end of file diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index da4388f9..845fe494 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -21,7 +21,7 @@ #include #include #include -#include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "debug/ge_log.h" @@ -37,7 +37,6 @@ namespace ge { namespace { -constexpr const char *kRefIndex = "parent_node_index"; graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -47,6 +46,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } auto sub_graph = root_graph->GetSubgraph(name); if (sub_graph == nullptr) { GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); @@ -63,7 +66,7 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { node->GetName().c_str()); return GRAPH_FAILED; } - if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { + if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), node->GetName().c_str()); return GRAPH_FAILED; @@ -76,7 +79,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); return GRAPH_FAILED; } + GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), + node->GetName().c_str()); auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); + if (ret != GRAPH_SUCCESS) { GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); @@ -101,6 +107,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); for (const auto &name : sub_graph_names) { + if (name.empty()) { + GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); + continue; + } auto sub_graph = root_graph->GetSubgraph(name); if (sub_graph == nullptr) { GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); @@ -132,11 +142,14 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { node->GetName().c_str(), edge_anchor->GetIdx()); return GRAPH_FAILED; } + GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), + edge_desc->GetShape().GetDimNum()); int ref_i; - if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { + if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. continue; } + GELOGI("Parent node index of edge desc is %d", ref_i); auto output_desc = op_desc->MutableOutputDesc(static_cast(ref_i)); if (output_desc == nullptr) { GE_LOGE( diff --git a/src/common/graph/tensor.cc b/src/common/graph/tensor.cc index 8e0a9c7d..d5d304b7 100644 --- a/src/common/graph/tensor.cc +++ b/src/common/graph/tensor.cc @@ -29,6 +29,7 @@ namespace { /// Extra 1 byte store '\0' const int EXTRA_STORE_POINTER_FOR_STRING = 8; const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; +const int64_t UNKNOWN_DIM_SIZE = -1; } // namespace namespace ge { @@ -65,6 +66,7 @@ class TensorDescImpl { TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} Shape shape_; + std::vector> range_; Format format_ = FORMAT_ND; Format origin_format_ = FORMAT_ND; DataType data_type_ = DT_FLOAT; @@ -94,7 +96,16 @@ class ShapeImpl { public: ShapeImpl() = default; ~ShapeImpl() = default; - explicit ShapeImpl(const std::vector &dims) : dims_(dims) {} + explicit ShapeImpl(const std::vector &dims) { + bool is_unknown_dim_num = false; + for (const auto &dim : dims) { + if (dim == UNKNOWN_DIM_NUM) { + is_unknown_dim_num = true; + break; + } + } + dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; + } std::vector dims_; }; @@ -105,6 +116,11 @@ Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShareddims_) { + if (i == UNKNOWN_DIM_NUM) { + return 0; + } + } return impl_->dims_.size(); } return 0; @@ -146,6 +162,10 @@ int64_t Shape::GetShapeSize() const { } int64_t size = 1; for (auto i : impl_->dims_) { + if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { + return UNKNOWN_DIM_SIZE; + } + if (!Int64MulNotOverflow(size, i)) { GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); size = 0; @@ -217,6 +237,34 @@ void TensorDesc::SetShape(const Shape &shape) { } } +// set shape with -2, it stand for unknown shape +graphStatus TensorDesc::SetUnknownDimNumShape() { + if (impl != nullptr) { + impl->shape_ = Shape({UNKNOWN_DIM_NUM}); + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); + return GRAPH_FAILED; +} + +// for unknown shape +graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { + if (impl != nullptr) { + impl->range_ = range; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); + return GRAPH_FAILED; +} +graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { + if (impl != nullptr) { + range = impl->range_; + return GRAPH_SUCCESS; + } + GELOGE(GRAPH_FAILED, "impl is nullptr!"); + return GRAPH_FAILED; +} + Shape TensorDesc::GetOriginShape() const { if (impl != nullptr) { return impl->origin_shape_; @@ -541,6 +589,17 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des tensor_desc.GetDataType()); ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); + std::vector> shape_range; + auto status = tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return ge_tensor_desc; + } + status = ge_tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return ge_tensor_desc; + } auto size = tensor_desc.GetSize(); TensorUtils::SetSize(ge_tensor_desc, size); @@ -554,6 +613,17 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ ge_tensor_desc.GetDataType()); tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); + std::vector> shape_range; + auto status = ge_tensor_desc.GetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Get shape range failed!"); + return tensor_desc; + } + status = tensor_desc.SetShapeRange(shape_range); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Set shape range failed!"); + return tensor_desc; + } int64_t size = 0; (void)TensorUtils::GetSize(ge_tensor_desc, size); tensor_desc.SetSize(size); diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index c5e45516..c495ffc9 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include "./ge_context.h" #include "debug/ge_util.h" @@ -390,8 +391,8 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vectorGetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { + (void)RemoveEdge(src, dst); + if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); return GRAPH_FAILED; @@ -399,7 +400,7 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vectorGetOutControlAnchor(); GE_CHECK_NOTNULL(new_out_ctrl_anchor); - for (InControlAnchorPtr peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { + for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || (AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), @@ -706,7 +707,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn GELOGE(GRAPH_FAILED, "File name is too longer!"); return; } - std::unique_ptr real_path(new (std::nothrow) char[PATH_MAX]{0}); + std::unique_ptr real_path(new (std::nothrow) char[PATH_MAX]{0}); if (real_path == nullptr) { GELOGE(GRAPH_FAILED, "New real_path failed."); return; @@ -1275,6 +1276,423 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR return result; } +/// +/// Get reference-mapping of all data_anchors in graph +/// @param [in] graph +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetAllNodes()) { + // in_data_anchor + if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + + // out_data_anchor + if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping for in_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + + if (NodeUtils::IsSubgraphOutput(node)) { + return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); + } + + if (NodeUtils::IsSubgraphInput(node)) { + return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); + } + + std::string type = node->GetType(); + if ((type == MERGE) || (type == STREAMMERGE)) { + return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); + } + + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + std::string symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol] = {cur_node_info}; + anchor_to_symbol[symbol] = symbol; + } else { + NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Get reference-mapping for out_data_anchors of node +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + NodeIndexIO cur_node_info = NodeIndexIO(node, out_data_anchor->GetIdx(), kOut); + if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { + continue; + } + + int32_t reuse_in_index = -1; + if (IsRefFromInput(out_data_anchor, reuse_in_index)) { + NodeIndexIO exist_node_info = NodeIndexIO(node, reuse_in_index, kIn); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } else { + std::string symbol = cur_node_info.ToString(); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors.emplace(std::make_pair(symbol, std::vector{cur_node_info})); + anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + // Data in subgraph + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { + GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); + return GRAPH_FAILED; + } + NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(parent_in_anchor); + OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); + if (peer_out_anchor != nullptr) { + // Data has and only has one input + NodeIndexIO cur_node_info = NodeIndexIO(node, 0, kIn); + NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { + GE_LOGE("Update symbol mapping failed."); + return GRAPH_FAILED; + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle input of Merge op +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + std::vector exist_node_infos; + std::vector cur_node_infos; + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + std::string next_name; + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + ge::NodePtr next_node = graph->FindNode(next_name); + GE_CHECK_NOTNULL(next_node); + // NextIteration has and only has one output + peer_out_anchor = next_node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(peer_out_anchor); + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); + } + } else { + cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); + exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); + } + } + + size_t anchor_nums = 0; + NodeIndexIO max_node_index_io(nullptr, 0, kOut); + for (auto &temp_node_info : exist_node_infos) { + auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); + if (iter1 != anchor_to_symbol.end()) { + std::string temp_symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(temp_symbol); + if (iter2 != symbol_to_anchors.end()) { + if (iter2->second.size() > anchor_nums) { + max_node_index_io = temp_node_info; + anchor_nums = iter2->second.size(); + } + } + } + } + + std::string symbol; + for (auto &temp_node_info : exist_node_infos) { + if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != + GRAPH_SUCCESS) || + symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), + temp_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + } + + auto iter = symbol_to_anchors.find(symbol); + if (iter != symbol_to_anchors.end()) { + for (auto &temp_node_info : cur_node_infos) { + GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); + iter->second.emplace_back(temp_node_info); + anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); + } + } + + return GRAPH_SUCCESS; +} + +/// +/// Handle output of subgraph +/// @param [in] node +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + GE_CHECK_NOTNULL(node); + ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + NodePtr parent_node = owner_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + + GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); + uint32_t index = 0; + if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { + continue; + } + GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); + // Union symbol of peer_out_anchor & parent_out_anchor + NodeIndexIO peer_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); + NodeIndexIO parent_node_info = NodeIndexIO(parent_node, index, kOut); + std::string symbol; + if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != + GRAPH_SUCCESS) || + symbol.empty()) { + GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), + parent_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + symbol_to_anchors[symbol].emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Union ref-mapping +/// @param [in] exist_node_info1 +/// @param [in] exist_node_info2 +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @param [out] symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol, std::string &symbol) { + std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + if (symbol1 == symbol2) { + symbol = symbol1; + GELOGI("no need to union."); + return GRAPH_SUCCESS; + } + + auto iter1 = symbol_to_anchors.find(symbol1); + auto iter2 = symbol_to_anchors.find(symbol2); + if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { + GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); + return GRAPH_FAILED; + } + + auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); + auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); + symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); + std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); + for (auto &node_index_io : min_iter->second) { + GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); + max_iter->second.emplace_back(node_index_io); + auto iter = anchor_to_symbol.find(node_index_io.ToString()); + if (iter == anchor_to_symbol.end()) { + GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); + return GRAPH_FAILED; + } + if (iter->second != min_symbol) { + GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), + iter->second.c_str()); + } + iter->second = symbol; + } + + GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); + symbol_to_anchors.erase(min_iter); + return GRAPH_SUCCESS; +} + +/// +/// Update symbol mapping with a new reference pair +/// @param [in] cur_node_info +/// @param [in] exist_node_info +/// @param [out] symbol_to_anchors +/// @param [out] anchor_to_symbol +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + std::map> &symbol_to_anchors, + std::map &anchor_to_symbol) { + auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); + if (iter1 == anchor_to_symbol.end()) { + GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", + exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); + return GRAPH_FAILED; + } + + std::string symbol = iter1->second; + auto iter2 = symbol_to_anchors.find(symbol); + if (iter2 == symbol_to_anchors.end()) { + GE_LOGE("symbol %s not found.", symbol.c_str()); + return GRAPH_FAILED; + } + GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); + iter2->second.emplace_back(cur_node_info); + anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); + + return GRAPH_SUCCESS; +} + +/// +/// Check if out_data_anchor is reference of input +/// @param [in] out_data_anchor +/// @param [out] reuse_in_index +/// @return bool +/// +bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { + if (out_data_anchor == nullptr) { + GELOGW("out_data_anchor is NULL."); + return false; + } + int32_t output_index = out_data_anchor->GetIdx(); + + // pass-through op + NodePtr node = out_data_anchor->GetOwnerNode(); + std::string type = node->GetType(); + const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; + if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { + reuse_in_index = output_index; + GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); + return true; + } + + // Merge op 0th output + if ((type == MERGE) && (output_index == 0)) { + reuse_in_index = 0; + GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); + return true; + } + + // ref op + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGW("op_desc is NULL."); + return false; + } + bool is_ref = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + if (is_ref) { + const string &output_name = op_desc->GetOutputNameByIndex(output_index); + for (const auto &input_name : op_desc->GetAllInputNames()) { + if (!input_name.empty() && (output_name == input_name)) { + reuse_in_index = op_desc->GetInputIndexByName(input_name); + GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), + output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); + return true; + } + } + } + + // reuse input + auto output_op_desc = op_desc->GetOutputDescPtr(output_index); + bool reuse_input = false; + if (output_op_desc != nullptr) { + if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { + uint32_t reuse_input_index = 0; + if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { + reuse_in_index = static_cast(reuse_input_index); + GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, + reuse_in_index); + return true; + } + } + } + + return false; +} + /// /// @brief Add node to graph /// @param [in] op_desc @@ -1561,13 +1979,14 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map(new (std::nothrow) ComputeGraph(name_)); - if (owner_graph_ == nullptr) { + if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; + error_msg = "graph / parent_node is NULL."; return nullptr; } owner_graph_->SetParentNode(parent_node_); + owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); BuildNodes(error_code, error_msg); if (error_code != GRAPH_SUCCESS) { @@ -1584,41 +2003,58 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string return nullptr; } - BuildInputs(error_code, error_msg); + AddDataNodes(error_code, error_msg); if (error_code != GRAPH_SUCCESS) { return nullptr; } - BuildOutputs(error_code, error_msg); + AddRetValNodes(error_code, error_msg); if (error_code != GRAPH_SUCCESS) { return nullptr; } - if (AddNetOutputNode(error_code, error_msg) == nullptr) { + // ATTR_NAME_SESSION_GRAPH_ID + std::string graph_id; + if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Get attr session_graph_id failed."; return nullptr; } + if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { + error_code = GRAPH_FAILED; + error_msg = "Set attr session_graph_id failed."; + return nullptr; + } + + // refresh node name + for (const NodePtr &node : owner_graph_->GetDirectNode()) { + if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { + continue; + } + node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); + } return owner_graph_; } /// -/// @brief Build inputs +/// @brief Add data nodes /// @param [out] error_code /// @param [out] error_msg /// @return void /// -void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &error_msg) { +void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { for (auto &input : graph_inputs_) { - NodePtr data_node = AddDateNode(input.first, error_code, error_msg); + NodePtr data_node = AddDataNode(input.first, error_code, error_msg); if (data_node == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: add node Data:" + std::to_string(input.first) + +" failed."; + error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; return; } if (owner_graph_->AddInputNode(data_node) == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: add input node Data:" + std::to_string(input.first) + +" failed."; + error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; return; } @@ -1627,7 +2063,7 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err std::vector anchor_indes = input.second.second; if (input_names.size() != anchor_indes.size()) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: num of input_names and indexs not equal."; + error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; return; } if (input_names.empty()) { @@ -1641,29 +2077,29 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err auto iter = node_names_.find(input_name); if (iter == node_names_.end()) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: node " + input_name + " not exist in graph."; + error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; return; } NodePtr in_node = node_names_[input_name]; if (in_node == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: node " + input_name + " is NULL."; + error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; return; } if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + + error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + ":" + std::to_string(ind) + " failed."; return; } } - GELOGD("BuildInputs : Add %u input succ.", input.first); + GELOGD("AddDataNodes : Add %u input succ.", input.first); } - GELOGD("BuildInputs succ."); + GELOGD("AddDataNodes succ."); } /// @@ -1673,13 +2109,13 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err /// @param [out] error_msg /// @return void /// -NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { +NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { std::string data_name = "Data_" + std::to_string(index); OpDescBuilder op_desc_builder(data_name, "Data"); OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); if (op_desc == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: create op_desc " + data_name + " failed."; + error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; return nullptr; } @@ -1687,7 +2123,7 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod if (index_iter != input_mapping_.end()) { if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; + error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; return nullptr; } } @@ -1695,189 +2131,83 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod NodePtr data_node = owner_graph_->AddNode(op_desc); if (data_node == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildInputs failed: add node " + data_name + " failed."; + error_msg = "AddDataNode failed: add node " + data_name + " failed."; return nullptr; } + node_names_[data_name] = data_node; return data_node; } /// -/// @brief Build outputs +/// @brief Add RetVal nodes /// @param [out] error_code /// @param [out] error_msg /// @return void /// -void CompleteGraphBuilder::BuildOutputs(graphStatus &error_code, std::string &error_msg) { - std::map> out_nodes_map; - std::vector> out_nodes_info; - for (auto &pair : graph_outputs_) { - std::string output = pair.first; - int32_t ind = pair.second; - auto out_iter = node_names_.find(output); +void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { + size_t output_num = graph_outputs_.size(); + for (size_t i = 0; i < output_num; i++) { + int32_t index = graph_outputs_[i].second; + auto out_iter = node_names_.find(graph_outputs_[i].first); if (out_iter == node_names_.end()) { error_code = GRAPH_FAILED; - error_msg = "BuildOutputs failed: node " + output + " not exist in graph."; + error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; return; } - - NodePtr out_node = node_names_[output]; - if (out_node == nullptr) { + NodePtr node = out_iter->second; + if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { error_code = GRAPH_FAILED; - error_msg = "BuildOutputs failed: node " + output + " is NULL."; + error_msg = "AddRetValNode failed: node is NULL."; return; } - OutDataAnchorPtr out_anchor = out_node->GetOutDataAnchor(ind); - if (out_anchor == nullptr) { + std::string name = node->GetName() + "_RetVal"; + OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); + if (ret_val_desc == nullptr) { error_code = GRAPH_FAILED; - error_msg = "BuildOutputs failed: anchor " + output + ":" + std::to_string(ind) + " is NULL."; + error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; return; } - - auto iter = out_nodes_map.find(output); - if (iter == out_nodes_map.end()) { - std::vector vec = {ind}; - out_nodes_map[output] = vec; - } else { - out_nodes_map[output].emplace_back(ind); + ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); + if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || + (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; + return; } - out_nodes_info.emplace_back(std::make_pair(out_node, ind)); - - GELOGD("BuildOutputs : AddOutputAnchor %s:%u succ.", output.c_str(), ind); - } - - owner_graph_->SetGraphOutNodes(out_nodes_map); - owner_graph_->SetGraphOutNodesInfo(out_nodes_info); - GELOGD("BuildOutputs succ."); -} - -/// -/// @brief Add NetOutput node -/// @param [out] error_code -/// @param [out] error_msg -/// @return NodePtr -/// -NodePtr CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { - std::string log_msg = "AddNetOutputNode name:" + std::string(kNodeNameNetOutput) + ", type:" + NETOUTPUT; - OpDescPtr net_output_desc = shared_ptr(new (std::nothrow) OpDesc(kNodeNameNetOutput, NETOUTPUT)); - if (net_output_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: op_desc is NULL."; - return nullptr; - } - - std::vector> out_nodes_info = owner_graph_->GetGraphOutNodesInfo(); - error_code = BuildInOutForNetOutput(out_nodes_info, net_output_desc); - if (error_code != GRAPH_SUCCESS) { - error_msg = log_msg + " failed: add input/output tensor failed."; - return nullptr; - } - - NodePtr net_output_node = owner_graph_->AddNode(net_output_desc); - if (net_output_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: add node failed."; - return nullptr; - } - - error_code = AddEdgeForNetOutput(out_nodes_info, net_output_node); - if (error_code != GRAPH_SUCCESS) { - error_msg = log_msg + " failed: link edge failed."; - return nullptr; - } - - GELOGD("%s succ.", log_msg.c_str()); - return net_output_node; -} -/// -/// @brief Add input/output tensor for NetOutput node -/// @param [in] out_nodes_info -/// @param [out] net_output_desc -/// @return graphStatus -/// -graphStatus CompleteGraphBuilder::BuildInOutForNetOutput(const std::vector> &out_nodes_info, - OpDescPtr &net_output_desc) { - size_t output_num = out_nodes_info.size(); - for (size_t i = 0; i < output_num; i++) { - NodePtr src_node = out_nodes_info[i].first; - uint32_t src_index = out_nodes_info[i].second; - if ((src_node == nullptr) || (src_node->GetOpDesc() == nullptr)) { - GE_LOGE("AddInOutForNetOutputOp failed: src_node is NULL."); - return GRAPH_FAILED; + if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && + ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; + return; } - - ge::GeTensorDesc in_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); auto iter = output_mapping_.find(i); if (iter != output_mapping_.end()) { - if (!ge::AttrUtils::SetInt(in_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { - GE_LOGE("AddInOutForNetOutputOp failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); - return GRAPH_FAILED; + if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; + return; } } - if (net_output_desc->AddInputDesc(in_desc) != SUCCESS) { - GE_LOGE("AddInOutForNetOutputOp failed: add input_desc failed."); - return GRAPH_FAILED; - } - - ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); - TensorUtils::SetOutputTensor(out_desc, true); - if (net_output_desc->AddOutputDesc(out_desc) != SUCCESS) { - GE_LOGE("AddInOutForNetOutputOp failed: add output_desc failed."); - return GRAPH_FAILED; - } - } - - GELOGD("Add input/output tensor for NetOutput node succ."); - return GRAPH_SUCCESS; -} - -/// -/// @brief Add edge for NetOutput node -/// @param [in] out_nodes_info -/// @param [out] net_output_node -/// @return graphStatus -/// -graphStatus CompleteGraphBuilder::AddEdgeForNetOutput(const std::vector> &out_nodes_info, - const NodePtr &net_output_node) { - if (net_output_node == nullptr) { - GE_LOGE("AddEdgeForNetOutputOp failed: NetOutput is NULL."); - return GRAPH_FAILED; - } - - size_t out_num = out_nodes_info.size(); - for (size_t i = 0; i < out_num; i++) { - NodePtr src_node = out_nodes_info[i].first; - uint32_t ind = out_nodes_info[i].second; - if (src_node == nullptr) { - GE_LOGE("AddEdgeForNetOutputOp failed: src_node is NULL."); - return GRAPH_FAILED; - } - - if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(ind), net_output_node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { - GE_LOGE("Add data-edge %s:%u->%s:%zu failed.", src_node->GetName().c_str(), ind, - net_output_node->GetName().c_str(), i); - return GRAPH_FAILED; + NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); + if (ret_val_node == nullptr) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add node failed."; + return; } - } - std::vector leaf_nodes; - for (auto &node : owner_graph_->GetDirectNode()) { - if (node->GetOutNodes().empty()) { - leaf_nodes.emplace_back(node); - } - } - for (auto &node : leaf_nodes) { - if (GraphUtils::AddEdge(node->GetOutControlAnchor(), net_output_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GE_LOGE("Add ctrl-edge %s->%s failed.", node->GetName().c_str(), net_output_node->GetName().c_str()); - return GRAPH_FAILED; + if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + error_code = GRAPH_FAILED; + error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + + "->" + ret_val_node->GetName() + ":0 failed."; + return; } } - GELOGD("Add edge for NetOutput node succ."); - return GRAPH_SUCCESS; + GELOGD("AddRetValNodes succ."); } /// @@ -1999,4 +2329,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & GELOGD("Build exist nodes succ."); } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus +GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { + std::vector stack_input; + std::map map_in_edge_num; + graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Sort nodes failed."); + return GRAPH_FAILED; + } + const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; + std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, + [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); + + std::queue stack; + NodePtr cur_node = nullptr; + std::map name_node_map; + vector nodes_name; + while (!stack_input.empty() || !stack.empty()) { + if (!stack.empty()) { + cur_node = stack.front(); + stack.pop(); + } else { + cur_node = stack_input.back(); + stack_input.pop_back(); + } + node_vec.emplace_back(cur_node); + compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); + for (const auto &iter : name_node_map) { + nodes_name.emplace_back(iter.first); + } + std::sort(nodes_name.begin(), nodes_name.end()); + for (const auto &iter : nodes_name) { + stack.push(name_node_map[iter]); + } + name_node_map.clear(); + nodes_name.clear(); + } + // If they are not equal, there is a closed loop + if (node_vec.size() != compute_graph->nodes_.size()) { + std::set itered_nodes_set; + for (auto &node : node_vec) { + itered_nodes_set.insert(node.get()); + } + GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", + compute_graph->nodes_.size(), node_vec.size()); + for (auto &node : compute_graph->nodes_) { + if (itered_nodes_set.count(node.get()) == 0) { + GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); + } + } + return GRAPH_FAILED; + } + return GRAPH_SUCCESS; +} + } // namespace ge diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index 52d81e43..8c2ff244 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -21,6 +21,7 @@ #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" #include "graph/debug/ge_attr_define.h" +#include "graph/types.h" #include "utils/tensor_utils.h" #include "utils/type_utils.h" @@ -28,6 +29,26 @@ namespace ge { std::map> NodeUtils::map_send_info_{}; std::map> NodeUtils::map_recv_info_{}; +bool OpShapeIsUnknown(const OpDescPtr &desc) { + for (const auto &ptr : desc->GetAllInputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + for (const auto &ptr : desc->GetAllOutputsDescPtr()) { + auto ge_shape = ptr->GetShape(); + for (const auto &dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { + return true; + } + } + } + return false; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, const uint32_t &event_id) { GE_CHECK_NOTNULL(node); @@ -282,18 +303,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); continue; } - auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->GetInputDescPtr(peer_anchor->GetIdx()); + auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); if (peer_input_desc == nullptr) { GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); continue; } - output_tensor.SetOriginFormat(peer_input_desc->GetOriginFormat()); - output_tensor.SetFormat(peer_input_desc->GetFormat()); - auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); - GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); - GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, - GELOGE(GRAPH_FAILED, "peer opdesc is null"); - continue); + GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), + output_tensor.GetDataType(), output_tensor.GetOriginDataType()); + peer_input_desc->SetShape(output_tensor.GetShape()); + peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); + peer_input_desc->SetDataType(output_tensor.GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); + ge::TensorUtils::SetRealDimCnt(*peer_input_desc, + static_cast(output_tensor.GetShape().GetDims().size())); + GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), + peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); } } return GRAPH_SUCCESS; @@ -361,6 +387,41 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const input_desc->SetShape(shape); return GRAPH_SUCCESS; } + +graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { + auto desc = node.GetOpDesc(); + GE_CHECK_NOTNULL(desc); + + auto sub_graph_names = desc->GetSubgraphInstanceNames(); + if (sub_graph_names.empty()) { + is_unknow = OpShapeIsUnknown(desc); + return GRAPH_SUCCESS; + } else { + auto owner_graph = node.GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); + if (root_graph == nullptr) { + GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); + return GRAPH_PARAM_INVALID; + } + for (auto &sub_graph_name : sub_graph_names) { + auto sub_graph = root_graph->GetSubgraph(sub_graph_name); + GE_CHECK_NOTNULL(sub_graph); + for (const auto &node_ptr : sub_graph->GetDirectNode()) { + auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); + if (status != GRAPH_SUCCESS) { + GE_LOGE("get node unknown shape status failed!"); + return status; + } + if (is_unknow) { + return GRAPH_SUCCESS; + } + } + } + } + return GRAPH_SUCCESS; +} + std::string NodeUtils::GetNodeType(const Node &node) { if (node.GetType() != FRAMEWORKOP) { return node.GetType(); @@ -381,9 +442,9 @@ ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); } -graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { +graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { if (subgraph == nullptr) { - GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); + GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); return GRAPH_PARAM_INVALID; } auto op_desc = node.GetOpDesc(); @@ -395,11 +456,105 @@ graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); return GRAPH_PARAM_INVALID; } - op_desc->AddSubgraphInstanceName(subgraph->GetName()); + auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); + if (ret != GRAPH_SUCCESS) { + GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); + return ret; + } subgraph->SetParentNode(node.shared_from_this()); subgraph->SetParentGraph(node.GetOwnerComputeGraph()); - root_graph->AddSubgraph(subgraph); + return root_graph->AddSubgraph(subgraph); +} - return GRAPH_SUCCESS; +/// +/// Check if node is input of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphInput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { + return false; + } + + return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); +} + +/// +/// Check if node is output of subgraph +/// @param [in] node +/// @return bool +/// +bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { + if ((node == nullptr) || (node->GetOpDesc() == nullptr) || + (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { + return false; + } + + for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { + if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { + return true; + } + } + + return false; +} + +/// +/// @brief Get subgraph original input node. +/// @param [in] node +/// @return Node +/// +NodePtr NodeUtils::GetParentInput(const NodePtr &node) { + GE_CHECK_NOTNULL_EXEC(node, return nullptr); + + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return nullptr; + } + + // Subgraph Data Node, check for constant input. + const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + + const NodePtr &parent_node = graph->GetParentNode(); + GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); + + const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); + GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); + + const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); + + return peer_out_anchor->GetOwnerNode(); +} + +/// +/// @brief Get subgraph input is constant. +/// @param [in] node +/// @param [out] string +/// @return bool +/// +bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { + GE_CHECK_NOTNULL_EXEC(in_node, return false); + + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + op_type = in_node->GetType(); + return true; + } + + if (in_node->GetType() == DATA) { + std::string const_type; + if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { + return false; + } + + if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { + op_type = const_type; + return true; + } + } + + return false; } } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index 89175b56..32ae00cf 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -469,7 +469,7 @@ OpDescUtils::SetWeights(ge::Node &node, const vector &weights) return GRAPH_PARAM_INVALID; } - ge::GeAttrValue::NamedAttrs named_attrs; + ge::GeAttrValue::NAMED_ATTRS named_attrs; (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); vector copy_weights; (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); @@ -578,7 +578,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei /// @return OpDescBuilder /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { - inputs_.emplace_back(name); + inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add input +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, + const GeTensorDesc &tensor) { + inputs_.emplace_back(std::make_pair(name, tensor)); return *this; } @@ -591,7 +603,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, uint32_t num) { for (uint32_t i = 0; i < num; i++) { - inputs_.emplace_back(name + std::to_string(i)); + inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic input +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( + const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); } return *this; } @@ -602,7 +629,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add /// @return OpDescBuilder /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { - outputs_.emplace_back(name); + outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); + return *this; +} + +/// +/// @brief Add output +/// @param [in] name +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, + const GeTensorDesc &tensor) { + outputs_.emplace_back(std::make_pair(name, tensor)); return *this; } @@ -615,7 +654,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, uint32_t num) { for (uint32_t i = 0; i < num; i++) { - outputs_.emplace_back(name + std::to_string(i)); + outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); + } + return *this; +} + +/// +/// @brief Add dynamic output +/// @param [in] name +/// @param [in] num +/// @param [in] tensor +/// @return OpDescBuilder +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( + const std::string &name, uint32_t num, const GeTensorDesc &tensor) { + for (uint32_t i = 0; i < num; i++) { + outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); } return *this; } @@ -632,14 +686,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() } for (auto &input : inputs_) { - if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { + if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Add input_desc failed."); return nullptr; } } for (auto &output : outputs_) { - if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { + if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Add output_desc failed."); return nullptr; } @@ -647,4 +701,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() return op_desc; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( + const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { + const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(subgraph_name); + if (iter == subgraph_names_to_index.end()) { + GELOGE(GRAPH_PARAM_INVALID, + "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", + subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), + subgraph_name.c_str()); + return GRAPH_PARAM_INVALID; + } + + return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); +} } // namespace ge diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc index 7b8ad3cd..072673c0 100644 --- a/src/common/graph/utils/tensor_utils.cc +++ b/src/common/graph/utils/tensor_utils.cc @@ -282,6 +282,7 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_FRACTAL_Z_3D: case FORMAT_FRACTAL_Z_3D_TRANSPOSE: case FORMAT_NDC1HWC0: + case FORMAT_FRACTAL_Z_C04: graph_status = CalcElementCntByDims(dims, element_cnt); break; default: diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index cd316260..e8ad9ed0 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -56,6 +56,7 @@ static const std::map kFormatToStringMap = { {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, {FORMAT_CN, "CN"}, {FORMAT_NC, "NC"}, + {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, {FORMAT_RESERVED, "FORMAT_RESERVED"}, {FORMAT_ALL, "ALL"}}; @@ -76,7 +77,8 @@ static const std::unordered_set kInternalFormat = {"NC1HWC0", "FRACTAL_NZ", "NDC1HWC0", "FORMAT_FRACTAL_Z_3D", - "FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; + "FORMAT_FRACTAL_Z_3D_TRANSPOSE" + "FORMAT_FRACTAL_ZN_LSTM"}; static const std::map kDataFormatMap = { {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; @@ -119,6 +121,7 @@ static const std::map kStringToFormatMap = { {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, {"CN", FORMAT_CN}, {"NC", FORMAT_NC}, + {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, {"FORMAT_RESERVED", FORMAT_RESERVED}, {"ALL", FORMAT_ALL}}; diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index 3e98fbcf..ba6f1d73 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -13,15 +13,18 @@ # limitations under the License. # ============================================================================ -# libge_compiler.so & libge_train.so +# libge_compiler.so & libge_runner.so # will later be integrated into libgraph_runner.so, works for both training and inference # compiling proto files generates some warnings, use no-unused-variable to suppress them set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../proto/fusion_model.proto" + "../proto/optimizer_priority.proto" ) - -file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "../proto/ge_api.proto" + ) +file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../proto/om.proto" "../proto/task.proto" "../proto/insert_op.proto" @@ -30,57 +33,46 @@ file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../proto/op_mapping_info.proto" ) ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) # include directories include_directories(${CMAKE_CURRENT_LIST_DIR}) include_directories(${GE_SOURCE_DIR}) include_directories(${GE_SOURCE_DIR}/src) include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/common/util) include_directories(${GE_SOURCE_DIR}/inc/external) include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/framework/common) include_directories(${GE_SOURCE_DIR}/inc/runtime) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) -######### libge_train.so ############# +######### libge_runner.so ############# # need to remove dependencies on pb files later file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} + "client/ge_api.cc" "common/formats/format_transfers/*.cc" "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" "common/ge/plugin_manager.cc" + "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" "generator/ge_generator.cc" "generator/generator_api.cc" - "graph/build/graph_builder.cc" - "graph/build/label_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/model_builder.cc" - "graph/build/run_context.cc" - "graph/build/stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/task_generator.cc" - "graph/common/bcast.cc" - "graph/common/omg_util.cc" - "graph/common/transop_util.cc" + "graph/build/*.cc" + "graph/common/*.cc" "graph/execute/graph_execute.cc" "graph/label/*.cc" "graph/load/graph_loader.cc" - "graph/load/new_model_manager/cpu_queue_schedule.cc" - "graph/load/new_model_manager/data_dumper.cc" - "graph/load/new_model_manager/data_inputer.cc" - "graph/load/new_model_manager/davinci_model.cc" - "graph/load/new_model_manager/davinci_model_parser.cc" - "graph/load/new_model_manager/model_manager.cc" - "graph/load/new_model_manager/model_output.cc" - "graph/load/new_model_manager/model_utils.cc" + "graph/load/new_model_manager/*.cc" "graph/load/new_model_manager/task_info/end_graph_task_info.cc" "graph/load/new_model_manager/task_info/event_record_task_info.cc" "graph/load/new_model_manager/task_info/event_wait_task_info.cc" @@ -89,8 +81,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/hccl_task_info.cc" "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -99,15 +93,9 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/new_model_manager/tbe_handle_store.cc" "graph/load/output/output.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_var_manager.cc" + "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" - "graph/manager/trans_var_data_utils.cc" "graph/manager/util/debug.cc" "graph/manager/util/hcom_util.cc" "graph/manager/util/rt_context_util.cc" @@ -115,27 +103,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/optimize/graph_optimize.cc" "graph/optimize/optimizer/allreduce_fusion_pass.cc" "graph/optimize/summary_optimize.cc" + "graph/partition/dynamic_shape_partition.cc" "graph/partition/engine_place.cc" "graph/partition/graph_partition.cc" - "graph/passes/addn_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/assert_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/base_pass.cc" - "graph/passes/cast_remove_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" - "graph/passes/compile_nodes_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/control_op_attr_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/dimension_adjust_pass.cc" - "graph/passes/dimension_compute_pass.cc" - "graph/passes/dropout_pass.cc" - "graph/passes/end_graph_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/flow_ctrl_pass.cc" + "graph/passes/*.cc" "graph/passes/folding_kernel/add_kernel.cc" "graph/passes/folding_kernel/broadcast_args_kernel.cc" "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" @@ -171,51 +142,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/passes/folding_kernel/sub_kernel.cc" "graph/passes/folding_kernel/transdata_kernel.cc" "graph/passes/folding_kernel/unpack_kernel.cc" - "graph/passes/folding_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/identify_reference_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/multi_batch_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/next_iteration_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/pass_manager.cc" - "graph/passes/pass_utils.cc" - "graph/passes/permute_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/save_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/switch_op_pass.cc" - "graph/passes/switch_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" - "graph/passes/transpose_transdata_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/unused_op_remove_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/variable_format_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" @@ -231,22 +157,17 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} ) -######### libge_train.so ############# -add_library(ge_train SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_train PRIVATE +######### libge_runner.so ############# +add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) +target_compile_definitions(ge_runner PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 DAVINCI_SUPPORT_PROFILING REUSE_MEMORY=1 - DAVINCI_TRAIN - DAVINCI_CLOUD - FMK_SUPPORT_DEBUG - PLATFORM_CLOUD) -target_link_libraries(ge_train + DAVINCI_CLOUD) +target_link_libraries(ge_runner graph ge_common - "-Wl,--whole-archive" ge_memory - "-Wl,--no-whole-archive" ${PROTOBUF_LIBRARY} ${register} ${c_sec} @@ -267,33 +188,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" "common/ge/plugin_manager.cc" + "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" "generator/ge_generator.cc" "generator/generator_api.cc" - "graph/build/graph_builder.cc" - "graph/build/label_allocator.cc" - "graph/build/logical_stream_allocator.cc" - "graph/build/model_builder.cc" - "graph/build/run_context.cc" - "graph/build/stream_allocator.cc" - "graph/build/stream_graph_optimizer.cc" - "graph/build/task_generator.cc" - "graph/common/bcast.cc" - "graph/common/omg_util.cc" - "graph/common/transop_util.cc" + "graph/build/*.cc" + "graph/common/*.cc" "graph/execute/graph_execute.cc" "graph/label/*.cc" "graph/load/graph_loader.cc" - "graph/load/new_model_manager/cpu_queue_schedule.cc" - "graph/load/new_model_manager/data_dumper.cc" - "graph/load/new_model_manager/data_inputer.cc" - "graph/load/new_model_manager/davinci_model.cc" - "graph/load/new_model_manager/davinci_model_parser.cc" - "graph/load/new_model_manager/model_manager.cc" - "graph/load/new_model_manager/model_output.cc" - "graph/load/new_model_manager/model_utils.cc" + "graph/load/new_model_manager/*.cc" "graph/load/new_model_manager/task_info/end_graph_task_info.cc" "graph/load/new_model_manager/task_info/event_record_task_info.cc" "graph/load/new_model_manager/task_info/event_wait_task_info.cc" @@ -301,8 +207,10 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" "graph/load/new_model_manager/task_info/kernel_task_info.cc" - "graph/load/new_model_manager/task_info/label_goto_task_info.cc" + "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" "graph/load/new_model_manager/task_info/label_set_task_info.cc" + "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -311,41 +219,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/new_model_manager/tbe_handle_store.cc" "graph/load/output/output.cc" - "graph/manager/graph_context.cc" - "graph/manager/graph_manager.cc" - "graph/manager/graph_manager_utils.cc" - "graph/manager/graph_mem_allocator.cc" - "graph/manager/graph_var_manager.cc" + "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" - "graph/manager/trans_var_data_utils.cc" "graph/manager/util/debug.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" "graph/optimize/graph_optimize.cc" "graph/optimize/summary_optimize.cc" + "graph/partition/dynamic_shape_partition.cc" "graph/partition/engine_place.cc" "graph/partition/graph_partition.cc" - "graph/passes/addn_pass.cc" - "graph/passes/aicpu_constant_folding_pass.cc" - "graph/passes/assert_pass.cc" - "graph/passes/atomic_addr_clean_pass.cc" - "graph/passes/base_pass.cc" - "graph/passes/cast_remove_pass.cc" - "graph/passes/cast_translate_pass.cc" - "graph/passes/common_subexpression_elimination_pass.cc" - "graph/passes/compile_nodes_pass.cc" - "graph/passes/constant_folding_pass.cc" - "graph/passes/constant_fuse_same_pass.cc" - "graph/passes/control_op_attr_pass.cc" - "graph/passes/control_trigger_pass.cc" - "graph/passes/dimension_adjust_pass.cc" - "graph/passes/dimension_compute_pass.cc" - "graph/passes/dropout_pass.cc" - "graph/passes/end_graph_pass.cc" - "graph/passes/enter_pass.cc" - "graph/passes/flow_ctrl_pass.cc" + "graph/passes/*.cc" "graph/passes/folding_kernel/add_kernel.cc" "graph/passes/folding_kernel/broadcast_args_kernel.cc" "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" @@ -380,87 +265,33 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/passes/folding_kernel/strided_slice_kernel.cc" "graph/passes/folding_kernel/sub_kernel.cc" "graph/passes/folding_kernel/transdata_kernel.cc" + "graph/passes/folding_kernel/transpose_kernel.cc" "graph/passes/folding_kernel/unpack_kernel.cc" - "graph/passes/folding_pass.cc" - "graph/passes/get_original_format_pass.cc" - "graph/passes/guarantee_const_pass.cc" - "graph/passes/hccl_memcpy_pass.cc" - "graph/passes/identify_reference_pass.cc" - "graph/passes/identity_pass.cc" - "graph/passes/infershape_pass.cc" - "graph/passes/isolated_op_remove_pass.cc" - "graph/passes/iterator_op_pass.cc" - "graph/passes/link_gen_mask_nodes_pass.cc" - "graph/passes/merge_pass.cc" - "graph/passes/multi_batch_pass.cc" - "graph/passes/net_output_pass.cc" - "graph/passes/next_iteration_pass.cc" - "graph/passes/no_use_reshape_remove_pass.cc" - "graph/passes/pass_manager.cc" - "graph/passes/pass_utils.cc" - "graph/passes/permute_pass.cc" - "graph/passes/placeholder_with_default_pass.cc" - "graph/passes/prevent_gradient_pass.cc" - "graph/passes/print_op_pass.cc" - "graph/passes/prune_pass.cc" - "graph/passes/reshape_remove_pass.cc" - "graph/passes/resource_pair_add_control_pass.cc" - "graph/passes/resource_pair_remove_control_pass.cc" - "graph/passes/same_transdata_breadth_fusion_pass.cc" - "graph/passes/save_pass.cc" - "graph/passes/shape_operate_op_remove_pass.cc" - "graph/passes/snapshot_pass.cc" - "graph/passes/stop_gradient_pass.cc" - "graph/passes/switch_logic_remove_pass.cc" - "graph/passes/switch_op_pass.cc" - "graph/passes/switch_pass.cc" - "graph/passes/transop_breadth_fusion_pass.cc" - "graph/passes/transop_depth_fusion_pass.cc" - "graph/passes/transop_nearby_allreduce_fusion_pass.cc" - "graph/passes/transop_without_reshape_fusion_pass.cc" - "graph/passes/transpose_transdata_pass.cc" - "graph/passes/unused_const_pass.cc" - "graph/passes/unused_op_remove_pass.cc" - "graph/passes/var_is_initialized_op_pass.cc" - "graph/passes/variable_format_pass.cc" - "graph/passes/variable_op_pass.cc" - "graph/passes/variable_prepare_op_pass.cc" - "graph/passes/variable_ref_delete_op_pass.cc" "graph/preprocess/graph_preprocess.cc" "graph/preprocess/insert_op/ge_aipp_op.cc" "graph/preprocess/insert_op/util_insert_aipp_op.cc" "graph/preprocess/multi_batch_copy_graph.cc" "init/gelib.cc" + "ir_build/atc_ir_common.cc" + "ir_build/ge_ir_build.cc" "model/ge_model.cc" "omm/csa_interact.cc" "opskernel_manager/ops_kernel_manager.cc" "session/inner_session.cc" "session/session_manager.cc" - "single_op/single_op.cc" - "single_op/single_op_manager.cc" - "single_op/single_op_model.cc" - "single_op/stream_resource.cc" - "single_op/task/build_task_utils.cc" - "single_op/task/op_task.cc" - "single_op/task/tbe_task_builder.cc" -########################################## -# "ir_build/ge_ir_build.cc" -# "offline/atc_ir_common.cc" + "single_op/*.cc" + "single_op/task/*.cc" ) add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) target_compile_definitions(ge_compiler PRIVATE PROTOBUF_INLINE_NOT_IN_HEADERS=0 - DAVINCI_SUPPORT_PROFILING REUSE_MEMORY=1 - FMK_HOST_INFER - PLATFORM_CLOUD) + FMK_HOST_INFER) target_link_libraries(ge_compiler graph ge_common - "-Wl,--whole-archive" ge_memory - "-Wl,--no-whole-archive" ${PROTOBUF_LIBRARY} ${register} ${c_sec} @@ -469,5 +300,6 @@ target_link_libraries(ge_compiler ${msprof} ${runtime} ${resouce} + ${error_manager} rt dl) diff --git a/src/ge/client/CMakeLists.txt b/src/ge/client/CMakeLists.txt index a99b4eb1..afdbd141 100755 --- a/src/ge/client/CMakeLists.txt +++ b/src/ge/client/CMakeLists.txt @@ -13,21 +13,21 @@ # limitations under the License. # ============================================================================ -# libge_client.so & libge_client_train.so +# libge_client.so # add all proto files, generate corresponding .h and .cc files set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") -file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../../proto/ge_api.proto" ) -file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../../proto/ge_ir.proto" "../../proto/task.proto" "../../proto/om.proto" "../../proto/insert_op.proto" ) -file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ge_api.cc" ) @@ -49,30 +49,9 @@ include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) -######### libge_client_train.so ############# -add_library(ge_client_train SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_client_train PRIVATE - Werror - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - REUSE_MEMORY=1 - PLATFORM_CLOUD - DAVINCI_CLOUD) -target_link_libraries(ge_client_train - graph - ge_train - ge_common - ${PROTOBUF_LIBRARY} - ${register} - ${c_sec} - ${slog} - ${mmpa} - ${runtime} - rt - dl) - ############ libge_client.so ################ add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) -target_compile_definitions(ge_client_train PRIVATE +target_compile_definitions(ge_client PRIVATE Werror PROTOBUF_INLINE_NOT_IN_HEADERS=0 REUSE_MEMORY=1 diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index 9b9e5568..24126de1 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -32,17 +32,18 @@ using domi::GetContext; using domi::OpRegistry; -using domi::RealPath; -using domi::StringUtils; using std::map; using std::string; using std::vector; -namespace ge { -static const int32_t kMaxStrLen = 128; +namespace { +const int32_t kMaxStrLen = 128; +} + static bool kGeInitialized = false; static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use +namespace ge { void GetOpsProtoPath(std::string &opsproto_path) { GELOGI("Enter get ops proto path schedule"); const char *path_env = std::getenv("ASCEND_OPP_PATH"); @@ -394,8 +395,8 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); } -Status Session::RunGraphAsync(uint32_t graph_id, const std::vector &inputs, - std::vector &outputs, std::function callback) { +Status Session::RunGraphAsync(uint32_t graph_id, const std::vector &inputs, + RunAsyncCallback callback) { std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); @@ -405,8 +406,7 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector & GELOGW( "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); - Status ret = - ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, outputs, callback); + Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); if (ret != SUCCESS) { GELOGE(ret, "SessionManager RunGraphAsync failed"); return FAILED; diff --git a/src/ge/common/CMakeLists.txt b/src/ge/common/CMakeLists.txt index 2f43e2ff..adcdb1bc 100755 --- a/src/ge/common/CMakeLists.txt +++ b/src/ge/common/CMakeLists.txt @@ -28,7 +28,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "debug/memory_dumper.cc" "fmk_error_codes.cc" "formats/format_transfers/datatype_transfer.cc" - "formats/format_transfers/format_transfer.cc" "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" @@ -41,6 +40,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" + "formats/format_transfers/format_transfer_nchw_fz_c04.cc" "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" "formats/format_transfers/format_transfer_transpose.cc" @@ -54,6 +54,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "helper/om_file_helper.cc" "math/fp16_math.cc" "model_parser/base.cc" + "model_saver.cc" "op/attr_value_util.cc" "op/ge_op_utils.cc" "properties_manager.cc" @@ -61,9 +62,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "thread_pool.cc" "types.cc" "util.cc" - "model_saver.cc" - ############################### - "op/attr_define.cc" ) ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) @@ -73,6 +71,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) include_directories(${CMAKE_CURRENT_LIST_DIR}/op) include_directories(${GE_SOURCE_DIR}/src/ge) include_directories(${GE_SOURCE_DIR}/inc) +include_directories(${GE_SOURCE_DIR}/inc/common/util) include_directories(${GE_SOURCE_DIR}/inc/external) include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) @@ -96,5 +95,6 @@ target_link_libraries(ge_common ${slog} ${mmpa} ${resource} + ${error_manager} rt dl) diff --git a/src/ge/common/auth/file_saver.cc b/src/ge/common/auth/file_saver.cc index 04638ecf..1dc42fad 100644 --- a/src/ge/common/auth/file_saver.cc +++ b/src/ge/common/auth/file_saver.cc @@ -17,7 +17,6 @@ #include "common/auth/file_saver.h" #include - #include #include #include @@ -29,10 +28,6 @@ #include "framework/common/debug/log.h" #include "framework/common/util.h" -using domi::CreateDirectory; -using domi::ModelEncryptType; -using ge::ModelBufferData; - namespace { const int kFileOpSuccess = 0; } // namespace @@ -270,4 +265,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co } return ret; } -} // namespace ge +} // namespace ge diff --git a/src/ge/common/auth/file_saver.h b/src/ge/common/auth/file_saver.h index a4473050..d415746d 100644 --- a/src/ge/common/auth/file_saver.h +++ b/src/ge/common/auth/file_saver.h @@ -26,30 +26,26 @@ #include "graph/buffer.h" #include "mmpa/mmpa_api.h" -using domi::ModelFileHeader; -using domi::ModelPartition; -using domi::ModelPartitionTable; - struct PROC_PARAM { uint8_t *model_name; - /* ISV Ek buffer */ + // ISV Ek buffer uint8_t *model_key; uint32_t model_key_len; - /* ISV root certificate buffer */ + // ISV root certificate buffer uint8_t *root_cert; uint32_t root_cert_len; - /* ISV private key buffer */ + // ISV private key buffer uint8_t *pri_key; uint32_t pri_key_len; - /* Raw AI Module Image buffer */ + // Raw AI Module Image buffer uint8_t *ai_image; uint32_t ai_image_len; - /* ISV HW key buffer */ + // ISV HW key buffer uint8_t *hw_key; uint32_t hw_key_len; }; @@ -66,11 +62,11 @@ using std::string; class FileSaver { public: - /** - * @ingroup domi_common - * @brief save model, no encryption - * @return Status result - */ + /// + /// @ingroup domi_common + /// @brief save model, no encryption + /// @return Status result + /// static Status SaveToFile(const string &file_path, const ge::ModelData &model, const ModelFileHeader *model_file_header = nullptr); @@ -84,26 +80,26 @@ class FileSaver { static Status SaveToFile(const string &file_path, const void *data, int len); protected: - /** - * @ingroup domi_common - * @brief Check validity of the file path - * @return Status result - */ + /// + /// @ingroup domi_common + /// @brief Check validity of the file path + /// @return Status result + /// static Status CheckPath(const string &file_path); static Status WriteData(const void *data, uint32_t size, int32_t fd); static Status OpenFile(int32_t &fd, const std::string &file_path); - /** - * @ingroup domi_common - * @brief save model to file - * @param [in] file_path file output path - * @param [in] file_header file header info - * @param [in] data model data - * @param [in] len model length - * @return Status result - */ + /// + /// @ingroup domi_common + /// @brief save model to file + /// @param [in] file_path file output path + /// @param [in] file_header file header info + /// @param [in] data model data + /// @param [in] len model length + /// @return Status result + /// static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, int len); diff --git a/src/ge/common/context/ctx.cc b/src/ge/common/context/ctx.cc index 34ba5d25..f6ae364d 100644 --- a/src/ge/common/context/ctx.cc +++ b/src/ge/common/context/ctx.cc @@ -16,6 +16,7 @@ #include "framework/omg/omg_inner_types.h" +using ge::OmgContext; namespace domi { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { static OmgContext context; diff --git a/src/ge/common/convert/pb2json.cc b/src/ge/common/convert/pb2json.cc index 7c53968a..88b2a332 100644 --- a/src/ge/common/convert/pb2json.cc +++ b/src/ge/common/convert/pb2json.cc @@ -155,7 +155,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const set &black_fields, Json &json, bool enum2str) { - if (nullptr == field || nullptr == reflection) { + if ((field == nullptr) || (reflection == nullptr)) { Message2Json(message, black_fields, json, enum2str); return; } diff --git a/src/ge/common/debug/memory_dumper.cc b/src/ge/common/debug/memory_dumper.cc index 0534ff81..56724be8 100644 --- a/src/ge/common/debug/memory_dumper.cc +++ b/src/ge/common/debug/memory_dumper.cc @@ -28,7 +28,9 @@ using std::string; -static const int kInvalidFd = (-1); +namespace { +const int kInvalidFd = (-1); +} // namespace namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.cc b/src/ge/common/formats/format_transfers/datatype_transfer.cc index bac3a178..0bd4b8e5 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.cc +++ b/src/ge/common/formats/format_transfers/datatype_transfer.cc @@ -16,7 +16,7 @@ #include "common/formats/format_transfers/datatype_transfer.h" -#include +#include #include #include @@ -27,8 +27,6 @@ #include "graph/utils/type_utils.h" #include "securec.h" -using ge::fp16_t; - namespace ge { namespace formats { @@ -134,10 +132,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result } auto trans_mode = iter->second; - if (args.src_data_size == 0) { - GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); - return PARAM_INVALID; - } int size = GetSizeByDataType(args.dst_data_type); if (size <= 0) { GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", @@ -149,6 +143,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result return PARAM_INVALID; } size_t total_size = static_cast(args.src_data_size * size); + result.length = total_size; + if (total_size == 0) { + GELOGI("In TransDataType, total_size is zero, has no data."); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); @@ -162,7 +162,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result return INTERNAL_ERROR; } result.data = dst; - result.length = total_size; return SUCCESS; } diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.h b/src/ge/common/formats/format_transfers/datatype_transfer.h index fe13a9b8..0702592f 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.h +++ b/src/ge/common/formats/format_transfers/datatype_transfer.h @@ -21,7 +21,7 @@ #include #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" #include "external/graph/types.h" #include "framework/common/ge_inner_error_codes.h" diff --git a/src/ge/common/formats/format_transfers/format_transfer.cc b/src/ge/common/formats/format_transfers/format_transfer.cc deleted file mode 100644 index 76ba8192..00000000 --- a/src/ge/common/formats/format_transfers/format_transfer.cc +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/formats/format_transfers/format_transfer.h" - -#include -#include - -#include "framework/common/debug/ge_log.h" -#include "graph/utils/type_utils.h" - -namespace ge { -namespace formats { -namespace { -struct FormatTransferRegistry { - Status RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) { - src_dst_builder[src][dst] = std::move(builder); - return SUCCESS; - } - std::map> src_dst_builder; -}; - -FormatTransferRegistry &GetFormatTransferRegistry() { - static FormatTransferRegistry registry; - return registry; -} -} // namespace - -std::shared_ptr BuildFormatTransfer(const TransArgs &args) { - auto registry = GetFormatTransferRegistry(); - auto dst_builder = registry.src_dst_builder.find(args.src_format); - if (dst_builder == registry.src_dst_builder.end()) { - return nullptr; - } - auto builder_iter = dst_builder->second.find(args.dst_format); - if (builder_iter == dst_builder->second.end()) { - return nullptr; - } - return builder_iter->second(); -} - -bool FormatTransferExists(const TransArgs &args) { - auto registry = GetFormatTransferRegistry(); - auto dst_builder = registry.src_dst_builder.find(args.src_format); - if (dst_builder == registry.src_dst_builder.end()) { - return false; - } - return dst_builder->second.count(args.dst_format) > 0; -} - -FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) { - (void)GetFormatTransferRegistry().RegisterBuilder(src, dst, std::move(builder)); - // RegisterBuilder() always return success, no need to check value -} -} // namespace formats -} // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc index 28d713b5..40dc749d 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc @@ -27,7 +27,9 @@ namespace ge { namespace formats { namespace { -bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } +bool CheckDataTypeSupported(const DataType &data_type) { + return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); +} Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { auto src_shape = args.src_shape; @@ -51,10 +53,11 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; } - if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || + auto cube_size = GetCubeSizeByDataType(args.src_data_type); + if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || - src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || - src_shape.at(kC1hwncoc0C0) != kCubeSize) { + src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || + src_shape.at(kC1hwncoc0C0) != cube_size) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -78,6 +81,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto c0 = args.src_shape.at(kC1hwncoc0C0); auto co = args.src_shape.at(kC1hwncoc0Co); auto c = args.dst_shape.at(kHwcnC); + auto cube_size = GetCubeSizeByDataType(args.src_data_type); int64_t cn = c * n; int64_t wcn = w * cn; int64_t coc0 = co * c0; @@ -93,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size int64_t c_head_addr = w_head_addr + c_idx * n; for (int64_t n_idx = 0; n_idx < n; n_idx++) { int64_t dst_idx = c_head_addr + n_idx; - int64_t c1_idx = c_idx / kCubeSize; - int64_t c0_idx = c_idx % kCubeSize; + int64_t c1_idx = c_idx / cube_size; + int64_t c0_idx = c_idx % cube_size; int64_t co_idx = c0_idx; int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; auto src_offset = src_idx * size; @@ -130,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu int size = GetSizeByDataType(args.src_data_type); int64_t total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h index fccc4524..d2156018 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h +++ b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc index 45808fa0..dc8e1033 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc @@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h index 55549cb3..41581dec 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.h @@ -18,7 +18,7 @@ #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc index 86c6935d..11e3d270 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc @@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul dst_size *= dim; } dst_size *= data_size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h index 82a9e14f..1c4986b8 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.h @@ -18,7 +18,7 @@ #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index 76834437..ff7b84a4 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h index c593aa7c..49e82884 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index aedc7589..f3d06496 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = total_ele_cnt * size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h index 9653f3e7..5ae83303 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index be0c3abb..d5507765 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; + if (dst_size == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h index 4250ce93..93f40920 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc index 3453c232..b0eebcfa 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h index 49d8d336..a7efbfcb 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc index 6f616051..9f8d9e39 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc @@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; @@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str(), total_size); + if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h index 312a10f2..af2cedd0 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc index 57b840af..9a1e5f3b 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc @@ -132,6 +132,12 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h index 5a908dbb..41654304 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc index fbadb4c3..7101256a 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc @@ -27,16 +27,20 @@ namespace ge { namespace formats { namespace { -bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } +bool CheckDataTypeSupported(const DataType &data_type) { + return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); +} -Status TransShapeHwcnToC1hwncoc0(const std::vector &src_shape, std::vector &dst_shape) { +Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector &src_shape, + std::vector &dst_shape) { + auto cube_size = GetCubeSizeByDataType(data_type); dst_shape.clear(); - dst_shape.push_back((src_shape.at(kHwcnC) - 1) / kCubeSize + 1); + dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); dst_shape.push_back(src_shape.at(kHwcnH)); dst_shape.push_back(src_shape.at(kHwcnW)); dst_shape.push_back(src_shape.at(kHwcnN)); - dst_shape.push_back(kCubeSize); - dst_shape.push_back(kCubeSize); + dst_shape.push_back(cube_size); + dst_shape.push_back(cube_size); if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -65,7 +69,7 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { return PARAM_INVALID; } std::vector expect_dst_shape; - auto ret = TransShapeHwcnToC1hwncoc0(args.src_shape, expect_dst_shape); + auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); if (ret != SUCCESS) { return ret; } @@ -165,6 +169,12 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; @@ -188,7 +198,7 @@ Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vec GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); return PARAM_INVALID; } - return TransShapeHwcnToC1hwncoc0(src_shape, dst_shape); + return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); } else { return UNSUPPORTED; } diff --git a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h index 56270cd1..81d7358e 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h +++ b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc index 0a5af5ff..57ab1266 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { } if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || - src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNchwC) - 1) / c0 + 1) { + src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h index b3fe65f8..6d599933 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc index 92fd5772..e68e54de 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { } if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || - src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNhwcC) - 1) / c0 + 1) { + src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h index 22bc170b..8ff60bb1 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc new file mode 100644 index 00000000..f79d358b --- /dev/null +++ b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc @@ -0,0 +1,314 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/formats/format_transfers/format_transfer_nchw_fz_c04.h" +#include "common/formats/format_transfers/format_transfer_transpose.h" + +#include +#include +#include + +#include "common/formats/utils/formats_definitions.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/type_utils.h" + +/** ã€Explain about transfer from nchw to FZ_CO4】 + * First Step: Padding in N and C axis. Here C must be less or equal than 4 + * After Padding, it will be like (n = ceil(n,16)*16, 4, h, w) + * Second Step: transpose. It will be like (n = ceil(n,16)*16, h, w, 4) + * Third Step: View the 4D as 2D , first dim is N, second dim is h*w*c. + * Padding to (N, ceil(Z/16)*16) + * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) + */ + +namespace ge { +namespace formats { +namespace { + +constexpr int64_t kMaxDimsNumC = 4; + +Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } + +Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector &dst_shape) { + auto c0 = GetCubeSizeByDataType(data_type); + if (c0 < 0) { + return UNSUPPORTED; + } + auto chw = c * h * w; + + auto first_dim = Ceil(chw, c0); + auto no = Ceil(n, static_cast(c0)); + + dst_shape.clear(); + dst_shape.push_back(first_dim); + dst_shape.push_back(no); + dst_shape.push_back(c0); + dst_shape.push_back(c0); + + if (!IsShapeValid(dst_shape)) { + GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status TransShapeNchwToFzC04(const std::vector &src_shape, DataType data_type, + std::vector &dst_shape) { + if (!CheckShapeValid(src_shape, kNchwDimsNum)) { + return PARAM_INVALID; + } + + auto n = src_shape.at(kNchwN); + auto c = src_shape.at(kNchwC); + auto h = src_shape.at(kNchwH); + auto w = src_shape.at(kNchwW); + return TransShape(n, c, h, w, data_type, dst_shape); +} + +Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { + int64_t n = args.src_shape.at(kNchwN); + int64_t c = args.src_shape.at(kNchwC); + int64_t h = args.src_shape.at(kNchwH); + int64_t w = args.src_shape.at(kNchwW); + + int64_t c0 = GetCubeSizeByDataType(args.src_data_type); + int size = GetSizeByDataType(args.src_data_type); + + auto data = args.data; + TransResult trans_result_1; + std::vector perm_arg_1 = {0, 2, 3, 1}; + std::vector expect_shape = {n, h, w, c}; + auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); + return NOT_CHANGED; + } + + TransArgs args_tmp = args; + args_tmp.src_shape = expect_shape; + args_tmp.data = trans_result_1.data.get(); + // check size it should be same with original + size_t expect_size = n * c * h * w * size; // before has do check about mul + if (trans_result_1.length != expect_size) { + GELOGE(INTERNAL_ERROR, "size is not match after transpose!"); + return NOT_CHANGED; + } + + /* prepare for padding in chw*/ + int64_t tmp = h * w * c; + int64_t n_o = Ceil(n, static_cast(c0)); + int64_t c_o = c0; + int64_t h_o = Ceil(tmp, c0); + int64_t w_o = c0; + std::vector shape_o = {n_o, c_o, h_o, w_o}; + + // data overflow check totally + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); + return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); + return INTERNAL_ERROR); + auto t1 = h_o * w_o; + auto t2 = n_o * c_o; + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); + return INTERNAL_ERROR); + + int64_t total_ele_cnt = n_o * c_o * h_o * w_o; + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); + return INTERNAL_ERROR); + int64_t dst_size = total_ele_cnt * size; + if (dst_size == 0) { + result.length = 0; + return SUCCESS; + } + + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); + if (dst == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); + return OUT_OF_MEMORY; + } + auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); + if (retMem != EOK) { + GELOGE(INTERNAL_ERROR, "memst failed!"); + return INTERNAL_ERROR; + } + // copy data + auto block = c * h * w * size; + auto stride = h_o * w_o * size; + auto p_s = trans_result_1.data.get(); + auto p_d = dst.get(); + auto protectSize = dst_size; + for (auto k = 0; k < n; k++) { + ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); + return INTERNAL_ERROR; + } + protectSize = protectSize - block; + } + + // transpose : 2,0,1,3 + std::vector perm_arg_2 = {2, 0, 1, 3}; + ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); + return NOT_CHANGED; + } + + return SUCCESS; +} + +Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr &dst) { + args_tmp = args; + auto src_shape = args_tmp.src_shape; + if (!CheckShapeValid(src_shape, kNchwDimsNum)) { + return PARAM_INVALID; + } + int64_t c0 = GetCubeSizeByDataType(args.src_data_type); + + auto n = src_shape.at(kNchwN); + auto c = src_shape.at(kNchwC); + auto h = src_shape.at(kNchwH); + auto w = src_shape.at(kNchwW); + + if (c > kMaxDimsNumC) { + GELOGE(PARAM_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); + return PARAM_INVALID; + } + + auto n_o = Ceil(n, c0) * c0; + auto c_o = kMaxDimsNumC; + auto h_o = h; + auto w_o = w; + args_tmp.src_shape.at(kNchwN) = n_o; + args_tmp.src_shape.at(kNchwC) = c_o; + args_tmp.src_shape.at(kNchwH) = h_o; + args_tmp.src_shape.at(kNchwW) = w_o; + + // data overflow check + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); + return INTERNAL_ERROR); + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); + return INTERNAL_ERROR); + auto t1 = h_o * w_o; + auto t2 = n_o * c_o; + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); + return INTERNAL_ERROR); + + int64_t total_ele_cnt = n_o * c_o * h_o * w_o; + int size = GetSizeByDataType(args.src_data_type); + GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), + GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); + return INTERNAL_ERROR); + + int64_t dst_size = total_ele_cnt * size; + if (dst_size == 0) { + return SUCCESS; + } + + dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete()); + if (dst == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); + return OUT_OF_MEMORY; + } + auto ret = memset_s(dst.get(), dst_size, 0, dst_size); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, "memst failed!"); + return INTERNAL_ERROR; + } + + auto p_s = args.data; + auto p_d = dst.get(); + auto block = h * w * size; + auto protectSize = dst_size; + + for (int i = 0; i < n; i++) { + for (int j = 0; j < c; j++) { + ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, + p_s + (i * c * h * w + j * h * w) * size, block); + if (ret != EOK) { + GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); + return INTERNAL_ERROR; + } + protectSize = protectSize - block; + } + } + args_tmp.data = dst.get(); + + return SUCCESS; +} +} // namespace + +Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult &result) { + GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", + TypeUtils::FormatToSerialString(args.src_format).c_str(), + TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), + TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); + TransArgs args_tmp = args; + std::shared_ptr dst = nullptr; + auto ret = PaddingNC(args, args_tmp, dst); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Padding in NC axis failed!"); + return ret; + } + + std::vector expect_shape; + ret = TransShape(args_tmp.src_format, args_tmp.src_shape, args_tmp.src_data_type, args_tmp.dst_format, expect_shape); + if (ret != SUCCESS) { + return ret; + } + + if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { + GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", + TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), + TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), + ShapeToString(expect_shape).c_str()); + return PARAM_INVALID; + } + + if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { + return TransFormatFromNchwToFzC04(args_tmp, result); + } + + return UNSUPPORTED; +} + +Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector &src_shape, + DataType data_type, Format dst_format, std::vector &dst_shape) { + if (CheckDataTypeSupport(data_type) != SUCCESS) { + return UNSUPPORTED; + } + if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { + return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); + } + + return UNSUPPORTED; +} + +REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) + +} // namespace formats +} // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h new file mode 100644 index 00000000..4a0fce95 --- /dev/null +++ b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ +#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ + +#include + +#include "register/register_format_transfer.h" + +namespace ge { +namespace formats { +class FormatTransferNchwToFZC04 : public FormatTransfer { + public: + Status TransFormat(const ge::formats::TransArgs &args, ge::formats::TransResult &result) override; + Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape) override; +}; +} // namespace formats +} // namespace ge + +#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc index 7b90c6a8..b4e92cbc 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc @@ -40,7 +40,7 @@ Status TransShapeNchwToNc1hwc0(const std::vector &src_shape, DataType d } dst_shape.clear(); dst_shape.push_back(src_shape.at(kNchwN)); - dst_shape.push_back((src_shape.at(kNchwC) - 1) / c0 + 1); + dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); dst_shape.push_back(src_shape.at(kNchwH)); dst_shape.push_back(src_shape.at(kNchwW)); dst_shape.push_back(c0); @@ -74,25 +74,8 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { return SUCCESS; } -} // namespace -Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { - if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { - return PARAM_INVALID; - } - // Guarantee the validity of parameters in check function - int size = GetSizeByDataType(args.src_data_type); - auto total_size = GetItemNumByShape(args.dst_shape) * size; - if (total_size <= 0) { - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, - ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); - return PARAM_INVALID; - } - GELOGD( - "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " - "%s, dst shape %s memory size %ld", - ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), - ShapeToString(args.dst_shape).c_str(), total_size); +Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, @@ -169,6 +152,39 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult result.length = static_cast(total_size); return SUCCESS; } +} // namespace + +Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { + if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { + return PARAM_INVALID; + } + // Guarantee the validity of parameters in check function + int size = GetSizeByDataType(args.src_data_type); + auto total_size = GetItemNumByShape(args.dst_shape) * size; + if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, + ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + GELOGD( + "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " + "%s, dst shape %s memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); + return INTERNAL_ERROR; + } + return SUCCESS; +} Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h index 272b6a50..c6269579 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h +++ b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index 26e533fc..a5be94ff 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -38,7 +38,7 @@ Status TransShapeNhwcToNc1hwc0(const std::vector &src_shape, DataType d } dst_shape.clear(); dst_shape.push_back(src_shape.at(kNhwcN)); - dst_shape.push_back((src_shape.at(kNhwcC) - 1) / c0 + 1); + dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); dst_shape.push_back(src_shape.at(kNhwcH)); dst_shape.push_back(src_shape.at(kNhwcW)); dst_shape.push_back(c0); @@ -161,6 +161,12 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { + int64_t src_size = GetItemNumByShape(args.src_shape); + if (total_size == 0 && src_size == 0) { + result.length = static_cast(total_size); + return SUCCESS; + } + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h index 401f7e07..fb190f54 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h +++ b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h @@ -19,7 +19,7 @@ #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc index 9b3457ca..ec309543 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -51,8 +51,8 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &src_shape, Data int64_t dst_ele_num = GetItemNumByShape(dst_shape); int64_t data_size = GetSizeByDataType(src_data_type); int64_t dst_size = data_size * dst_ele_num; - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + if (dst_ele_num == 0) { + result.length = static_cast(dst_size); + return SUCCESS; + } + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); int64_t dst_index = 0; std::vector dst_indexes(dst_shape.size()); while (dst_index < dst_ele_num) { diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.h b/src/ge/common/formats/format_transfers/format_transfer_transpose.h index 6866b2e7..476ef024 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_transpose.h +++ b/src/ge/common/formats/format_transfers/format_transfer_transpose.h @@ -20,7 +20,7 @@ #include #include -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" namespace ge { namespace formats { diff --git a/src/ge/common/formats/formats.cc b/src/ge/common/formats/formats.cc index 938f0888..d01d055b 100644 --- a/src/ge/common/formats/formats.cc +++ b/src/ge/common/formats/formats.cc @@ -24,6 +24,7 @@ #include #include +#include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/type_utils.h" @@ -38,10 +39,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg TypeUtils::FormatToSerialString(args.dst_format).c_str()); return UNSUPPORTED; } - if (args.data == nullptr) { + + auto src_shape_size = GetItemNumByShape(args.src_shape); + if (args.data == nullptr && src_shape_size != 0) { GELOGE(PARAM_INVALID, "Invalid input null data"); return PARAM_INVALID; } + return transfer->TransFormat(args, result); } @@ -71,6 +75,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); return UNSUPPORTED; } + + if (args.data == nullptr && args.src_data_size != 0) { + GELOGE(PARAM_INVALID, "Invalid input null data"); + return PARAM_INVALID; + } + return transfer->TransDataType(args, result); } diff --git a/src/ge/common/formats/formats.h b/src/ge/common/formats/formats.h index 09566904..b58c67aa 100644 --- a/src/ge/common/formats/formats.h +++ b/src/ge/common/formats/formats.h @@ -21,7 +21,7 @@ #include #include "common/formats/format_transfers/datatype_transfer.h" -#include "common/formats/format_transfers/format_transfer.h" +#include "register/register_format_transfer.h" #include "external/graph/types.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/ge_tensor.h" @@ -36,8 +36,8 @@ namespace formats { */ Status TransFormat(const TransArgs &args, TransResult &result); -Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, - Format dst_format, std::vector &dst_shape); +Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, + std::vector &dst_shape); Status TransDataType(const CastArgs &args, TransResult &result); diff --git a/src/ge/common/formats/utils/formats_trans_utils.cc b/src/ge/common/formats/utils/formats_trans_utils.cc index 35a0a073..23da0f74 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.cc +++ b/src/ge/common/formats/utils/formats_trans_utils.cc @@ -69,11 +69,11 @@ bool IsShapeValid(const std::vector &shape) { } int64_t num = 1; for (auto dim : shape) { - if (dim < 1) { - GELOGE(PARAM_INVALID, "Invalid zero dim in the shape %s", ShapeToString(shape).c_str()); + if (dim < 0) { + GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); return false; } - if (kShapeItemNumMAX / dim < num) { + if (dim != 0 && kShapeItemNumMAX / dim < num) { GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); return false; } diff --git a/src/ge/common/formats/utils/formats_trans_utils.h b/src/ge/common/formats/utils/formats_trans_utils.h index 310aaf38..a8fbd09b 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.h +++ b/src/ge/common/formats/utils/formats_trans_utils.h @@ -64,6 +64,9 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst); template T Ceil(T n1, T n2) { + if (n1 == 0) { + return 0; + } return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; } diff --git a/src/ge/common/fp16_t.h b/src/ge/common/fp16_t.h index 854df58f..34908b95 100644 --- a/src/ge/common/fp16_t.h +++ b/src/ge/common/fp16_t.h @@ -601,4 +601,4 @@ int16_t GetManBitLength(T man) { return len; } }; // namespace ge -#endif // GE_COMMON_FP16_T_H_ \ No newline at end of file +#endif // GE_COMMON_FP16_T_H_ diff --git a/src/ge/common/ge/plugin_manager.cc b/src/ge/common/ge/plugin_manager.cc index f2eb8f5f..458b584d 100644 --- a/src/ge/common/ge/plugin_manager.cc +++ b/src/ge/common/ge/plugin_manager.cc @@ -17,8 +17,8 @@ #include "common/ge/plugin_manager.h" #include -#include #include +#include #include #include #include @@ -27,13 +27,15 @@ #include #include "framework/common/debug/log.h" - +#include "framework/common/util.h" + +namespace { +const int kMaxNumOfSo = 64; +const int kMaxSizeOfSo = 209100800; // = 200M(unit is Byte) +const int kMaxSizeOfLoadedSo = 522752000; // = 500M(unit is Byte) +const char *const kExt = ".so"; // supported extension of shared object +} // namespace namespace ge { -static const int kMaxNumOfSo = 64; -static const int kMaxSizeOfSo = 209100800; // = 200M(unit is Byte) -static const int kMaxSizeOfLoadedSo = 522752000; // = 500M(unit is Byte) -static const char *const kExt = ".so"; // supported extension of shared object - void PluginManager::ClearHandles_() noexcept { for (const auto &handle : handles_) { if (dlclose(handle.second) != 0) { @@ -100,7 +102,7 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec } std::string file_name = single_path.substr(single_path.rfind('/') + 1, string::npos); - string file_path_dlopen = domi::RealPath(single_path.c_str()); + string file_path_dlopen = RealPath(single_path.c_str()); if (file_path_dlopen.empty()) { GELOGW("Failed to get realpath of %s!", single_path.c_str()); continue; @@ -121,8 +123,6 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec continue; } - GELOGW("The shared library will not be checked. Please ensure the source of the shared library is trusted."); - // load continue when so is invalid bool is_valid = true; for (const auto &func_name : func_check_list) { @@ -145,10 +145,17 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec handles_[string(file_name)] = handle; num_of_loaded_so++; } + + GELOGI("load so total num %u", num_of_loaded_so); + for (auto name : so_list_) { + GELOGI("load %s successfully", name.c_str()); + } + if (num_of_loaded_so == 0) { GELOGW("Failed to find any valid so in path %s!", path.c_str()); return SUCCESS; } + return SUCCESS; } @@ -225,7 +232,7 @@ Status PluginManager::Load(const string &path, const vector &func_check_ } std::string canonical_path_str = (std::string(canonical_path) + "/" + file_name); - string file_path_dlopen = domi::RealPath(canonical_path_str.c_str()); + string file_path_dlopen = RealPath(canonical_path_str.c_str()); if (file_path_dlopen.empty()) { GELOGW("failed to get realpath of %s", canonical_path_str.c_str()); continue; diff --git a/src/ge/common/helper/model_cache_helper.cc b/src/ge/common/helper/model_cache_helper.cc new file mode 100644 index 00000000..da1a212e --- /dev/null +++ b/src/ge/common/helper/model_cache_helper.cc @@ -0,0 +1,1708 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "common/ge/ge_util.h" +#include "common/helper/model_cache_helper.h" +#include "common/types.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_types.h" +#include "framework/common/helper/model_helper.h" +#include "framework/common/util.h" +#include "graph/detail/attributes_holder.h" +#include "graph/detail/model_serialize_imp.h" +#include "graph/load/new_model_manager/davinci_model_parser.h" +#include "graph/model.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "init/gelib.h" +#include "proto/ge_ir.pb.h" + +using namespace std; + +namespace { +const char *const kTbeKernelInfoStoreName = "AIcoreEngine"; +const char *const kGraphName = "temp_name"; +// Keys of json +const char *const kNodeNum = "nodeNum"; +const char *const kEdgeNum = "edgeNum"; +const char *const kGraphHash = "graphHash"; +const char *const kNodeHash = "nodeHash"; +const char *const kHash = "hash"; +const char *const kSessionId = "sessionId"; +const char *const kDeviceId = "deviceId"; +const char *const kJobId = "jobId"; +const char *const kGraphMemMaxSize = "graphMemMaxSize"; +const char *const kVarMemMaxSize = "varMemMaxSize"; +const char *const kVarMemLogicBase = "varMemLogicBase"; +const char *const kUseMaxMemSize = "useMaxMemSize"; +const char *const kMemResourceMap = "memResourceMap"; +const char *const kMemType = "memType"; +const char *const kTotalSize = "totalSize"; +const char *const kVarMemSize = "varMemSize"; +const char *const kVarResource = "varResource"; +const char *const kVarAddrMgrMap = "varAddrMgrMap"; +const char *const kName = "name"; +const char *const kAddress = "address"; +const char *const kOffset = "offset"; +const char *const kMemoryType = "memoryType"; +const char *const kTensorDesc = "tensorDesc"; +const char *const kDataType = "dataType"; +const char *const kShape = "shape"; +const char *const kLayout = "layout"; +const char *const kOriginDataType = "originDataType"; +const char *const kOriginShape = "originShape"; +const char *const kOriginLayout = "originLayout"; +const char *const kRealDimCnt = "realDimCnt"; +const char *const kCurVarTensorDescMap = "curVarTensorDescMap"; +const char *const kTransRoads = "transRoads"; +const char *const kTransRoad = "transRoad"; +const char *const kNodeType = "nodeType"; +const char *const kInputTensorDesc = "inputTensorDesc"; +const char *const kOutputTensorDesc = "outputTensorDesc"; +const char *const kChangedGraphId = "changedGraphId"; +const char *const kAllocatedGraphId = "allocatedGraphId"; +const char *const kGraphId = "graphId"; +const char *const kVarBroadcastInfo = "varBroadcastInfo"; +const char *const kBroadcastName = "broadcastName"; +const char *const kIdx = "idx"; +const char *const kInputOffset = "inputOffset"; +const char *const kInputSize = "inputSize"; +const char *const kOutputOffset = "outputOffset"; +const char *const kOutputSize = "outputSize"; +// Suffix of cache files +const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json"; +const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json"; +const char *const kManifestSuffix = ".manifest"; +const char *const kOmSuffix = ".om"; +} // namespace + +namespace ge { +map ModelCacheHelper::graph_id_run_times_; +ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph) + : session_id_(session_id), + graph_id_(graph_id), + compute_graph_(compute_graph), + is_cache_path_valid_for_output(false) { + if (graph_id_run_times_.count(graph_id) == 0) { + graph_id_run_times_[graph_id] = 1; + } else { + graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1; + } + for (const auto &node : compute_graph_->GetDirectNode()) { + bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || + (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); + if (!is_variable) { + continue; + } + var_names_.insert(node->GetName()); + } + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { + std::string cache_path = instance_ptr->GetIncreBuildCachePath(); + GELOGD("Incre build path conf: %s", cache_path.c_str()); + string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix; + if (CheckOutputPathValid(fake_file_path)) { + is_cache_path_valid_for_output = true; + } else { + GELOGW("Invalid cache path for output."); + } + std::string real_cache_path = RealPath(cache_path.c_str()); + if (real_cache_path.empty()) { + GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str()); + return; + } + cache_path_ = real_cache_path + '/'; + GELOGD("Try to use incre build cache path: %s", cache_path_.c_str()); + } +} + +ModelCacheHelper::~ModelCacheHelper() { var_names_.clear(); } + +bool ModelCacheHelper::IsModelCacheHit() const { + CacheInfo cache_info; + if (GetCacheInfo(cache_info) != SUCCESS) { + GELOGI("Get cache info of graph id[%u] failed.", graph_id_); + return false; + } + // Check number of nodes and edges first. + if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) { + GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_); + return false; + } + size_t edge_num = 0; + for (const auto &node : compute_graph_->GetDirectNode()) { + for (const auto &anchor : node->GetAllInAnchors()) { + edge_num += anchor->GetPeerAnchors().size(); + } + } + if (cache_info.edge_num != edge_num) { + GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_); + return false; + } + size_t compute_graph_hash; + auto ret = GetComputeGraphHash(compute_graph_hash); + if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) { + GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_); + return false; + } + if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) { + GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_); + return false; + } + + string var_manager_cache = + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix; + Json var_manager_json; + if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { + GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); + return false; + } + if (!IsVarManagerSameAsCache(var_manager_json)) { + GELOGI("Graph id[%u] cache miss: the VarManager dos not match the cache info.", graph_id_); + return false; + } + GELOGI("Graph id[%u] cache hit.", graph_id_); + return true; +} + +Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) { + if (compute_graph->IsValid()) { + compute_graph_ = compute_graph; + var_names_.clear(); + for (const auto &node : compute_graph_->GetDirectNode()) { + bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || + (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); + if (!is_variable) { + continue; + } + var_names_.insert(node->GetName()); + } + return SUCCESS; + } else { + GELOGW("Invalid compute graph."); + return FAILED; + } +} + +Status ModelCacheHelper::ClearCache(uint32_t graph_id) const { + if (!is_cache_path_valid_for_output) { + GELOGW("Invalid cache path."); + return SUCCESS; + } + string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix; + string manifest_file_path = RealPath(manifest_file.c_str()); + int ret; + if (!manifest_file_path.empty()) { + ret = remove(manifest_file_path.c_str()); + // If remove file failed, print the warning log + if (ret != 0) { + GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str()); + } + } + string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; + string before_var_manager_file_path = RealPath(before_var_manager_file.c_str()); + if (!before_var_manager_file_path.empty()) { + ret = remove(before_var_manager_file_path.c_str()); + if (ret != 0) { + GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str()); + } + } + string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; + string after_var_manager_file_path = RealPath(after_var_manager_file.c_str()); + if (!after_var_manager_file_path.empty()) { + ret = remove(after_var_manager_file_path.c_str()); + if (ret != 0) { + GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str()); + } + } + string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix; + string om_file_path = RealPath(om_file.c_str()); + if (!om_file_path.empty()) { + ret = remove(om_file_path.c_str()); + if (ret != 0) { + GELOGW("Clear cache [%s] failed.", om_file_path.c_str()); + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverVarManagerFromCache() const { + string var_manager_cache = + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix; + Json var_manager_json; + if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { + GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); + return FAILED; + } + + Json mem_resource_json = move(var_manager_json[kMemResourceMap]); + auto ret = RecoverMemResource(mem_resource_json); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[MemResource]"); + return FAILED; + } + Json var_resource_json = move(var_manager_json[kVarResource]); + ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]"); + return FAILED; + } + ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[ChangedGraphId]"); + return FAILED; + } + ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]"); + return FAILED; + } + ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]"); + return FAILED; + } + ret = RecoverTransRoads(var_resource_json[kTransRoads]); + if (ret != SUCCESS) { + GELOGW("Recover VarManager from cache failed.[TransRoads]"); + return FAILED; + } + GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str()); + return SUCCESS; +} + +Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes) { + std::shared_ptr instance = ge::GELib::GetInstance(); + if (instance == nullptr || !instance->InitFlag()) { + GELOGW("RecompileNodes failed."); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + // Collect aicore ops for recompile + for (auto &node : graph->GetDirectNode()) { + if (node == nullptr) { + continue; + } + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + // Get op kernel lib name + string kernel_lib_name = op_desc->GetOpKernelLibName(); + if (kernel_lib_name.empty()) { + // reset op kernel lib + (void)instance->DNNEngineManagerObj().GetDNNEngineName(op_desc); + kernel_lib_name = op_desc->GetOpKernelLibName(); + if (kernel_lib_name.empty()) { + GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); + continue; + } + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) { + std::shared_ptr instance = ge::GELib::GetInstance(); + if (instance == nullptr || !instance->InitFlag()) { + GELOGW("RecompileNodes failed."); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + // Get aicore ops kernel info store. + OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kTbeKernelInfoStoreName); + if (kernel_info == nullptr) { + GELOGW("Get %s ops kernel info store failed", kTbeKernelInfoStoreName); + return INTERNAL_ERROR; + } + + auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + vector node_vec; + auto ret = GetNodesNeedRecompile(compute_graph, node_vec); + GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Get nodes need recompiling failed"); + // Recompile aicore ops + ret = kernel_info->CompileOp(node_vec); + GE_CHK_BOOL_EXEC_WARN(ret == ge::SUCCESS, return ret, "Recompile op failed"); + const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore(); + TBEKernelStore tbe_kernel_store; + for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { + auto node_op_desc = n->GetOpDesc(); + GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); + TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); + if (tbe_kernel == nullptr) { + // Load tbe kernel from tbe_kernel_store to op if op was not recompiled + auto op_desc = n->GetOpDesc(); + tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); + GELOGD("LoadOmModelFromCache: Load tbe kernel bin to op desc[%s].", op_desc->GetName().c_str()); + } + tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); + GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); + // Refresh tbe kernel in tbe_kernel_store + tbe_kernel_store.AddTBEKernel(tbe_kernel); + GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); + } + GE_CHK_BOOL_EXEC_WARN(tbe_kernel_store.Build(), return FAILED, "TBE Kernels store build failed!"); + ge_model->SetTBEKernelStore(tbe_kernel_store); + return SUCCESS; +} + +Status ModelCacheHelper::GetNodesHash(map &hash_map) const { + vector nodes; + GraphUtils::TopologicalSortingByName(compute_graph_, nodes); + ModelSerializeImp model_serialize_imp; + std::hash node_hash; + for (const auto &node : nodes) { + if (node == nullptr) { + continue; + } + proto::OpDef op_def; + bool is_framework_op = (node->GetType() == FRAMEWORKOP); + int32_t framework_type = 0; + if (is_framework_op) { + AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); + AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, 0); + } + bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op); + op_def.set_id(0); // Id of op is not stable because of parallel parsing + // Clear weights attr in constant. + auto attr = op_def.mutable_attr(); + if (op_def.type() == CONSTANT || op_def.type() == CONSTANTOP) { + attr->erase(ATTR_NAME_WEIGHTS); + } + if (is_framework_op) { + AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, framework_type); + } + if (!ret) { + GELOGW("Fail to serialize node[%s].", node->GetName().c_str()); + return INTERNAL_ERROR; + } + string prototxt; + ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt); + if (!ret) { + GELOGW("Print OpDef to string failed."); + hash_map.clear(); + return INTERNAL_ERROR; + } + size_t hash_code = node_hash(prototxt); + hash_map[node->GetName()] = hash_code; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const { + proto::GraphDef graph_proto; + ModelSerializeImp model_serialize_imp; + // The name of compute graph may be generated randomly, so replace it temporarily. + const string origin_name = compute_graph_->GetName(); + compute_graph_->SetName(kGraphName); + bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto); + graph_proto.clear_op(); + if (!serialize_ret) { + GELOGW("Serialize graph failed."); + hash = 0; + return INTERNAL_ERROR; + } + compute_graph_->SetName(origin_name); + // Generate proto text of GraphDef + string prototxt; + bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt); + if (!print_ret) { + GELOGW("Print GraphDef to string failed."); + hash = 0; + return INTERNAL_ERROR; + } + // Get the hash code of proto text + std::hash graph_hash; + hash = graph_hash(prototxt); + return SUCCESS; +} + +Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const { + if (!is_cache_path_valid_for_output) { + GELOGW("Invalid cache path."); + return PARAM_INVALID; + } + // Check whether the manifest exists, if not, create it. + string real_path = RealPath(cache_path_.c_str()); + if (real_path.empty()) { + GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); + return FAILED; + } + const string path = cache_path_ + file_name; + const int FILE_AUTHORITY = 0600; + int fd = open(path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); + if (fd < 0) { + GELOGW("Fail to open the file: %s.", path.c_str()); + return INTERNAL_ERROR; + } + if (close(fd) != 0) { + GELOGW("Fail to close the file: %s.", path.c_str()); + return INTERNAL_ERROR; + } + + // Write json into cache file + ofstream ofs; + ofs.open(path); + if (!ofs.is_open()) { + GELOGW("Fail to open the file: %s.", path.c_str()); + return INTERNAL_ERROR; + } + ofs << json << std::endl; + ofs.close(); + return SUCCESS; +} + +Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const { + if (!json.is_null()) { + GELOGW("Input param json type should be null."); + return PARAM_INVALID; + } + string real_path = RealPath(cache_path_.c_str()); + if (real_path.empty()) { + GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); + return FAILED; + } + const string path = cache_path_ + file_name; + if (!CheckInputPathValid(path)) { + GELOGW("Invalid cache path for input:%s.", path.c_str()); + return FAILED; + } + string cache_real_path = RealPath(path.c_str()); + if (cache_real_path.empty()) { + GELOGI("File[%s] is not found.", path.c_str()); + return FAILED; + } + // Read json from cache file + ifstream ifs; + ifs.open(path); + if (!ifs.is_open()) { + GELOGW("Fail to open the file: %s.", path.c_str()); + return INTERNAL_ERROR; + } + ifs >> json; + if (!json.is_object()) { + GELOGW("Fail to load the json file: %s.", path.c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::SaveCacheInfoToCache() const { + // Generate cache json + // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356} + Json cache_json; + try { + cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize(); + size_t edge_num = 0; + for (const auto &node : compute_graph_->GetDirectNode()) { + for (const auto &anchor : node->GetAllInAnchors()) { + edge_num += anchor->GetPeerAnchors().size(); + } + } + cache_json[kEdgeNum] = edge_num; + size_t hash = 0; + auto ret = GetComputeGraphHash(hash); + if (ret != SUCCESS) { + GELOGW("Error occur when generate graph hash code."); + return ret; + } + cache_json[kGraphHash] = hash; + Json nodes_hash_json; + ret = GetNodesHashMapJson(nodes_hash_json); + if (ret != SUCCESS) { + GELOGW("Error occur when generate nodes hash code."); + return ret; + } + cache_json[kNodeHash] = nodes_hash_json; + } catch (const std::exception &e) { + GELOGW("Fail to generate cache info json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; + + auto ret = SaveJsonToFile(cache_manifest, cache_json); + if (ret != SUCCESS) { + GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str()); + return ret; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const { + string cache_manifest = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; + Json cache_json; + if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) { + GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str()); + return INTERNAL_ERROR; + } + if (!cache_json.is_object()) { + GELOGW("Manifest should be a json object"); + return INTERNAL_ERROR; + } + try { + cache_info.node_num = cache_json[kNodeNum]; + cache_info.edge_num = cache_json[kEdgeNum]; + cache_info.graph_hash = cache_json[kGraphHash]; + Json nodes_hash_json = cache_json[kNodeHash]; + if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) { + GELOGW("Nodes hash in cache be null or array."); + return FAILED; + } + for (const auto &iter : nodes_hash_json) { + cache_info.nodes_hash[iter[kName].get()] = iter[kHash].get(); + } + } catch (const std::exception &e) { + GELOGW("Fail to get info from json file. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare allocated graph id info between json and VarManager + std::unordered_map allocated_graph_id; + auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to parse AllocatedGraphId from Json."); + return false; + } + for (const auto &iter : allocated_graph_id) { + uint32_t graph_id = 0; + ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str()); + return false; + } + if (graph_id != iter.second) { + GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsNodeHashSameAsCache(const map &hash_map) const { + map cur_hash_map; + GetNodesHash(cur_hash_map); + if (hash_map.size() != cur_hash_map.size()) { + GELOGI("The number of hash code is different from cache info."); + return false; + } + for (const auto &iter : cur_hash_map) { + if (hash_map.count(iter.first) == 0) { + GELOGI("Node[%s] is not found in cache info.", iter.first.c_str()); + return false; + } + if (hash_map.at(iter.first) != iter.second) { + GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare var mem size info between json and VarManager + std::map var_mem_size; + auto ret = ParseMemResourceFromJson(json, var_mem_size); + if (ret != SUCCESS) { + GELOGW("Fail to parse MemResource from Json."); + return false; + } + for (const auto &iter : var_mem_size) { + int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first); + if (mem_size != iter.second) { + GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare variable changed graph id info between json and VarManager + std::unordered_map changed_graph_id; + auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to parse ChangedGraphId from Json."); + return false; + } + for (const auto &iter : changed_graph_id) { + uint32_t graph_id = 0; + ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str()); + return false; + } + if (graph_id != iter.second) { + GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare variable tensor desc info between json and VarManager + std::unordered_map cur_var_tensor_desc; + auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to parse CurVarTensorDesc from Json."); + return false; + } + for (const auto &iter : cur_var_tensor_desc) { + GeTensorDesc tensor_desc; + ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); + return false; + } + uint32_t l_real_dim_cnt = 0; + uint32_t r_real_dim_cnt = 0; + TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt); + TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt); + if ((tensor_desc.GetDataType() != iter.second.GetDataType()) || + (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) || + (tensor_desc.GetFormat() != iter.second.GetFormat()) || + (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) || + (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) || + (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) || + (l_real_dim_cnt != r_real_dim_cnt)) { + GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare variable address info between json and VarManager + std::vector> var_addr_mgr_vector; + std::unordered_set var_offset_set; + auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); + if (ret != SUCCESS) { + GELOGW("Fail to parse VarAddrMgrMap from Json."); + return false; + } + for (const auto &iter : var_addr_mgr_vector) { + uint8_t *dev_ptr = nullptr; + rtMemType_t memory_type; + ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type); + if (ret != SUCCESS) { + GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); + return false; + } + // Compare memory type and logic address + if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) { + GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare broadcast info between json and VarManager + std::unordered_map var_broadcast_info; + auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); + if (ret != SUCCESS) { + GELOGW("Fail to parse BroadcastInfo from Json."); + return false; + } + for (const auto &iter : var_broadcast_info) { + VarBroadCastInfo broadcast_info; + if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) { + GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str()); + return false; + } + if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx || + iter.second.input_size != broadcast_info.input_size || + iter.second.input_offset != broadcast_info.input_offset || + iter.second.output_size != broadcast_info.output_size || + iter.second.output_offset != broadcast_info.output_offset) { + GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + return true; +} + +bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return false; + } + // Compare trans road between json and VarManager + std::unordered_map> trans_roads; + auto ret = ParseTransRoadsFromJson(json, trans_roads); + if (ret != SUCCESS) { + GELOGW("Fail to parse TransRoads from Json."); + return false; + } + for (const auto &iter : trans_roads) { + VarTransRoad *trans_road; + trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first); + if (trans_road == nullptr) { + GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); + return false; + } + if (trans_road->size() != iter.second.size()) { + GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + // Compare every trans node in trans road. + for (size_t idx = 0; idx < trans_road->size(); idx += 1) { + if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type && + trans_road->at(idx).input == iter.second.at(idx).input && + trans_road->at(idx).output == iter.second.at(idx).output)) { + GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); + return false; + } + } + } + return true; +} + +bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const { + if (!json.is_object()) { + GELOGW("Input param json type should be object."); + return false; + } + try { + if (json[kSessionId].get() != session_id_) { + GELOGW("Check VarManager cache failed.[sessionId]"); + return false; + } + if (json[kDeviceId].get() != VarManager::Instance(session_id_)->DeviceId()) { + GELOGW("Check VarManager cache failed.[deviceId]"); + return false; + } + if (json[kJobId].get() != VarManager::Instance(session_id_)->JobId()) { + GELOGW("Check VarManager cache failed.[jobId]"); + return false; + } + if (json[kGraphMemMaxSize].get() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) { + GELOGW("Check VarManager cache failed.[graphMemMaxSize]"); + return false; + } + if (json[kVarMemMaxSize].get() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) { + GELOGW("Check VarManager cache failed.[varMemMaxSize]"); + return false; + } + if (json[kVarMemLogicBase].get() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { + GELOGW("Check VarManager cache failed.[varMemLogicBase]"); + return false; + } + if (json[kUseMaxMemSize].get() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) { + GELOGW("Check VarManager cache failed.[useMaxMemSize]"); + return false; + } + } catch (const std::exception &e) { + GELOGW("Fail to check VarManager json. Error message: %s", e.what()); + return false; + } + return true; +} + +bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const { + if (!json.is_object()) { + GELOGW("Input param json type should be object."); + return false; + } + try { + if (!IsVarManagerParamSameAsCache(json)) { + GELOGW("Check VarManager cache failed.[Param]"); + return false; + } + Json mem_resource_json = move(json[kMemResourceMap]); + auto ret = IsMemResourceSameAsCache(mem_resource_json); + if (!ret) { + GELOGW("Check VarManager cache failed.[MemResource]"); + return false; + } + Json var_resource_json = move(json[kVarResource]); + ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]); + if (!ret) { + GELOGW("Check VarManager cache failed.[AllocatedGraphId]"); + return false; + } + ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]); + if (!ret) { + GELOGW("Check VarManager cache failed.[ChangedGraphId]"); + return false; + } + ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]); + if (!ret) { + GELOGW("Check VarManager cache failed.[VarBroadcastInfo]"); + return false; + } + ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]); + if (!ret) { + GELOGW("Check VarManager cache failed.[CurVarTensorDesc]"); + return false; + } + ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]); + if (!ret) { + GELOGW("Check VarManager cache failed.[VarAddrMgrMap]"); + return false; + } + ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]); + if (!ret) { + GELOGW("Check VarManager cache failed.[TransRoads]"); + return false; + } + } catch (const std::exception &e) { + GELOGW("Fail to check VarManager json. Error message: %s", e.what()); + return false; + } + return true; +} + +Status ModelCacheHelper::RecoverMemResource(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::map var_mem_size; + auto ret = ParseMemResourceFromJson(json, var_mem_size); + if (ret != SUCCESS) { + GELOGW("Fail to parse MemResource from Json."); + return ret; + } + for (const auto &iter : var_mem_size) { + ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second); + if (ret != SUCCESS) { + GELOGW("Fail to recover var mem size."); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::unordered_map allocated_graph_id; + auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to parse AllocatedGraphId from Json."); + return ret; + } + for (const auto &iter : allocated_graph_id) { + ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second); + if (ret != SUCCESS) { + GELOGW("Fail to recover allocated graph id."); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::unordered_map changed_graph_id; + auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); + if (ret != SUCCESS) { + GELOGW("Fail to parse AllocatedGraphId from Json."); + return ret; + } + for (const auto &iter : changed_graph_id) { + ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second); + if (ret != SUCCESS) { + GELOGW("Fail to recover changed graph id."); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::vector> var_addr_mgr_vector; + std::unordered_set var_offset_set; + auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); + if (ret != SUCCESS) { + GELOGW("Fail to parse VarAddrMgrMap from Json."); + return ret; + } + for (const auto &iter : var_addr_mgr_vector) { + const VarAddrMgr &tensor_addr_mgr = iter.second; + const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc); + // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager + if (!var_exist) { + auto logic_address = reinterpret_cast(reinterpret_cast(tensor_addr_mgr.address)); + auto offset = (tensor_addr_mgr.offset); + // Check logic address and offset + if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { + GELOGW("Check logic_address[%u] and offset [%u] of %s failed, var mem logic base is %u, abandon", logic_address, + offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase()); + return PARAM_INVALID; + } + // Offset is needed by SaveVarVddr instead of logic address + ret = + VarManager::Instance(session_id_) + ->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc, + reinterpret_cast(reinterpret_cast(offset)), tensor_addr_mgr.memory_type); + if (ret != SUCCESS) { + GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str()); + return ret; + } + } + // SetVarAddr to update cur_var_tensor_desc_map_ + ret = VarManager::Instance(session_id_) + ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type); + if (ret != SUCCESS) { + GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str()); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::unordered_map var_broadcast_info; + auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); + if (ret != SUCCESS) { + GELOGW("Fail to parse BroadcastInfo from Json."); + return ret; + } + for (const auto &iter : var_broadcast_info) { + VarBroadCastInfo broadcast_info; + ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second); + if (ret != SUCCESS) { + GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str()); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::RecoverTransRoads(const Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::unordered_map> trans_roads; + auto ret = ParseTransRoadsFromJson(json, trans_roads); + if (ret != SUCCESS) { + GELOGW("Fail to parse TransRoads from Json."); + return ret; + } + for (const auto &iter : trans_roads) { + ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second); + if (ret != SUCCESS) { + GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); + return ret; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) { + if (!(json.is_null() || json.is_object())) { + GELOGW("Input param json type should be null or object."); + return PARAM_INVALID; + } + try { + json[kDataType] = static_cast(ge_tensor_desc.GetDataType()); + json[kOriginDataType] = static_cast(ge_tensor_desc.GetOriginDataType()); + json[kLayout] = static_cast(ge_tensor_desc.GetFormat()); + json[kOriginLayout] = static_cast(ge_tensor_desc.GetOriginFormat()); + json[kShape] = ge_tensor_desc.GetShape().GetDims(); + json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims(); + uint32_t real_dim_cnt = 0; + (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] + json[kRealDimCnt] = real_dim_cnt; + } catch (const std::exception &e) { + GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) { + if (!json.is_object()) { + GELOGW("Input param json type should be object."); + return PARAM_INVALID; + } + try { + ge_tensor_desc.SetDataType(static_cast(json[kDataType].get())); + ge_tensor_desc.SetOriginDataType(static_cast(json[kOriginDataType].get())); + ge_tensor_desc.SetFormat(static_cast(json[kLayout].get())); + ge_tensor_desc.SetOriginFormat(static_cast(json[kOriginLayout].get())); + GeShape shape(json[kShape].get>()); + ge_tensor_desc.SetShape(shape); + GeShape origin_shape(json[kOriginShape].get>()); + ge_tensor_desc.SetOriginShape(origin_shape); + auto real_dim_cnt = json[kRealDimCnt].get(); + (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] + } catch (const std::exception &e) { + GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + map hash_map; + GetNodesHash(hash_map); + for (const auto &iter : hash_map) { + Json node_hash_json; + try { + node_hash_json[kName] = iter.first; + node_hash_json[kHash] = iter.second; + json.emplace_back(move(node_hash_json)); + } catch (const std::exception &e) { + GELOGW("Fail to trans node cache to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::GetMemResourceMap(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize(); + const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM); + Json mem_resource_json; + try { + mem_resource_json[kMemType] = RT_MEMORY_HBM; + mem_resource_json[kTotalSize] = total_size; + mem_resource_json[kVarMemSize] = var_mem_size; + json.emplace_back(move(mem_resource_json)); + } catch (const std::exception &e) { + GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + std::unordered_map var_addr_mgr_map; + VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map); + try { + for (const auto &iter : var_addr_mgr_map) { + Json var_addr_json; + string name; + GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name); + var_addr_json[kName] = name; + var_addr_json[kAddress] = static_cast(reinterpret_cast(iter.second.address)); + var_addr_json[kMemoryType] = iter.second.memory_type; + var_addr_json[kOffset] = iter.second.offset; + + // Copy tensor desc to json. + Json tensor_desc_json; + auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json); + if (ret != SUCCESS) { + GELOGW("Fail to trans tensor desc to json."); + return INTERNAL_ERROR; + } + var_addr_json[kTensorDesc] = move(tensor_desc_json); + + json.emplace_back(move(var_addr_json)); + } + } catch (const std::exception &e) { + GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + try { + for (const auto &name : var_names_) { + Json cur_tensor_desc_json; + GeTensorDesc tensor_desc; + auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc); + if (ret != SUCCESS) { + GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str()); + continue; + } + cur_tensor_desc_json[kName] = name; + + Json tensor_desc_json; + ret = TensorDescToJson(tensor_desc, tensor_desc_json); + if (ret != SUCCESS) { + GELOGW("Fail to trans tensor desc to json."); + return INTERNAL_ERROR; + } + cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json); + json.emplace_back(move(cur_tensor_desc_json)); + } + } catch (const std::exception &e) { + GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetTransRoadsJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + try { + for (const auto &name : var_names_) { + auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name); + if (trans_road == nullptr) { + continue; + } + // Json object, variable name and trans road + Json trans_road_map_json; + trans_road_map_json[kName] = name; + + Json trans_road_json; + Status ret; + // Add nodes' info to json + for (const auto &trans_node_info : *trans_road) { + Json trans_node_info_json; + trans_node_info_json[kNodeType] = trans_node_info.node_type; + Json input_tensor_desc_json; + ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json); + if (ret != SUCCESS) { + GELOGW("Fail to trans tensor desc to json."); + return INTERNAL_ERROR; + } + trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json); + Json output_tensor_desc_json; + ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json); + if (ret != SUCCESS) { + GELOGW("Fail to trans tensor desc to json."); + return INTERNAL_ERROR; + } + trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json); + trans_road_json.emplace_back(move(trans_node_info_json)); + } + trans_road_map_json[kTransRoad] = move(trans_road_json); + json.emplace_back(move(trans_road_map_json)); + } + } catch (const std::exception &e) { + GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + for (const auto &name : var_names_) { + uint32_t changed_graph_id = 0; + Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id); + if (ret != SUCCESS) { + continue; + } + Json name_and_changed_graph_id; + try { + name_and_changed_graph_id[kName] = name; + name_and_changed_graph_id[kGraphId] = changed_graph_id; + json.emplace_back(move(name_and_changed_graph_id)); + } catch (const std::exception &e) { + GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + for (const auto &name : var_names_) { + uint32_t allocated_graph_id = 0; + Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id); + if (ret != SUCCESS) { + continue; + } + Json name_and_allocated_graph_id; + try { + name_and_allocated_graph_id[kName] = name; + name_and_allocated_graph_id[kGraphId] = allocated_graph_id; + json.emplace_back(move(name_and_allocated_graph_id)); + } catch (const std::exception &e) { + GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const { + if (!(json.is_null() || json.is_array())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + for (const auto &name : var_names_) { + VarBroadCastInfo var_broadcast_info; + Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info); + if (ret != SUCCESS) { + continue; + } + Json var_broadcast_info_json; + try { + var_broadcast_info_json[kName] = name; + var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name; + var_broadcast_info_json[kIdx] = var_broadcast_info.idx; + var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset; + var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size; + var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset; + var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size; + json.emplace_back(move(var_broadcast_info_json)); + } catch (const std::exception &e) { + GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::GetVarResourceJson(Json &json) const { + if (!(json.is_null() || json.is_object())) { + GELOGW("Input param json type should be null or object."); + return PARAM_INVALID; + } + Json var_addr_mgr_map_json; + Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json); + if (ret != SUCCESS) { + GELOGW("GetVarAddrMgrMapJson failed."); + return INTERNAL_ERROR; + } + + Json cur_var_tensor_desc_map_json; + ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json); + if (ret != SUCCESS) { + GELOGW("GetCurVarTensorDescMapJson failed."); + return INTERNAL_ERROR; + } + + Json trans_roads_json; + ret = GetTransRoadsJson(trans_roads_json); + if (ret != SUCCESS) { + GELOGW("GetTransRoadsJson failed."); + return INTERNAL_ERROR; + } + + Json changed_graph_id_json; + ret = GetChangedGraphIdJson(changed_graph_id_json); + if (ret != SUCCESS) { + GELOGW("GetChangedGraphIdJson failed."); + return INTERNAL_ERROR; + } + + Json allocated_graph_id_json; + ret = GetAllocatedGraphIdJson(allocated_graph_id_json); + if (ret != SUCCESS) { + GELOGW("GetAllocatedGraphIdJson failed."); + return INTERNAL_ERROR; + } + + Json var_broadcast_info_json; + ret = GetBroadcastInfoJson(var_broadcast_info_json); + if (ret != SUCCESS) { + GELOGW("GetBroadcastInfoJson failed."); + return INTERNAL_ERROR; + } + + try { + json[kVarAddrMgrMap] = move(var_addr_mgr_map_json); + json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json); + json[kTransRoads] = move(trans_roads_json); + json[kChangedGraphId] = move(changed_graph_id_json); + json[kAllocatedGraphId] = move(allocated_graph_id_json); + json[kVarBroadcastInfo] = move(var_broadcast_info_json); + } catch (const exception &e) { + GELOGW("Fail to generate VarResource json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetVarManagerJson(Json &json) const { + if (!(json.is_null() || json.is_object())) { + GELOGW("Input param json type should be null or object."); + return PARAM_INVALID; + } + + Json mem_resource_map_json; + auto ret = GetMemResourceMap(mem_resource_map_json); + if (ret != SUCCESS) { + GELOGW("GetMemResourceMap failed."); + return INTERNAL_ERROR; + } + + Json var_resource_json; + ret = GetVarResourceJson(var_resource_json); + if (ret != SUCCESS) { + GELOGW("GetVarResourceJson failed."); + return INTERNAL_ERROR; + } + + try { + json[kSessionId] = session_id_; + json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId(); + json[kJobId] = VarManager::Instance(session_id_)->JobId(); + json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize(); + json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize(); + json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase(); + json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize(); + json[kMemResourceMap] = move(mem_resource_map_json); + json[kVarResource] = move(var_resource_json); + } catch (const exception &e) { + GELOGW("Fail to generate VarManager json. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const { + if (!is_cache_path_valid_for_output) { + GELOGW("Invalid cache path."); + return FAILED; + } + Json var_manager_json; + auto ret = GetVarManagerJson(var_manager_json); + if (ret != SUCCESS) { + GELOGW("Fail to generate VarManager json."); + return FAILED; + } + string var_manager_path = to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + + (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix); + ret = SaveJsonToFile(var_manager_path, var_manager_json); + if (ret != SUCCESS) { + GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str()); + return ret; + } + return SUCCESS; +} + +Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const { + if (!is_cache_path_valid_for_output) { + GELOGW("Invalid cache path."); + return FAILED; + } + string om_path = RealPath(cache_path_.c_str()); + if (om_path.empty()) { + GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str()); + return FAILED; + } + string cache_om_path = cache_path_; + cache_om_path += (to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix); + GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str()); + ModelHelper model_helper; + SaveParam save_param; + ModelBufferData model; + Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model); + if (ret != SUCCESS) { + GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret); + return ret; + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map &mem_resource) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + mem_resource.clear(); + for (const Json &mem_resource_json : json) { + MemResource var_addr_mgr; + try { + rtMemType_t mem_type = mem_resource_json[kMemType].get(); + uint64_t var_mem_size = mem_resource_json[kVarMemSize].get(); + mem_resource[mem_type] = var_mem_size; + } catch (const exception &e) { + GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseVarAddrMgrMapFromJson( + const Json &json, std::vector> &var_addr_mgr_vector, + std::unordered_set &var_offset_set) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + var_addr_mgr_vector.clear(); + var_offset_set.clear(); + for (const Json &var_addr_json : json) { + VarAddrMgr var_addr_mgr; + try { + auto logic_address = var_addr_json[kAddress].get(); + auto address = reinterpret_cast(reinterpret_cast(logic_address)); + var_addr_mgr.address = address; + var_addr_mgr.offset = var_addr_json[kOffset].get(); + var_addr_mgr.memory_type = var_addr_json[kMemoryType].get(); + auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to trans json to tensor desc."); + return ret; + } + var_addr_mgr_vector.emplace_back(var_addr_json[kName].get(), move(var_addr_mgr)); + var_offset_set.insert(logic_address); + } catch (const exception &e) { + GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson( + const Json &json, std::unordered_map &cur_var_tensor_desc_map) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + cur_var_tensor_desc_map.clear(); + for (const Json &tensor_desc_json : json) { + GeTensorDesc tensor_desc; + try { + auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to trans json to tensor desc."); + return ret; + } + cur_var_tensor_desc_map[tensor_desc_json[kName].get()] = move(tensor_desc); + } catch (const exception &e) { + GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseTransRoadsFromJson( + const Json &json, std::unordered_map> &trans_roads) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + trans_roads.clear(); + try { + for (const Json &name_trans_road_json : json) { + const Json &trans_road_json = name_trans_road_json[kTransRoad]; + if (!(trans_road_json.is_array() || trans_road_json.is_null())) { + GELOGW("%s json type should be null or object.", kTransRoad); + return PARAM_INVALID; + } + vector trans_road; + for (const Json &trans_node_json : trans_road_json) { + TransNodeInfo trans_node_info; + trans_node_info.node_type = trans_node_json[kNodeType]; + GeTensorDesc input_tensor_desc; + auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to trans json to tensor desc."); + return ret; + } + trans_node_info.input = move(input_tensor_desc); + GeTensorDesc output_tensor_desc; + ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc); + if (ret != SUCCESS) { + GELOGW("Fail to trans json to tensor desc."); + return ret; + } + trans_node_info.output = move(output_tensor_desc); + trans_road.emplace_back(move(trans_node_info)); + } + trans_roads[name_trans_road_json[kName].get()] = move(trans_road); + } + } catch (const exception &e) { + GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json, + std::unordered_map &changed_graph_id) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + changed_graph_id.clear(); + for (const Json &name_graph_id_json : json) { + try { + changed_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); + } catch (const exception &e) { + GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json, + std::unordered_map &allocated_graph_id) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + allocated_graph_id.clear(); + for (const Json &name_graph_id_json : json) { + try { + allocated_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); + } catch (const exception &e) { + GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status ModelCacheHelper::ParseBroadcastInfoFromJson( + const Json &json, std::unordered_map &var_broadcast_info) { + if (!(json.is_array() || json.is_null())) { + GELOGW("Input param json type should be null or array."); + return PARAM_INVALID; + } + for (const Json &broadcast_info_json : json) { + VarBroadCastInfo broadcast_info; + try { + broadcast_info.var_name = broadcast_info_json[kName].get(); + broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get(); + broadcast_info.idx = broadcast_info_json[kIdx].get(); + broadcast_info.input_offset = broadcast_info_json[kInputOffset].get(); + broadcast_info.input_size = broadcast_info_json[kInputSize].get(); + broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get(); + broadcast_info.output_size = broadcast_info_json[kOutputSize].get(); + } catch (const exception &e) { + GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what()); + return INTERNAL_ERROR; + } + var_broadcast_info[broadcast_info.var_name] = broadcast_info; + } + return SUCCESS; +} + +Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { + string cache_om = cache_path_ + to_string(graph_id_) + "_" + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix; + if (!CheckInputPathValid(cache_om)) { + GELOGW("Invalid cache path for input:%s.", cache_om.c_str()); + return FAILED; + } + string om_path = RealPath(cache_om.c_str()); + if (om_path.empty()) { + GELOGW("file path is invalid. please check file om: %s", om_path.c_str()); + return FAILED; + } + GELOGI("load model data from file: %s", om_path.c_str()); + Status ret; + string key_path; + int32_t priority = 0; + ModelData model_data; + ret = DavinciModelParser::LoadFromFile(om_path.c_str(), key_path.c_str(), priority, model_data); + if (ret != SUCCESS) { + GELOGW("LoadOmModelFromCache: Load model from file fialed. ret = %u", ret); + return ret; + } + + ModelHelper model_helper; + ret = model_helper.LoadModel(model_data); + if (ret != SUCCESS) { + GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret); + return ret; + } + ge_model = model_helper.GetGeModel(); + ret = RecompileNodes(ge_model); + if (ret != SUCCESS) { + GELOGW("LoadOmModelFromCache: recompile nodes failed. ret = %u", ret); + return ret; + } + return SUCCESS; +} + +Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, + string &var_name) { + std::string::size_type underline_idx = var_key.rfind('_'); + if (underline_idx == std::string::npos) { + GELOGW("Invalid var key: underline not found"); + return FAILED; + } + std::string::size_type format_idx = + var_key.rfind(std::to_string(static_cast(tensor_desc.GetFormat())), underline_idx); + if (format_idx == std::string::npos) { + GELOGW("Invalid var key: format not found"); + return FAILED; + } + var_name = var_key.substr(0, format_idx); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/common/helper/model_cache_helper.h b/src/ge/common/helper/model_cache_helper.h new file mode 100644 index 00000000..7524b224 --- /dev/null +++ b/src/ge/common/helper/model_cache_helper.h @@ -0,0 +1,123 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ +#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ + +#include +#include +#include + +#include "ge/ge_api_error_codes.h" +#include "graph/compute_graph.h" +#include "graph/manager/graph_var_manager.h" +#include "model/ge_model.h" + +namespace ge { +using Json = nlohmann::json; + +struct CacheInfo { + size_t node_num; + size_t edge_num; + size_t graph_hash; + map nodes_hash; + CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} +}; + +class ModelCacheHelper { + public: + ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); + ~ModelCacheHelper(); + + Status SaveCacheInfoToCache() const; + Status SaveVarManagerToCache(bool before_build) const; + Status SaveOmModelToCache(const GeModelPtr &ge_model) const; + bool IsModelCacheHit() const; + Status RecoverVarManagerFromCache() const; + Status LoadOmModelFromCache(GeModelPtr &ge_model) const; + Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); + Status ClearCache(uint32_t graph_id) const; + + private: + Status GetComputeGraphHash(size_t &hash) const; + Status GetNodesHash(map &hash_map) const; + Status GetCacheInfo(CacheInfo &cache_info) const; + + Status RecoverMemResource(const Json &json) const; + Status RecoverAllocatedGraphId(const Json &json) const; + Status RecoverChangedGraphId(const Json &json) const; + Status RecoverVarAddrAndTensorDesc(const Json &json) const; + Status RecoverBroadcastInfo(const Json &json) const; + Status RecoverTransRoads(const Json &json) const; + static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes); + static Status RecompileNodes(GeModelPtr &ge_model); + + bool IsNodeHashSameAsCache(const map &hash_map) const; + bool IsMemResourceSameAsCache(Json &json) const; + bool IsChangedGraphIdSameAsCache(Json &json) const; + bool IsAllocatedGraphIdSameAsCache(Json &json) const; + bool IsCurVarTensorDescSameAsCache(Json &json) const; + bool IsVarAddrMgrMapSameAsCache(Json &json) const; + bool IsBroadcastInfoSameAsCache(Json &json) const; + bool IsTransRoadsSameAsCache(Json &json) const; + bool IsVarManagerSameAsCache(Json &json) const; + bool IsVarManagerParamSameAsCache(Json &json) const; + + Status SaveJsonToFile(const string &file_name, const Json &json) const; + Status LoadJsonFromFile(const string &file_name, Json &json) const; + + Status GetNodesHashMapJson(Json &json) const; + Status GetMemResourceMap(Json &json) const; + Status GetVarAddrMgrMapJson(Json &json) const; + Status GetCurVarTensorDescMapJson(Json &json) const; + Status GetTransRoadsJson(Json &json) const; + Status GetChangedGraphIdJson(Json &json) const; + Status GetAllocatedGraphIdJson(Json &json) const; + Status GetBroadcastInfoJson(Json &json) const; + Status GetVarResourceJson(Json &json) const; + Status GetVarManagerJson(Json &json) const; + + static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); + static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); + static Status ParseMemResourceFromJson(const Json &json, map &mem_resource); + static Status ParseVarAddrMgrMapFromJson(const Json &json, + std::vector> &var_addr_mgr_vector, + std::unordered_set &var_offset_set); + static Status ParseCurVarTensorDescMapFromJson( + const Json &json, std::unordered_map &cur_var_tensor_desc_map); + static Status ParseTransRoadsFromJson(const Json &json, + std::unordered_map> &trans_roads); + static Status ParseChangedGraphIdFromJson(const Json &json, + std::unordered_map &changed_graph_id); + static Status ParseAllocatedGraphIdFromJson(const Json &json, + std::unordered_map &allocated_graph_id); + static Status ParseBroadcastInfoFromJson(const Json &json, + std::unordered_map &var_broadcast_info); + static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); + + uint64_t session_id_; + uint32_t graph_id_; + string cache_path_; + ComputeGraphPtr compute_graph_; + std::set var_names_; + bool is_cache_path_valid_for_output; + static map graph_id_run_times_; +}; + +using ModelCacheHelperPtr = std::shared_ptr; +} // namespace ge + +#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ diff --git a/src/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc index 29b3ff7d..194ea59f 100644 --- a/src/ge/common/helper/model_helper.cc +++ b/src/ge/common/helper/model_helper.cc @@ -26,15 +26,14 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" -using ge::ModelBufferData; -using ge::TBEKernelPtr; -using ge::TBEKernelStore; +using domi::ModelTaskDef; using std::string; + namespace { const int64_t kOriginalOmPartitionNum = 1; } -namespace domi { +namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } Status ModelHelper::SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, @@ -80,7 +79,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod model_tmp->SetAttr(ge_model->MutableAttrMap()); ge::Buffer model_buffer; - model_tmp->Save(model_buffer); + (void)model_tmp->Save(model_buffer); GELOGI("MODEL_DEF size is %zu", model_buffer.GetSize()); if (model_buffer.GetSize() > 0) { if (SaveModelPartition(om_file_save_helper, ModelPartitionType::MODEL_DEF, model_buffer.GetData(), @@ -506,4 +505,4 @@ Status ModelHelper::ReleaseLocalModelData() noexcept { } return result; } -} // namespace domi +} // namespace ge diff --git a/src/ge/common/helper/om_file_helper.cc b/src/ge/common/helper/om_file_helper.cc index 3f2fc833..917807f0 100644 --- a/src/ge/common/helper/om_file_helper.cc +++ b/src/ge/common/helper/om_file_helper.cc @@ -25,11 +25,9 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/util.h" -using ge::FileSaver; -using ge::ModelBufferData; using std::string; -namespace domi { +namespace ge { // For Load FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { if (CheckModelValid(model) != SUCCESS) { @@ -226,4 +224,4 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat return SUCCESS; #endif } -} // namespace domi +} // namespace ge diff --git a/src/ge/common/math/math_util.h b/src/ge/common/math/math_util.h index 56148240..08088eb1 100644 --- a/src/ge/common/math/math_util.h +++ b/src/ge/common/math/math_util.h @@ -26,7 +26,6 @@ #include "framework/common/debug/log.h" #include "framework/common/fmk_error_codes.h" -using ge::fp16_t; namespace ge { /// @ingroup math_util /// @brief check whether int32 addition can result in overflow diff --git a/src/ge/common/math_util.h b/src/ge/common/math_util.h index 87364a2b..5e783e81 100644 --- a/src/ge/common/math_util.h +++ b/src/ge/common/math_util.h @@ -26,7 +26,7 @@ #include "framework/common/util.h" #include "mmpa/mmpa_api.h" -namespace domi { +namespace ge { /** * @ingroup domi_calibration @@ -68,6 +68,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { return SUCCESS; } -} // end namespace domi +} // end namespace ge #endif // GE_COMMON_MATH_UTIL_H_ diff --git a/src/ge/common/model_parser/base.cc b/src/ge/common/model_parser/base.cc index 8485d799..a9a21ec5 100644 --- a/src/ge/common/model_parser/base.cc +++ b/src/ge/common/model_parser/base.cc @@ -22,15 +22,9 @@ #include #include +#include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/util.h" -#include "framework/common/debug/ge_log.h" - -using domi::GetFileLength; -using domi::MODEL_FILE_MAGIC_NUM; -using domi::ModelEncryptType; -using domi::ModelFileHeader; -using domi::RealPath; namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} diff --git a/src/ge/common/model_saver.cc b/src/ge/common/model_saver.cc index 424d2f1c..f68051f4 100644 --- a/src/ge/common/model_saver.cc +++ b/src/ge/common/model_saver.cc @@ -63,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi const char *model_char = model_str.c_str(); uint32_t len = static_cast(model_str.length()); // Write data to file - int32_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); + mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); diff --git a/src/ge/common/op/attr_define.cc b/src/ge/common/op/attr_define.cc deleted file mode 100644 index f9929a5e..00000000 --- a/src/ge/common/op/attr_define.cc +++ /dev/null @@ -1,814 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "framework/common/op/attr_define.h" - -namespace domi { -/** - * Public attribute - */ -const std::string ATTR_NAME_NAME = "name"; - -const std::string ATTR_NAME_TYPE = "type"; - -const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; - -const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; - -const std::string ATTR_NAME_ALPHA = "alpha"; - -const std::string ATTR_NAME_BETA = "beta"; - -const std::string ATTR_NAME_PADMODE = "pad_mode"; - -const std::string ATTR_NAME_PADMODES = "padding"; - -const std::string ATTR_NAME_MODE = "mode"; - -const std::string ATTR_NAME_FILTER = "filter"; - -const std::string ATTR_NAME_BIAS = "bias"; - -const std::string ATTR_NAME_BIAS_TERM = "bias_term"; - -const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; - -const std::string ATTR_NAME_PAD = "pad"; - -const std::string ATTR_NAME_PADS = "pad"; - -const std::string ATTR_NAME_PAD_SIZE = "pad size"; - -const std::string ATTR_NAME_PAD_MODE = "pad mode"; - -const std::string ATTR_NAME_SCALE = "scale"; - -const std::string ATTR_NAME_WINDOWS = "windows"; - -const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; - -const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; - -const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; - -const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string ATTR_NAME_ALGO = "algo"; - -const std::string ATTR_NAME_FORMAT = "format"; - -const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; - -const std::string ATTR_NAME_LRN_K = "lrn_k"; - -const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; - -const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; - -const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; - -const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; - -const std::string ATTR_NAME_AXIS = "axis"; -const std::string ATTR_NAME_BROADCAST = "broadcast"; - -const std::string ATTR_NAME_OUTPUT = "output"; -const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; -const std::string ATTR_NAME_TIDX = "t_idx"; - -const std::string ATTR_NAME_TPADDINGS = "t_paddings"; -const std::string ATTR_IMG_H = "img_h"; -const std::string ATTR_IMG_W = "img_w"; -const std::string ATTR_NET_H = "net_h"; -const std::string ATTR_NET_W = "net_w"; - -const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; - -const std::string ATTR_NAME_MULTIPLES = "multiples"; - -const std::string ATTR_NAME_T = "T"; -const std::string ATTR_NAME_N = "N"; - -const std::string ATTR_NAME_TSHAPE = "Tshape"; -const std::string ATTR_NAME_NAN_OPT = "nan_opt"; - -const std::string ATTR_NAME_AIPP = "aipp"; -const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; - -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; - -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; -const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; - -const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; -const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; - -const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; -const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; -const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; -const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; - -const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; -const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; - -const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; -const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; -const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; -const std::string ATTR_NAME_WEIGHTS = "value"; -const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; -const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; - -const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; -const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; -const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; - -const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; -const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; - -/* to be deleted*/ -const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; -const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; -const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; -const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; - -/* refinedet */ -const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; - -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; - -/* _Arg */ -const std::string ATTR_NAME_INDEX = "index"; -/* _RetVal */ -const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; -/*Data*/ -const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; - -/*Send*/ -const std::string SEND_ATTR_EVENT_ID = "event_id"; - -/*Recv*/ -const std::string RECV_ATTR_EVENT_ID = "event_id"; - -/** - * convolution - */ -const std::string ATTR_NAME_COEF = "coef"; - -const std::string ATTR_NAME_STRIDE = "stride"; - -const std::string ATTR_NAME_STRIDES = "stride"; - -const std::string ATTR_NAME_DILATION = "dilation"; - -const std::string ATTR_NAME_DILATIONS = "dilation"; - -const std::string CONV_ATTR_NAME_MODE = "mode"; - -const std::string CONV_ATTR_NAME_ALGO = "algo"; - -const std::string CONV_ATTR_NAME_GROUP = "group"; - -const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; - -const std::string CONV_ATTR_NAME_PAD = "pad"; - -const std::string CONV_ATTR_NAME_STRIDE = "stride"; - -const std::string CONV_ATTR_NAME_DILATION = "dilation"; - -const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; - -const std::string CONV_ATTR_NAME_KERNEL = "kernel"; - -const std::string CONV_ATTR_NAME_FILTER = "filter"; - -const std::string CONV_ATTR_NAME_BIAS = "bias"; - -const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string CONV_ATTR_NAME_ADJ = "adj"; - -const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; - -const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; - -const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; - -/*Pooling*/ -const std::string POOLING_ATTR_MODE = "mode"; -const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; -const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; -const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOLING_ATTR_WINDOW = "window"; -const std::string POOLING_ATTR_PAD = "pad"; -const std::string POOLING_ATTR_STRIDE = "stride"; -const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOLING_ATTR_DATA_MODE = "data_mode"; -const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; -const std::string POOLING_ATTR_NAME_ALGO = "algo"; - -/*Eltwise*/ -const std::string ELTWISE_ATTR_MODE = "mode"; -const std::string ELTWISE_ATTR_COEFF = "coeff"; -const std::string ELTWISE_ATTR_WEIGHT = "weight"; -const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; -const std::string ELTWISE_ATTR_ALPHA = "alpha"; -const std::string ELTWISE_ATTR_BETA = "beta"; - -/*BatchNorm*/ -const std::string BATCHNORM_ATTR_MODE = "mode"; -const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; -const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; -const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; -const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; -const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; -const std::string BATCHNORM_ATTR_SCALE = "scale"; -const std::string BATCHNORM_ATTR_BIAS = "bias"; -const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; -const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; -const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; - -/*huberloss*/ -const std::string HUBER_LOSS_ATTR_DELTA = "delta"; - -/*SSDRealDivTileMul*/ -const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; - -/*SSDSumMulRealDivMean*/ -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; - -/*ConcatFive2Four*/ -/*ConcatFour2Five*/ -const std::string SSD_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_CLASS_NUM = "class_num"; -const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; -const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; -const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; -const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; - -/*Scale*/ -const std::string SCALE_ATTR_SCALE = "scale"; -const std::string SCALE_ATTR_BIAS = "bias"; - -/*FullConnection*/ -const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; -const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; -const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; -const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; -const std::string FULL_ATTR_NAME_ALGO = "algo"; - -/*SoftmaxOpParams*/ -const std::string SOFTMAX_ATTR_ALGO = "algo"; -const std::string SOFTMAX_ATTR_MODE = "mode"; - -/*SparseSoftmaxCrossEntropy*/ -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; -const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; - -/*Activation*/ -const std::string ACTIVATION_ATTR_MODE = "mode"; -const std::string ACTIVATION_ATTR_COEF = "coef"; - -/* Concat */ -const std::string CONCAT_ATTR_NAME_AXIS = "axis"; - -/* Const */ -const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; -const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; -const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; - -/* roipooling */ -const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; -const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; - -/* DetectionOutput */ -const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; -const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; -const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; -const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; -const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; -const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; -const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; -const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; - -/* Ssd DetectionOutput */ -const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; -const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; -const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; -const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; -const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; -const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; - -/* Refinedet DetectionOutput */ -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; - -/* yolo DetectionOutput */ -const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; -const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; -const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; -const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; -const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; -const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; -const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; -const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; - -/* DetectionPostprocess */ -const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; -const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; -const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; -const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; - -/* Spatialtransfrom */ -const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; -const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; -const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; -const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; - -/* Proposal */ -const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; -const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; -const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; -const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; -const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; -const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; -const std::string PROPOSAL_ATTR_IMG_H = "img_h"; -const std::string PROPOSAL_ATTR_IMG_W = "img_w"; - -/* Softmax */ -const std::string SOFTMAX_ATTR_AXIS = "axis"; - -/* Permute */ -const std::string PERMUTE_ATTR_ORDER = "order"; -const std::string PERMUTE_ATTR_PERM = "perm"; - -/*SSD Normalize*/ -const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; -const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; -const std::string SSDNORMALIZE_ATTR_EPS = "eps"; - -/* Flatten */ -const std::string FLATTEN_ATTR_AXIS = "axis"; -const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; - -/* SsdPRIORBOX */ -const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; -const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; -const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; - -/* RefinedetDetectionOutput */ -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; - -/* PRelu */ -const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; - -/*psroi pooling*/ -const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; -const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; - -/* power */ -const std::string POWER_ATTR_NAME_POWER = "power"; -const std::string POWER_ATTR_NAME_SCALE = "scale"; -const std::string POWER_ATTR_NAME_SHIFT = "shift"; - -/* log */ -const std::string LOG_ATTR_NAME_SCALE = "scale"; -const std::string LOG_ATTR_NAME_SHIFT = "shift"; -const std::string LOG_ATTR_NAME_BASE = "base"; - -/*pack*/ -const std::string PACK_ATTR_NAME_NUM = "N"; - -/*unpack*/ -const std::string UNPACK_ATTR_NAME_NUM = "num"; - -const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; - -/*gathernd*/ -const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; - -/*argmax*/ -const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; -const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; -const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; -const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; -const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; -const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; -const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; - -/*upsample*/ -const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; -const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; - -/* relu */ -const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; - -/* FreeSpaceExtract */ -const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; - -/* split */ -const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; -const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; -const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; - -/* Tvm */ -const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; -const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; -const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; - -/*squeeze*/ -const std::string SQUEEZE_ATTR_AXIS = "axis"; -const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; -const std::string SQUEEZE_OP_NAME = "Squeeze"; - -/*stride slice*/ -const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; -const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; -const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; -const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; -const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; - -/*slice*/ -const std::string SLICE_ATTR_NAME_BEGINS = "begins"; -const std::string SLICE_ATTR_NAME_SIZES = "sizes"; - -/*roialign*/ -const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; -const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; - -/*generate_rpn_proposal*/ -const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; - -/*decode_bbox*/ -const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; - -/* Cast */ -const std::string CAST_ATTR_DSTT = "DstT"; -const std::string CAST_ATTR_SRCT = "SrcT"; - -/* fastrcnnn predications*/ -const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; -const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; - -/* REORG*/ -const std::string REORG_ATTR_STRIDE = "stride"; -const std::string REORG_ATTR_REVERSE = "reverse"; - -/* MERGE*/ -const std::string MERGE_DEAD_INDEX = "merge_dead_index"; -const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; -const std::string TO_BE_OUTPUT = "to_be_output"; - -/*Concatv2*/ -const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; -const std::string CONCAT_V2_ATTR_N = "N"; - -/* SUM*/ -const std::string SUM_ATTR_TIDX = "Tidx"; -const std::string SUM_ATTR_AXIS = "axis"; -const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; - -/*ResizeBilinear*/ -const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; -const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; -const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; -const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; -const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; -const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; -const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; -const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; -const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; -const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; - -/*RetinaNet*/ -const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; -const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; - -/*MatMul*/ -const std::string MATMUL_TRANSPOSE_X = "transposeX"; -const std::string MATMUL_TRANSPOSE_W = "transposeW"; -const std::string MATMUL_HAS_BIAS = "has_bias"; -const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; - -/*Flatten*/ -const std::string FLATTEN_START_AXIS = "start_axis"; -const std::string FLATTEN_END_AXIS = "end_axis"; - -/*reshape*/ -const std::string RESHAPE_ATTR_AXIS = "axis"; -const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; -const std::string RESHAPE_ATTR_FORMAT = "format"; -const std::string RESHAPE_ATTR_SHAPE = "shape"; -const std::string RESHAPE_ATTR_ALPHA = "alpha"; -const std::string RESHAPE_ATTR_BETA = "beta"; - -/*frameoworkop*/ -const std::string T_IN_DATATYPE = "t_in_datatype"; -const std::string T_OUT_DATATYPE = "t_out_datatype"; -const std::string ATTR_NAME_OUT_N = "out_n"; -const std::string ATTR_NAME_OUT_C = "out_c"; -const std::string ATTR_NAME_OUT_H = "out_h"; -const std::string ATTR_NAME_OUT_W = "out_w"; -const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; -const std::string ATTR_PAD_CONV = "pad_conv"; - -const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; -const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; -const std::string PAD_ATTR_PADDINGDS = "paddings"; -const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; - -/*ConvGradFilter*/ -const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; - -/*ConvGradInput*/ -const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; - -/*Rnn*/ -const std::string RNN_MODE_STATIC = "rnn_static"; -const std::string MUTI_RNN = "multi_rnn"; -const std::string CNN_RNN = "cnn_rnn"; -const std::string CELL_MODE = "mode"; -const std::string LSTM_CELL = "lstm_cell"; -const std::string GRU_CELL = "gru_cell"; -const std::string RNN_HT = "ht"; -const std::string RNN_XT_HT = "xt_ht"; -const std::string RNN_BATCH_SIZE = "batch_size"; -const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; -const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; -const std::string LSTM_ACTIVATE = "lstm_activate"; -const std::string LSTM_OUT_MAP = "lstm_out_map"; -const std::string LSTM_OUT_MODE = "lstm_out_mode"; -const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; -const std::string LSTM_TIME_MAJOR = "lstm_time_major"; -const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; - -/*Upsample*/ -const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; - -/*PadV2*/ -const std::string PADV2_ATTR_NAME_MODE = "mode"; -const std::string PADV2_ATTR_NAME_PADS = "paddings"; -const std::string PADV2_ATTR_NAME_T = "T"; -const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; - -/*MirrorPad*/ -const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; -const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; -const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; - -/* filler */ -const std::string FILLER_TYPE = "filler_type"; -const std::string FILLER_VALUE = "filler_value"; - -/*shufflechannel*/ -const std::string SHUFFLE_CHANNEL_GROUP = "group"; - -/*TopKV2*/ -const std::string TOPKV2_ATTR_K = "k"; - -/*Calibaration*/ -const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; -const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; -const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; -const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; -const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; -const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; - -const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; - -const std::string ATTR_NAME_GROUP = "group"; -const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; -const std::string ATTR_NAME_EPSILON = "epsilon"; -const std::string ATTR_NAME_POOLING_MODE = "mode"; -const std::string ATTR_NAME_CLASS_NUM = "class_num"; -/** - * model - */ -const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; - -const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; - -const std::string ATTR_MODEL_EVENT_NUM = "event_num"; - -const std::string ATTR_MODEL_LABEL_NUM = "label_num"; - -const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; - -const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; - -const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; - -const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; - -/** - * Public attribute - */ -const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; - -const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; - -const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; - -const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; - -const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; - -const std::string ATTR_NAME_OPATTR = "opattr"; - -const std::string ATTR_NAME_RELUFLAG = "relu_flag"; - -const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; - -const std::string ATTR_NAME_X_INDEX = "x_index"; - -const std::string ATTR_NAME_CONT_INDEX = "cont_index"; - -const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; - -const std::string TARGET_TYPE_MINI = "MINI"; - -const std::string TARGET_TYPE_TINY = "TINY"; - -const std::string TARGET_TYPE_LITE = "LITE"; - -/*l2_normalize*/ -const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - -const std::string POOL_PARAMA_ATTR_WINDOW = "window"; -const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; -const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; -const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; - -/*HCOM*/ -const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; -const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; - -const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; -const std::string HCOM_ATTR_GROUP = "group"; -const std::string HCOM_ATTR_SR_TAG = "sr_tag"; -const std::string HCOM_ATTR_SRC_RANK = "src_rank"; -const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; -const std::string HCOM_ATTR_FUSION = "fusion"; -const std::string HCOM_ATTR_SHAPE = "shape"; -const std::string HCOM_ATTR_DATA_TYPE = "dtype"; - -/*SpaceToDepth/DepthToSpace*/ -const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; - -/*SparseSoftmaxCrossEntropyWithLogits*/ -const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; - -/*MaxPoolGradWithArgmax*/ -const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; - -/*AvgPoolGrad*/ -const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; - -/*Pad*/ -const std::string ATTR_PAD_FORMAT = "attr_pad_format"; - -/*Varible*/ -const std::string VAR_ATTR_FORMAT = "_var_format"; -const std::string VAR_ATTR_NAME = "var_name"; -const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; -const std::string VAR_ATTR_4D_FORMAT = "4D"; -const std::string VAR_ATTR_5D_FORMAT = "5D"; -const std::string VAR_ATTR_DATA_TYPE = "data_format"; -const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; -const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; -const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; -const std::string VAR_ATTR_SHAPE = "shape"; -const std::string HALF_VAR_NAME_END = "_fp16"; -const std::string VAR_ATTR_INITED = "var_is_inited"; - -const std::string VAR_ATTR_CONTAINER = "container"; -const std::string VAR_ATTR_SHARED_NAME = "shared_name"; -const std::string VAR_ATTR_DTYPE = "dtype"; - -const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; -const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; -const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; -const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; -const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; -const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; - -/*Assign*/ -const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; - -/* space2bacth batch2space */ -const std::string BATCH_SPACE_ATTR_BLOCK = "block"; -const std::string BATCH_SPACE_ATTR_PADDING = "padding"; - -/*depth_to_space space_to_depth*/ -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; - -/*FakeQuantWithMinMaxVars*/ -const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; -const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; - -/*mobilenet_ssd_conv_fusion*/ -const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; -const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; -const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; - -/*lsh project*/ -const std::string LSH_PROJ_TYPE = "lsh_project_type"; - -/* log time stamp */ -const std::string LOG_TIME_STAMP_LOGID = "logid"; -const std::string LOG_TIME_STAMP_NOTIFY = "notify"; - -/*ShapeN*/ -const std::string SHAPEN_ATTR_N = "N"; -const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; -const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; - -/* control flow */ -const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; -const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; -const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; - -/* GatherV2 attr def */ -const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; -const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; - -/* Reshape attr def */ -const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; -const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; - -/* axis attr def */ -const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; - -const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; - -const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; -const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; - -/* For constant folding */ -const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; -} // namespace domi diff --git a/src/ge/common/op/attr_value_util.cc b/src/ge/common/op/attr_value_util.cc index 77d81076..5d74aa1d 100644 --- a/src/ge/common/op/attr_value_util.cc +++ b/src/ge/common/op/attr_value_util.cc @@ -18,7 +18,7 @@ #include "framework/common/debug/log.h" #include "framework/common/util.h" -namespace domi { +namespace ge { #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ FMK_FUNC_DEV_VISIBILITY void SetAttrDef(ARG_TYPE value, AttrDef *out) { \ GE_CHECK_NOTNULL_JUST_RETURN(out); \ @@ -312,4 +312,4 @@ DEFINE_GET_ATTR_LIST_SIZE(const std::string &, uint32_t, u); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, float, f); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, double, f); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, bool, b); -} // namespace domi +} // namespace ge diff --git a/src/ge/common/op/ge_op_utils.cc b/src/ge/common/op/ge_op_utils.cc index b8a17514..bba1afe8 100644 --- a/src/ge/common/op/ge_op_utils.cc +++ b/src/ge/common/op/ge_op_utils.cc @@ -25,16 +25,15 @@ #include "framework/common/debug/log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/ge_inner_error_codes.h" -#include "framework/common/op/attr_define.h" #include "framework/common/op/attr_value_util.h" #include "framework/common/util.h" #include "graph/anchor.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" #include "mmpa/mmpa_api.h" -using ge::fp16_t; using std::vector; namespace ge { @@ -69,6 +68,8 @@ const uint32_t FOR_LIMIT_INPUT = 1; const uint32_t FOR_DELTA_INPUT = 2; const uint32_t FOR_DATA_INPUT = 3; +const int NORMAL_TENSOR_SIZE = 4; + // Get the value of key from attr #define AIPP_GET_ATTR_VALUE(KEY, ATTR_TYPE) \ if (aipp_attr.GetItem(#KEY).GetValue(KEY) != SUCCESS) { \ @@ -110,7 +111,7 @@ const uint32_t FOR_DATA_INPUT = 3; #define AIPP_CONVERT_LIST_FLOAT(KEY, REQUIRED) AIPP_CONVERT_LIST_FORMAT(KEY, float, REQUIRED, GeAttrValue::FLOAT) FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status -OpUtils::ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params) { +OpUtils::ConvertAippParams(const GeAttrValue::NAMED_ATTRS &aipp_attr, domi::AippOpParams *aipp_params) { GE_CHECK_NOTNULL(aipp_params); AIPP_CONVERT_FORMAT_EX(aipp_mode, domi::AippOpParams::AippMode, int32_t, GeAttrValue::INT); @@ -177,7 +178,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::TransferDim(con for (auto dim_temp : dim) { new_dim_list.push_back(dim_temp); } - if (input_shape_size > domi::DIM_DEFAULT_SIZE) { + if (input_shape_size > DIM_DEFAULT_SIZE) { dim_vector = dim; GELOGI("Dim_vector size is %zu, do not to transfer dim", input_shape_size); return SUCCESS; diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index 603bdfb1..8422ebf6 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -50,6 +50,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager &ProfilingMana FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::Init(const Options &options) { #ifdef DAVINCI_SUPPORT_PROFILING + vector().swap(device_id_); device_id_.push_back(options.device_id); job_id_ = options.job_id; @@ -58,7 +59,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In GELOGI("Profiling json config from acl:%s", recv_profiling_config_.c_str()); ret = InitFromAclCfg(recv_profiling_config_); } else { - ret = InitFromEnv(options); + ret = InitFromOptions(options); } if (ret != SUCCESS) { GELOGE(ret, "Failed to init profiling."); @@ -67,8 +68,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In if (is_profiling_) { // register Framework to profiling - const ProfilingEngineImpl engine_0; - int result = Msprof::Engine::RegisterEngine("Framework", &engine_0); + int result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); if (result != 0) { GELOGE(FAILED, "Register profiling engine failed."); return FAILED; @@ -172,21 +172,30 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In return ge::SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromEnv(const Options &options) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::InitFromOptions(const Options &options) { #ifdef DAVINCI_SUPPORT_PROFILING - const char *is_profiling = std::getenv("PROFILING_MODE"); + // enable profiling support two ways: env and front end + const char *profiling_mode = std::getenv("PROFILING_MODE"); const char *prof_options = std::getenv("PROFILING_OPTIONS"); - if ((is_profiling == nullptr) || (strcmp("true", is_profiling) != 0) || (prof_options == nullptr)) { - // default training trace on + if ((profiling_mode == nullptr) || (strcmp("true", profiling_mode) != 0) || (prof_options == nullptr)) { is_profiling_ = false; - return SUCCESS; } else { std::string prof_options_str = std::string(prof_options); - profiling_opts_ = domi::StringUtils::Split(prof_options_str, ':'); + profiling_opts_ = StringUtils::Split(prof_options_str, ':'); is_profiling_ = true; + GELOGI("The profiling in env is %s, %s", profiling_mode, prof_options); + } + if (!is_profiling_) { + const std::string enable_profiling = "1"; + if (options.profiling_mode != enable_profiling || options.profiling_options.empty()) { + is_profiling_ = false; + return SUCCESS; + } else { + profiling_opts_ = StringUtils::Split(options.profiling_options, ':'); + is_profiling_ = true; + GELOGI("The profiling in options is %s, %s", options.profiling_mode.c_str(), options.profiling_options.c_str()); + } } - GELOGI("The profiling in options is %s, %s", is_profiling, prof_options); - // features:'training_trace', 'task_trace' or 'op_trace' etc if (!profiling_opts_.empty()) { if (profiling_opts_[0] == "op_trace") { @@ -314,122 +323,119 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( - const std::vector &task_desc_info) { + const std::vector &task_desc_info, const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter == nullptr) { GELOGI("Profiling report is nullptr!"); return; } - std::string data; - for (size_t i = 0; i < device_id_.size(); ++i) { - for (const auto &task : task_desc_info) { - std::string op_name = task.op_name; - uint32_t block_dim = task.block_dim; - uint32_t task_id = task.task_id; - uint32_t stream_id = task.stream_id; - data = op_name.append(" ").append(std::to_string(block_dim) - .append(" ") - .append(std::to_string(task_id)) - .append(" ") - .append(std::to_string(stream_id)) - .append("\n")); - - Msprof::Engine::ReporterData reporter_data{}; - reporter_data.deviceId = device_id_[i]; - reporter_data.data = (unsigned char *)data.c_str(); - reporter_data.dataLen = data.size(); - int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); - if (ret != EOK) { - GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); - return; - } - ret = reporter->Report(&reporter_data); - if (ret != SUCCESS) { - GELOGE(ret, "Reporter data of task_desc_info fail!"); - return; - } + std::string data; + for (const auto &task : task_desc_info) { + std::string op_name = task.op_name; + uint32_t block_dim = task.block_dim; + uint32_t task_id = task.task_id; + uint32_t stream_id = task.stream_id; + data = op_name.append(" ").append(std::to_string(block_dim) + .append(" ") + .append(std::to_string(task_id)) + .append(" ") + .append(std::to_string(stream_id)) + .append("\n")); + + Msprof::Engine::ReporterData reporter_data{}; + reporter_data.deviceId = device_id; + reporter_data.data = (unsigned char *)data.c_str(); + reporter_data.dataLen = data.size(); + int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); + if (ret != EOK) { + GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); + return; } - data.clear(); + ret = reporter->Report(&reporter_data); + if (ret != SUCCESS) { + GELOGE(ret, "Reporter data of task_desc_info fail!"); + return; + } } + + data.clear(); #endif } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( - const std::vector &compute_graph_desc_info) { + const std::vector &compute_graph_desc_info, const int32_t &device_id) { #ifdef DAVINCI_SUPPORT_PROFILING Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); std::string data; - for (size_t idx = 0; idx < device_id_.size(); ++idx) { - for (const auto &graph : compute_graph_desc_info) { - data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); - for (size_t i = 0; i < graph.input_format.size(); ++i) { - data.append(" input_id:") - .append(std::to_string(i)) - .append(" input_format:") - .append(std::to_string(graph.input_format.at(i))) - .append(" input_data_type:") - .append(std::to_string(graph.input_data_type.at(i))) - .append(" input_shape:\""); - size_t input_shape_len = graph.input_shape.at(i).size(); - if (input_shape_len == 0) { - data.append(""); - } else if (input_shape_len == 1) { - data.append(std::to_string(graph.input_shape.at(i).at(0))); - } else { - for (size_t j = 0; j < input_shape_len - 1; ++j) { - data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); - } - data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); + for (const auto &graph : compute_graph_desc_info) { + data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); + for (size_t i = 0; i < graph.input_format.size(); ++i) { + data.append(" input_id:") + .append(std::to_string(i)) + .append(" input_format:") + .append(std::to_string(graph.input_format.at(i))) + .append(" input_data_type:") + .append(std::to_string(graph.input_data_type.at(i))) + .append(" input_shape:\""); + size_t input_shape_len = graph.input_shape.at(i).size(); + if (input_shape_len == 0) { + data.append(""); + } else if (input_shape_len == 1) { + data.append(std::to_string(graph.input_shape.at(i).at(0))); + } else { + for (size_t j = 0; j < input_shape_len - 1; ++j) { + data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); } - - data.append("\""); + data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); } - for (size_t i = 0; i < graph.output_format.size(); ++i) { - data.append(" output_id:") - .append(std::to_string(i)) - .append(" output_format:") - .append(std::to_string(graph.output_format.at(i))) - .append(" output_data_type:") - .append(std::to_string(graph.output_data_type.at(i))) - .append(" output_shape:\""); - size_t output_shape_len = graph.output_shape.at(i).size(); - if (output_shape_len == 0) { - data.append(""); - } else if (output_shape_len == 1) { - data.append(std::to_string(graph.output_shape.at(i).at(0))); - } else { - for (size_t j = 0; j < output_shape_len - 1; ++j) { - data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); - } - data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); + data.append("\""); + } + + for (size_t i = 0; i < graph.output_format.size(); ++i) { + data.append(" output_id:") + .append(std::to_string(i)) + .append(" output_format:") + .append(std::to_string(graph.output_format.at(i))) + .append(" output_data_type:") + .append(std::to_string(graph.output_data_type.at(i))) + .append(" output_shape:\""); + size_t output_shape_len = graph.output_shape.at(i).size(); + if (output_shape_len == 0) { + data.append(""); + } else if (output_shape_len == 1) { + data.append(std::to_string(graph.output_shape.at(i).at(0))); + } else { + for (size_t j = 0; j < output_shape_len - 1; ++j) { + data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); } - data.append("\""); + data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); } + data.append("\""); + } - data.append("\n"); + data.append("\n"); - Msprof::Engine::ReporterData reporter_data{}; - Report(idx, data, *reporter, reporter_data); + Msprof::Engine::ReporterData reporter_data{}; + Report(device_id, data, *reporter, reporter_data); - data.clear(); - } + data.clear(); } #endif } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( - const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, + const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, Msprof::Engine::ReporterData &reporter_data) { #ifdef DAVINCI_SUPPORT_PROFILING size_t index = data.size() / kReportMaxLen; if (index >= 1) { - reporter_data.deviceId = device_id_[idx]; + reporter_data.deviceId = device_id; int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); for (size_t i = 0; i < index; ++i) { @@ -445,7 +451,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); } } else { - reporter_data.deviceId = device_id_[idx]; + reporter_data.deviceId = device_id; reporter_data.data = (unsigned char *)data.c_str(); reporter_data.dataLen = data.size(); int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); @@ -457,13 +463,36 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( #endif } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUnInit(const std::string &module) const { +#ifdef DAVINCI_SUPPORT_PROFILING + int ret = Msprof::Engine::UnInit(module); + if (ret != SUCCESS) { + GELOGE(ret, "profiling plugin uninit failed, ret:%d", ret); + } +#endif +} + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( const std::vector &task_desc_info, const std::vector &compute_graph_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); + return; + } + GELOGI("current device_id:%d", device_id); + + auto ret = std::find(device_id_.begin(), device_id_.end(), device_id); + if (ret == device_id_.end()) { + GELOGE(FAILED, "get valid device_id failed, profiling report failed."); + return; + } + GELOGI("start ProfilingTaskDescInfo."); - ProfilingTaskDescInfo(task_desc_info); + ProfilingTaskDescInfo(task_desc_info, device_id); GELOGI("start ProfilingGraphDescInfo."); - ProfilingGraphDescInfo(compute_graph_desc_info); + ProfilingGraphDescInfo(compute_graph_desc_info, device_id); GELOGI("Report profiling data for GE end."); #endif } @@ -495,7 +524,7 @@ int PluginImpl::UnInit() { Msprof::Engine::PluginIntf *ProfilingEngineImpl::CreatePlugin() { GELOGI(" Create Plugin"); - return new (std::nothrow) PluginImpl("Framework"); + return new (std::nothrow) PluginImpl(GE_PROFILING_MODULE); } int ProfilingEngineImpl::ReleasePlugin(Msprof::Engine::PluginIntf *plugin) { diff --git a/src/ge/common/profiling/profiling_manager.h b/src/ge/common/profiling/profiling_manager.h index e56f514f..2dc0b407 100644 --- a/src/ge/common/profiling/profiling_manager.h +++ b/src/ge/common/profiling/profiling_manager.h @@ -32,13 +32,39 @@ using std::string; using std::vector; namespace ge { +const std::string GE_PROFILING_MODULE = "Framework"; +// register Plugin +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { + public: + explicit PluginImpl(const std::string &module); + ~PluginImpl() {} + + int Init(const Msprof::Engine::Reporter *reporter); + int UnInit(); + static Msprof::Engine::Reporter *GetPluginReporter() { return reporter_; } + + private: + static Msprof::Engine::Reporter *reporter_; + std::string module_; +}; + +// register Engine +class ProfilingEngineImpl : public Msprof::Engine::EngineIntf { + public: + ProfilingEngineImpl() {} + ~ProfilingEngineImpl() {} + + Msprof::Engine::PluginIntf *CreatePlugin(); + int ReleasePlugin(Msprof::Engine::PluginIntf *plugin); +}; + class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { public: ProfilingManager(); virtual ~ProfilingManager(); static ProfilingManager &Instance(); ge::Status Init(const Options &options); - ge::Status InitFromEnv(const Options &options); + ge::Status InitFromOptions(const Options &options); ge::Status InitFromAclCfg(const std::string &config); ge::Status StartProfiling(int32_t iter, int32_t device_id); void StopProfiling(); @@ -46,16 +72,16 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { bool ProfilingLoadFlag() const { return is_load_; } bool ProfilingOn() const { return is_profiling_; } int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; } - void ReportProfilingData(const std::vector &task_desc_info, const std::vector &compute_graph_desc_info); - - void Report(const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, + void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, Msprof::Engine::ReporterData &reporter_data); - void ProfilingTaskDescInfo(const std::vector &task_desc_info); - void ProfilingGraphDescInfo(const std::vector &compute_graph_desc_info); + void ProfilingTaskDescInfo(const std::vector &task_desc_info, const int32_t &device_id); + void ProfilingGraphDescInfo(const std::vector &compute_graph_desc_info, + const int32_t &device_id); void SetProfilingConfig(const string &profiling_cfg); vector GetProfilingDeviceId() const { return device_id_; } + void PluginUnInit(const std::string &module) const; private: bool is_profiling_ = false; @@ -70,35 +96,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { string recv_profiling_config_; string send_profiling_config_; string system_trace_conf_; -}; - -/** - * @brief register Plugin - */ -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { - public: - explicit PluginImpl(const std::string &module); - ~PluginImpl() {} - - int Init(const Msprof::Engine::Reporter *reporter); - int UnInit(); - static Msprof::Engine::Reporter *GetPluginReporter() { return reporter_; } - - private: - static Msprof::Engine::Reporter *reporter_; - std::string module_; -}; - -/** - * @brief register Engine - */ -class ProfilingEngineImpl : public Msprof::Engine::EngineIntf { - public: - ProfilingEngineImpl() {} - ~ProfilingEngineImpl() {} - - Msprof::Engine::PluginIntf *CreatePlugin(); - int ReleasePlugin(Msprof::Engine::PluginIntf *plugin); + const ProfilingEngineImpl engine_; }; } // namespace ge #endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_ diff --git a/src/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc index e44fc4eb..7321af9f 100644 --- a/src/ge/common/properties_manager.cc +++ b/src/ge/common/properties_manager.cc @@ -59,7 +59,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::Init(co // Load file contents bool PropertiesManager::LoadFileContent(const std::string &file_path) { // Normalize the path - string resolved_file_path = domi::RealPath(file_path.c_str()); + string resolved_file_path = RealPath(file_path.c_str()); if (resolved_file_path.empty()) { DOMI_LOGE("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); return false; @@ -67,7 +67,7 @@ bool PropertiesManager::LoadFileContent(const std::string &file_path) { std::ifstream fs(resolved_file_path, std::ifstream::in); if (!fs.is_open()) { - GELOGW("Open %s failed.", file_path.c_str()); + GELOGE(PARAM_INVALID, "Open %s failed.", file_path.c_str()); return false; } @@ -75,7 +75,7 @@ bool PropertiesManager::LoadFileContent(const std::string &file_path) { while (getline(fs, line)) { // line not with \n if (!ParseLine(line)) { - GELOGW("Parse line failed. content is [%s].", line.c_str()); + GELOGE(PARAM_INVALID, "Parse line failed. content is [%s].", line.c_str()); fs.close(); return false; } @@ -96,16 +96,17 @@ bool PropertiesManager::ParseLine(const std::string &line) { } if (!temp.empty()) { - std::string::size_type pos = temp.find_first_of(delimiter); // Must be divided by "=" + std::string::size_type pos = temp.find_first_of(delimiter); if (pos == std::string::npos) { - GELOGW("Incorrect line [%s]", line.c_str()); + GELOGE(PARAM_INVALID, "Incorrect line [%s], it must include [%s].Perhaps you use illegal chinese symbol", + line.c_str(), delimiter.c_str()); return false; } std::string map_key = Trim(temp.substr(0, pos)); std::string value = Trim(temp.substr(pos + 1)); if (map_key.empty() || value.empty()) { - GELOGW("Map_key or value empty. %s", line.c_str()); + GELOGE(PARAM_INVALID, "Map_key or value empty. %s", line.c_str()); return false; } @@ -273,4 +274,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager:: return this->dump_step_; } +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpMode(const std::string &dump_mode) { + std::lock_guard lock(dump_mutex_); + this->dump_mode_ = dump_mode; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpMode() { + std::lock_guard lock(dump_mutex_); + return this->dump_mode_; +} } // namespace ge diff --git a/src/ge/common/properties_manager.h b/src/ge/common/properties_manager.h index 100b83f0..eb43820c 100644 --- a/src/ge/common/properties_manager.h +++ b/src/ge/common/properties_manager.h @@ -94,6 +94,8 @@ class PropertiesManager { std::string GetDumpOutputPath(); void SetDumpStep(const std::string &dump_step); std::string GetDumpStep(); + void SetDumpMode(const std::string &dump_mode); + std::string GetDumpMode(); private: // Private construct, destructor @@ -120,6 +122,7 @@ class PropertiesManager { std::string output_mode_; std::string output_path_; std::string dump_step_; + std::string dump_mode_; std::map> model_dump_properties_map_; // model_dump_layers_map_ std::mutex dump_mutex_; }; diff --git a/src/ge/common/thread_pool.cc b/src/ge/common/thread_pool.cc index a52d4938..700892f2 100644 --- a/src/ge/common/thread_pool.cc +++ b/src/ge/common/thread_pool.cc @@ -62,9 +62,8 @@ void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { std::function task; { std::unique_lock lock{thread_pool->m_lock_}; - thread_pool->cond_var_.wait(lock, [thread_pool] { - return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); - }); + thread_pool->cond_var_.wait( + lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { return; } diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index 8b4e3ed4..26668c70 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -15,7 +15,6 @@ */ #include "framework/common/types.h" - #include "graph/types.h" namespace ge { @@ -25,16 +24,13 @@ const std::string DUMP_ALL_MODEL = "ALL_MODEL_NEED_DUMP_AND_IT_IS_NOT_A_MODEL_NA const std::string DUMP_STATUS = "status"; const std::string DUMP_LAYER = "layer"; const std::string DUMP_FILE_PATH = "path"; -} // namespace ge +const std::string DUMP_MODE = "dump_mode"; -namespace domi { const int DEFAULT_FORMAT = static_cast(ge::FORMAT_NCHW); -/** - * @brief Supported public property names - */ -const std::string PROP_OME_START_TIME = "ome_start_time"; /**< start time */ -const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; /**< dump path */ -const std::string PROP_OME_LOG_PATH = "ome_log_path"; /**< log path */ +// Supported public property names +const std::string PROP_OME_START_TIME = "ome_start_time"; // start time +const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; // dump path +const std::string PROP_OME_LOG_PATH = "ome_log_path"; // log path // Profile related constant const uint32_t CCE_PROFILE_ON = 0; @@ -354,6 +350,7 @@ REGISTER_OPTYPE_DEFINE(RESOURCEAPPLYMOMENTUM, "ResourceApplyMomentum"); REGISTER_OPTYPE_DEFINE(SGD, "SGD"); REGISTER_OPTYPE_DEFINE(NOOP, "NoOp"); REGISTER_OPTYPE_DEFINE(READVARIABLEOP, "ReadVariableOp"); +REGISTER_OPTYPE_DEFINE(PARALLELCONCATSTART, "_ParallelConcatStart"); REGISTER_OPTYPE_DEFINE(CONSTANTOP, "Constant"); REGISTER_OPTYPE_DEFINE(DEPTHWISECONV2DBACKPROPFILTER, "DepthwiseConv2dNativeBackpropFilter"); REGISTER_OPTYPE_DEFINE(DEPTHWISECONV2DBACKPORPINPUT, "DepthwiseConv2dNativeBackpropInput"); @@ -387,6 +384,7 @@ REGISTER_OPTYPE_DEFINE(STREAMSWITCH, "StreamSwitch"); REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); +REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DEFINE(SEND, "Send"); @@ -394,6 +392,7 @@ REGISTER_OPTYPE_DEFINE(RECV, "Recv"); REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); +REGISTER_OPTYPE_DEFINE(LABELGOTOEX, "LabelGotoEx"); REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); @@ -455,6 +454,8 @@ REGISTER_OPTYPE_DEFINE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d"); REGISTER_OPTYPE_DEFINE(SQRTGRAD, "SqrtGrad"); REGISTER_OPTYPE_DEFINE(SIGMOIDGRAD, "SigmoidGrad"); +REGISTER_OPTYPE_DEFINE(TRANSSHAPE, "TransShape"); + const std::string MODEL_ATTR_TASKS = "tasks"; const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; @@ -469,315 +470,315 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 8589934592; // Max size of 8 GB. const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. #endif -/** - * @brief Magic number of model file - */ +/// +///@brief Magic number of model file +/// const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number -/** - * @brief Model head length - */ +/// +///@brief Model head length +/// const uint32_t MODEL_FILE_HEAD_LEN = 256; -/** - * @ingroup domi_omg - * @brief Input node type - */ +/// +///@ingroup domi_omg +///@brief Input node type +/// const std::string INPUT_TYPE = "Input"; -/** - * @ingroup domi_omg - * @brief AIPP label, label AIPP conv operator - */ +/// +///@ingroup domi_omg +///@brief AIPP label, label AIPP conv operator +/// const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; -/** - * @ingroup domi_omg - * @brief AIPP label, label aipp data operator - */ +/// +///@ingroup domi_omg +///@brief AIPP label, label aipp data operator +/// const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; -/** - * @ingroup domi_omg - * @brief Record the w dimension of model input corresponding to dynamic AIPP - */ +/// +///@ingroup domi_omg +///@brief Record the w dimension of model input corresponding to dynamic AIPP +/// const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; -/** - * @ingroup domi_omg - * @brief Record the H dimension of model input corresponding to dynamic AIPP - */ +/// +///@ingroup domi_omg +///@brief Record the H dimension of model input corresponding to dynamic AIPP +/// const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; -/** - * @ingroup domi_omg - * @brief The tag of the data operator. Mark this input to the dynamic AIPP operator - */ +/// +///@ingroup domi_omg +///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator +/// const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; -/** - * @ingroup domi_omg - * @brief DATA node type - */ +/// +///@ingroup domi_omg +///@brief DATA node type +/// const std::string DATA_TYPE = "Data"; -/** - * @ingroup domi_omg - * @brief DATA node type - */ +/// +///@ingroup domi_omg +///@brief DATA node type +/// const std::string AIPP_DATA_TYPE = "AippData"; -/** - * @ingroup domi_omg - * @brief Frame operator type - */ +/// +///@ingroup domi_omg +///@brief Frame operator type +/// const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; -/** - * @ingroup domi_omg - * @brief Data node type - */ +/// +///@ingroup domi_omg +///@brief Data node type +/// const std::string ANN_DATA_TYPE = "AnnData"; const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; const std::string ANN_CONV_TYPE = "AnnConvolution"; const std::string ANN_FC_TYPE = "AnnFullConnection"; -/** - * @ingroup domi_omg - * @brief Convolution node type - */ +/// +///@ingroup domi_omg +///@brief Convolution node type +/// const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; -/** - * @ingroup domi_omg - * @brief Convolution node type - */ +/// +///@ingroup domi_omg +///@brief Convolution node type +/// const std::string OP_TYPE_CONVOLUTION = "Convolution"; -/** - * @ingroup domi_omg - * @brief Add convolution node name to AIPP - */ +/// +///@ingroup domi_omg +///@brief Add convolution node name to AIPP +/// const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; -/** - * @ingroup domi_omg - * @brief Operator configuration item separator - */ +/// +///@ingroup domi_omg +///@brief Operator configuration item separator +/// const std::string OP_CONF_DELIMITER = ":"; -/** - * @ingroup domi_omg - * @brief attr value name - */ +/// +///@ingroup domi_omg +///@brief attr value name +/// const std::string ATTR_NAME_VALUE1 = "value1"; -/** - * @ingroup domi_omg - * @brief attr value name, 6d_2_4d C - */ +/// +///@ingroup domi_omg +///@brief attr value name, 6d_2_4d C +/// const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; -/** - * @ingroup domi_omg - * @brief alpha default value - */ +/// +///@ingroup domi_omg +///@brief alpha default value +/// const float ALPHA_DEFAULT_VALUE = 1.0; -/** - * @ingroup domi_omg - * @brief beta default value - */ +/// +///@ingroup domi_omg +///@brief beta default value +/// const float BETA_DEFAULT_VALUE = 0.0; -/** - * @ingroup domi_omg - * @brief coef default value - */ +/// +///@ingroup domi_omg +///@brief coef default value +/// const float COEF_DEFAULT_VALUE = 0.0; -/** - * @ingroup domi_omg - * @brief Relu6 coef value - */ +/// +///@ingroup domi_omg +///@brief Relu6 coef value +/// const float RELU6_COEF = 6.0; -/** - * @ingroup domi_omg - * @brief stride default value - */ +/// +///@ingroup domi_omg +///@brief stride default value +/// const uint32_t STRIDE_DEFAULT_VALUE = 1; -/** - * @ingroup domi_omg - * @brief pad default value - */ +/// +///@ingroup domi_omg +///@brief pad default value +/// const uint32_t PAD_DEFAULT_VALUE = 0; -/** - * @ingroup domi_omg - * @brief dilation default value - */ +/// +///@ingroup domi_omg +///@brief dilation default value +/// const int DILATION_DEFAULT_VALUE = 1; -/** - * @ingroup domi_omg - * @brief kernel default value - */ +/// +///@ingroup domi_omg +///@brief kernel default value +/// const uint32_t KERNEL_DEFAULT_VALUE = 0; -/** - * @ingroup domi_omg - * @brief defaule convolution group size - */ +/// +///@ingroup domi_omg +///@brief defaule convolution group size +/// const uint32_t DEFAULT_CONV_GROUP = 1; -/** - * @ingroup domi_omg - * @brief Default deconvolution adj - */ +/// +///@ingroup domi_omg +///@brief Default deconvolution adj +/// const uint32_t DEFAULT_DECONV_ADJ = 0; -/** - * @ingroup domi_omg - * @brief Represents value 1 - */ +/// +///@ingroup domi_omg +///@brief Represents value 1 +/// const uint32_t NUM_ONE = 1; -/** - * @ingroup domi_omg - * @brief spatial dim size default value - */ +/// +///@ingroup domi_omg +///@brief spatial dim size default value +/// const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; -/** - * @ingroup domi_omg - * @brief dim extended default value - */ +/// +///@ingroup domi_omg +///@brief dim extended default value +/// const int32_t DIM_DEFAULT_VALUE = 1; -/** - * @ingroup domi_omg - * @brief The first weight list in opdef is filter - */ +/// +///@ingroup domi_omg +///@brief The first weight list in opdef is filter +/// const int32_t WEIGHT_FILTER_INDEX = 0; -/** - * @ingroup domi_omg - * @brief The second weight list in opdef is bias - */ +/// +///@ingroup domi_omg +///@brief The second weight list in opdef is bias +/// const int32_t WEIGHT_BIAS_INDEX = 1; const int32_t TENSOR_ND_SUPPORT_SIZE = 8; -/** - * @ingroup domi_omg - * @brief NCHW index default value - */ +/// +///@ingroup domi_omg +///@brief NCHW index default value +/// const uint32_t NCHW_DIM_N = 0; const uint32_t NCHW_DIM_C = 1; const uint32_t NCHW_DIM_H = 2; const uint32_t NCHW_DIM_W = 3; -/** - * @ingroup domi_omg - * @brief KCHW index default value - */ +/// +///@ingroup domi_omg +///@brief KCHW index default value +/// const uint32_t KCHW_DIM_K = 0; const uint32_t KCHW_DIM_C = 1; const uint32_t KCHW_DIM_H = 2; const uint32_t KCHW_DIM_W = 3; -/** - * @ingroup domi_omg - * @brief HWCK index default value - */ +/// +///@ingroup domi_omg +///@brief HWCK index default value +/// const uint32_t HWCK_DIM_H = 0; const uint32_t HWCK_DIM_W = 1; const uint32_t HWCK_DIM_C = 2; const uint32_t HWCK_DIM_K = 3; -/** - * @ingroup domi_omg - * @brief NHWC index default value - */ +/// +///@ingroup domi_omg +///@brief NHWC index default value +/// const uint32_t NHWC_DIM_N = 0; const uint32_t NHWC_DIM_H = 1; const uint32_t NHWC_DIM_W = 2; const uint32_t NHWC_DIM_C = 3; -/** - * @ingroup domi_omg - * @brief CHWN index default value - */ +/// +///@ingroup domi_omg +///@brief CHWN index default value +/// const uint32_t CHWN_DIM_N = 3; const uint32_t CHWN_DIM_C = 0; const uint32_t CHWN_DIM_H = 1; const uint32_t CHWN_DIM_W = 2; -/** - * @ingroup domi_omg - * @brief CHW index default value - */ +/// +///@ingroup domi_omg +///@brief CHW index default value +/// const uint32_t CHW_DIM_C = 0; const uint32_t CHW_DIM_H = 1; const uint32_t CHW_DIM_W = 2; -/** - * @ingroup domi_omg - * @brief HWC index default value - */ +/// +///@ingroup domi_omg +///@brief HWC index default value +/// const uint32_t HWC_DIM_H = 0; const uint32_t HWC_DIM_W = 1; const uint32_t HWC_DIM_C = 2; -/** - * @ingroup domi_omg - * @brief Pad index default value - */ +/// +///@ingroup domi_omg +///@brief Pad index default value +/// const uint32_t PAD_H_HEAD = 0; const uint32_t PAD_H_TAIL = 1; const uint32_t PAD_W_HEAD = 2; const uint32_t PAD_W_TAIL = 3; -/** - * @ingroup domi_omg - * @brief window index default value - */ +/// +///@ingroup domi_omg +///@brief window index default value +/// const uint32_t WINDOW_H = 0; const uint32_t WINDOW_W = 1; -/** - * @ingroup domi_omg - * @brief stride index default value - */ +/// +///@ingroup domi_omg +///@brief stride index default value +/// const uint32_t STRIDE_H = 0; const uint32_t STRIDE_W = 1; -/** - * @ingroup domi_omg - * @brief dilation index default value - */ +/// +///@ingroup domi_omg +///@brief dilation index default value +/// const uint32_t DILATION_H = 0; const uint32_t DILATION_W = 1; -/** - * @ingroup domi_omg - * @brief the num of XRBG channel - */ +/// +///@ingroup domi_omg +///@brief the num of XRBG channel +/// const uint32_t XRGB_CHN_NUM = 4; -/** - * @ingroup domi_omg - * @brief global pooling default value - */ +/// +///@ingroup domi_omg +///@brief global pooling default value +/// const bool DEFAULT_GLOBAL_POOLING = false; -const uint32_t MODEL_VERSION = 0x10000000; /**< Model version 1.0 */ +const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// // Eltwise's input size const int ELTWISE_MIN_INPUT_SIZE = 2; -/* flowctrl */ +// flowctrl const std::string NODE_NAME_STREAM_SWITCH = "IteratorCtrl_StreamSwitch"; const std::string NODE_NAME_STREAM_ACTIVE = "IteratorCtrl_StreamActive"; const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER = "npu_runconfig/iterations_per_loop"; @@ -792,4 +793,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; -}; // namespace domi +}; // namespace ge diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index 44a8586d..b53a1c43 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -27,12 +27,13 @@ #include #include -#include "google/protobuf/io/coded_stream.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "framework/common/fmk_types.h" +#include "external/ge/ge_api_error_codes.h" +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_types.h" #include "framework/common/ge_inner_error_codes.h" -#include "external/ge/ge_api_error_codes.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" #include "mmpa/mmpa_api.h" using google::protobuf::io::CodedInputStream; @@ -57,7 +58,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M const int kMaxFileSizeLimit = INT_MAX; } // namespace -namespace domi { +namespace ge { static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, return false, "incorrect parameter. nullptr == proto"); @@ -111,13 +112,18 @@ long GetFileLength(const std::string &input_file) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); unsigned long long file_length = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, return -1, - "open file failed."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, + ErrorManager::GetInstance().ATCReportErrMessage("E10037"); + return -1, "open file failed."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), return -1, "file length == 0, not valid."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), ErrorManager::GetInstance().ATCReportErrMessage("E10038"); + return -1, "file length is 0, not valid."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_length > kMaxFileSizeLimit, return -1, "file size %lld is out of limit: %d.", - file_length, kMaxFileSizeLimit); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + file_length > kMaxFileSizeLimit, + ErrorManager::GetInstance().ATCReportErrMessage("E10039", {"filesize", "maxlen"}, + {std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); + return -1, "file size %lld is out of limit: %d.", file_length, kMaxFileSizeLimit); return static_cast(file_length); } @@ -196,7 +202,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= PATH_MAX) { - GELOGE(ge::FAILED, "Directory path is too long."); + GELOGW("Directory path is too long."); return -1; } char tmp_dir_path[PATH_MAX] = {0}; @@ -207,7 +213,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", + GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); return ret; } @@ -218,8 +224,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(const_cast(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", - directory_path.c_str()); + GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -247,21 +252,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch "incorrect parameter. nullptr == file || nullptr == message"); std::string real_path = RealPath(file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "proto file path '%s' not valid", file); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), + ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"realpath"}, {file}); + return false, "proto file real path '%s' not valid", file); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); std::ifstream fs(real_path.c_str(), std::ifstream::in); if (!fs.is_open()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"protofile"}, {file}); GELOGE(ge::FAILED, "Fail to open proto file '%s'.", file); return false; } google::protobuf::io::IstreamInputStream input(&fs); bool ret = google::protobuf::TextFormat::Parse(&input, message); - GE_IF_BOOL_EXEC( - !ret, GELOGE(ret, "Call [google::protobuf::TextFormat::Parse] func ret fail, please check your text file.")); + GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"protofile"}, {file}); + GELOGE(ret, + "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " + "please check whether the file is a valid protobuf format file.", + file)); fs.close(); return ret; @@ -336,10 +347,20 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char return res; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path, + const std::string &atc_param) { // The specified path is empty + std::map args_map; if (file_path.empty()) { - GELOGE(ge::FAILED, "Path is empty."); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); + GELOGW("Input parameter's value is empty."); + return false; + } + std::string real_path = RealPath(file_path.c_str()); + // Unable to get absolute path (does not exist or does not have permission to access) + if (real_path.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); + GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); return false; } @@ -350,52 +371,54 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ValidateStr(file_path, mode), return false, - "input [%s] is illegal. path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese; filename can " - "only contains 'a-z' 'A-Z' '0-9' '_' '.' '+' '-' and chinese", - file_path.c_str()); - - std::string real_path = RealPath(file_path.c_str()); - // Unable to get absolute path (does not exist or does not have permission to access) - if (real_path.empty()) { - GELOGE(ge::FAILED, "Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); - return false; - } + !ValidateStr(real_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); + return false, + "Input parameter's value[%s] is illegal. The path[%s] can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " + "and chinese character.", + atc_param.c_str(), real_path.c_str()); // The absolute path points to a file that is not readable if (access(real_path.c_str(), R_OK) != 0) { - GELOGE(ge::FAILED, "Can not read file in %s, %s", file_path.c_str(), strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); + GELOGW("Read path[%s] failed, %s", file_path.c_str(), strerror(errno)); return false; } return true; } -FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path, + const std::string &atc_param) { // The specified path is empty if (file_path.empty()) { - GELOGE(ge::FAILED, "Path is empty."); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); + GELOGW("Input parameter's value is empty."); return false; } - // A regular matching expression to verify the validity of the input file path - // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ - // Path section:Support upper and lower case letters, numbers dots(.) chinese and underscores - // File name section:Support upper and lower case letters, numbers, underscores chinese and dots(.) - std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ValidateStr(file_path, mode), return false, - "output [%s] is illegal. path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese; filename can " - "only contains 'a-z' 'A-Z' '0-9' '_' '.' '+' '-' and chinese", - file_path.c_str()); - std::string real_path = RealPath(file_path.c_str()); // Can get absolute path (file exists) if (!real_path.empty()) { + // A regular matching expression to verify the validity of the input file path + // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ + // Path section:Support upper and lower case letters, numbers dots(.) chinese and underscores + // File name section:Support upper and lower case letters, numbers, underscores chinese and dots(.) + std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ValidateStr(real_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); + return false, + "Input parameter's value[%s] is illegal. The path[%s] can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " + "and chinese character.", + atc_param.c_str(), real_path.c_str()); + // File is not readable or writable - if (access(real_path.c_str(), R_OK | W_OK | F_OK) != 0) { - GELOGE(ge::FAILED, "Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); + if (access(real_path.c_str(), W_OK | F_OK) != 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"path", "errmsg"}, + {real_path.c_str(), strerror(errno)}); + GELOGW("Write file failed, path[%s], %s", real_path.c_str(), strerror(errno)); return false; } } else { @@ -413,7 +436,8 @@ FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); // Determine whether the specified path is valid by creating the path if (CreateDirectory(prefix_path) != 0) { - GELOGE(ge::FAILED, "Can not create prefix path for path[ %s ].", file_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"path"}, {file_path}); + GELOGW("Can not create prefix path for path[%s].", file_path.c_str()); return false; } } @@ -436,4 +460,4 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str return true; #endif } -} // namespace domi +} // namespace ge diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index 221b1045..8512904c 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -36,7 +36,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/davinci_model.cc" "../graph/load/new_model_manager/davinci_model_parser.cc" "../graph/load/new_model_manager/model_manager.cc" - "../graph/load/new_model_manager/model_output.cc" "../graph/load/new_model_manager/model_utils.cc" "../graph/load/new_model_manager/task_info/end_graph_task_info.cc" "../graph/load/new_model_manager/task_info/event_record_task_info.cc" @@ -45,8 +44,10 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" "../graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" "../graph/load/new_model_manager/task_info/kernel_task_info.cc" - "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" + "../graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" "../graph/load/new_model_manager/task_info/label_set_task_info.cc" + "../graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" + "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -56,6 +57,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "../graph/load/new_model_manager/task_info/task_info.cc" "../graph/load/new_model_manager/tbe_handle_store.cc" + "../graph/load/new_model_manager/zero_copy_task.cc" "../graph/load/output/output.cc" "../graph/manager/graph_manager_utils.cc" "../graph/manager/graph_mem_allocator.cc" @@ -68,6 +70,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../single_op/single_op_manager.cc" "../single_op/single_op_model.cc" "../single_op/stream_resource.cc" + "../single_op/task/aicpu_task_builder.cc" "../single_op/task/build_task_utils.cc" "../single_op/task/op_task.cc" "../single_op/task/tbe_task_builder.cc" diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index 120187cc..92529598 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -165,12 +165,13 @@ namespace ge { bool GeExecutor::isInit_ = false; class ModelListenerAdapter : public ModelListener { public: - domi::Status OnComputeDone(uint32_t model_id, uint32_t dataIndex, uint32_t resultCode) { + domi::Status OnComputeDone(uint32_t model_id, uint32_t dataIndex, uint32_t resultCode, + std::vector &outputs) { if (listener == nullptr) { GELOGE(ge::FAILED, "listener is null."); return FAILED; } - return listener->OnComputeDone(model_id, dataIndex, resultCode); + return listener->OnComputeDone(model_id, dataIndex, resultCode, outputs); } std::shared_ptr listener; @@ -193,15 +194,8 @@ Status GeExecutor::Initialize() { } // Start profiling - int32_t device_id = 0; - rtError_t rt_ret = rtGetDevice(&device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); - return FAILED; - } - GELOGI("current device_id:%d", device_id); Options profiling_options; - profiling_options.device_id = device_id; + profiling_options.device_id = 0; profiling_options.job_id = ""; ProfilingManager::Instance().Init(profiling_options); @@ -218,7 +212,11 @@ Status GeExecutor::Finalize() { } // Stop profiling - ProfilingManager::Instance().StopProfiling(); + if (ProfilingManager::Instance().ProfilingOn()) { + ProfilingManager::Instance().StopProfiling(); + ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE); + } + GELOGI("Uninit ge_executor over."); return ge::SUCCESS; } @@ -352,7 +350,7 @@ Status GeExecutor::LoadModelOffline(uint32_t &model_id, const std::string &path, return GE_EXEC_NOT_INIT; } - string filePath = domi::RealPath(path.c_str()); + string filePath = RealPath(path.c_str()); if (filePath.empty()) { GELOGE(ge::FAILED, "fileath is invalid. please check your text file '%s'.", path.c_str()); return ge::FAILED; @@ -402,10 +400,10 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { GELOGE(GE_EXEC_NOT_INIT, "not inited yet!"); return GE_EXEC_NOT_INIT; } - - // stop profiling - if (!ProfilingManager::Instance().ProfilingOpTraceOn() && ProfilingManager::Instance().ProfilingOn()) { - ProfilingManager::Instance().StopProfiling(); + Status ret = GraphLoader::DestroyAicpuSessionForInfer(model_id); + if (ret != SUCCESS) { + GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed."); + return FAILED; } return GraphLoader::UnloadModel(model_id); } @@ -565,7 +563,7 @@ Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_da return GE_EXEC_NOT_INIT; } - string filePath = domi::RealPath(path.c_str()); + string filePath = RealPath(path.c_str()); if (filePath.empty()) { GELOGE(ge::FAILED, "filePath is invalid. please check your text file '%s'.", path.c_str()); return ge::FAILED; diff --git a/src/ge/ge_local_engine/CMakeLists.txt b/src/ge/ge_local_engine/CMakeLists.txt index 80a3c335..e685c301 100755 --- a/src/ge/ge_local_engine/CMakeLists.txt +++ b/src/ge/ge_local_engine/CMakeLists.txt @@ -15,13 +15,14 @@ # libge_local_engine.so # add all proto files, generate corresponding .h and .cc files -file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../../proto/task.proto" ) -file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "engine/ge_local_engine.cc" "ops_kernel_store/*.cc" + "ops_kernel_store/op/*.cc" ) ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) @@ -34,6 +35,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external) include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/graph) +include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.cc b/src/ge/ge_local_engine/engine/host_cpu_engine.cc index c4fe9ea1..9ee616ac 100644 --- a/src/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/src/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -237,7 +237,7 @@ Status HostCpuEngine::LoadLib(const std::string &lib_path) { } Status HostCpuEngine::GetRealPath(std::string &path) { - std::string real_path = domi::RealPath(path.c_str()); + std::string real_path = RealPath(path.c_str()); if (real_path.empty()) { GELOGW("File path %s is invalid.", path.c_str()); return INTERNAL_ERROR; diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.h b/src/ge/ge_local_engine/engine/host_cpu_engine.h index 88985f87..1987138d 100644 --- a/src/ge/ge_local_engine/engine/host_cpu_engine.h +++ b/src/ge/ge_local_engine/engine/host_cpu_engine.h @@ -21,7 +21,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" #include "graph/operator.h" -#include "register/register.h" +#include "inc/register/register.h" namespace ge { class HostCpuEngine { diff --git a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index 4eae65c5..cde6640f 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -17,8 +17,6 @@ #include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" #include #include "common/constant/constant.h" -#include "framework/common/debug/ge_log.h" -#include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index 6a327bb8..0f33ae2a 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -56,10 +56,10 @@ REGISTER_OP_CREATOR(IsVariableInitialized, GeDeletedOp); REGISTER_OP_CREATOR(PlaceholderV2, GeDeletedOp); REGISTER_OP_CREATOR(Placeholder, GeDeletedOp); REGISTER_OP_CREATOR(End, GeDeletedOp); -REGISTER_OP_CREATOR(Merge, GeDeletedOp); REGISTER_OP_CREATOR(Switch, GeDeletedOp); REGISTER_OP_CREATOR(SwitchN, GeDeletedOp); REGISTER_OP_CREATOR(RefMerge, GeDeletedOp); REGISTER_OP_CREATOR(RefSwitch, GeDeletedOp); +REGISTER_OP_CREATOR(TransShape, GeDeletedOp); } // namespace ge_local } // namespace ge diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc index 58777e53..d595be8d 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/op/no_op.cc @@ -41,10 +41,10 @@ REGISTER_OP_CREATOR(Constant, NoOp); REGISTER_OP_CREATOR(Const, NoOp); -REGISTER_OP_CREATOR(NetOutput, NoOp); - REGISTER_OP_CREATOR(ControlTrigger, NoOp); +REGISTER_OP_CREATOR(Merge, NoOp); + // Functional Op. REGISTER_OP_CREATOR(If, NoOp); REGISTER_OP_CREATOR(_If, NoOp); diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index b60df61d..329c6683 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -18,10 +18,10 @@ #include #include "./model_context.h" #include "./task/task.h" -#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "common/util.h" +#include "framework/common/debug/ge_log.h" #include "framework/common/op/op_parser_util.h" #include "graph/types.h" #include "task/task_factory.h" @@ -202,7 +202,8 @@ bool RuntimeModel::LoadTask() { } uint32_t task_id = 0; - rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id); + uint32_t stream_id = 0; + rtError_t rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X.", rt_ret); return false; diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index 3cc067c5..d4a33eec 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -28,12 +28,6 @@ #include "graph/utils/graph_utils.h" #include "model/ge_model.h" -using domi::DATA; -using domi::ModelHelper; -using domi::NETOUTPUT; -using domi::NODE_NAME_NET_OUTPUT; -using domi::SaveParam; -using ge::ModelBufferData; using std::map; using std::string; using std::vector; @@ -106,7 +100,7 @@ static void GetOpsProtoPath(string &opsproto_path) { const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { string path = path_env; - string file_path = domi::RealPath(path.c_str()); + string file_path = RealPath(path.c_str()); if (file_path.empty()) { GELOGE(FAILED, "File path %s is invalid.", path.c_str()); return; @@ -132,6 +126,8 @@ class GeGenerator::Impl { Status SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs); + Status GenerateInfershapeGraph(const Graph &graph, GraphId &graph_id); + GraphManager graph_manager_; SaveParam save_param_; bool is_offline_ = true; @@ -148,7 +144,7 @@ Status GeGenerator::Initialize(const map &options) { GELOGI("opsproto_path is %s", opsproto_path.c_str()); OpsProtoManager *manager = OpsProtoManager::Instance(); map option_tmp; - option_tmp.insert(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); + option_tmp.emplace(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); (void)manager->Initialize(option_tmp); Status ret = impl_->graph_manager_.Initialize(options); @@ -200,6 +196,22 @@ Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vectorGenerateInfershapeGraph(graph, graph_id); + if (ret != SUCCESS) { + GELOGE(ret, "Dump infershape json failed"); + if (impl_->graph_manager_.Finalize() != SUCCESS) { + GELOGE(FAILED, "graph_manager finalize fail."); + } + return ret; + } + GELOGI("GenerateInfershapeJson success."); + return SUCCESS; +} + Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { GraphId graph_id; @@ -260,10 +272,11 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector op_attrs = op_desc->GetAllAttrs(); + OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); + GE_CHECK_NOTNULL(op_desc_tmp); // 1. Create ComputeGraph. - string name = domi::CurrentTimeInStr() + "_" + model_file_name; + string name = ge::CurrentTimeInStr() + "_" + model_file_name; ge::ComputeGraphPtr compute_graph = MakeShared(name); if (compute_graph == nullptr) { return INTERNAL_ERROR; @@ -273,11 +286,6 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vectorAddNode(op_desc); GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); - - // 3. Create InputData node. - int64_t in_size = static_cast(op_desc->GetInputsSize()); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(op_desc, ATTR_NAME_N, in_size), return FAILED, "Op[%s] Set N fail", - op_desc->GetName().c_str()); int32_t arg_index = 0; if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { @@ -301,7 +309,7 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vectorDump(); Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); - GELOGI("ATC parser success."); + GELOGI("ATC parser success in single op schedule."); GraphId graph_id; vector ge_models; @@ -309,7 +317,9 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vectorBuildModel(graph, inputs, graph_id, ge_models)); if (!ge_models.empty()) { - GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_models[0], op_desc->GetType(), op_attrs, inputs, outputs)); + map op_attrs = op_desc_tmp->GetAllAttrs(); + GELOGI("The opType in op_desc_tmp is: %s", op_desc_tmp->GetType().c_str()); + GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_models[0], op_desc_tmp->GetType(), op_attrs, inputs, outputs)); } ModelBufferData model_buff; GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_models, model_buff)); @@ -367,4 +377,27 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector return SUCCESS; } + +Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &graph_id) { + static GraphId id = 0; + const std::map options; + Status ret = graph_manager_.AddGraph(id, graph, options); + if (ret != SUCCESS) { + GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager AddGraph failed, id: %u", id); + graph_manager_.Finalize(); + return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; + } + + ret = graph_manager_.GenerateInfershapeGraph(id); + if (ret != SUCCESS) { + GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "graphManager BuildGraph failed, id: %u", id); + return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; + } + + graph_id = id; + id += 1; + + return SUCCESS; +} + } // namespace ge diff --git a/src/ge/generator/generator_api.cc b/src/ge/generator/generator_api.cc index 094baab8..3f92f1a2 100644 --- a/src/ge/generator/generator_api.cc +++ b/src/ge/generator/generator_api.cc @@ -116,7 +116,7 @@ Status_t OpTaskGernerator(const char *op_type, const OpTensor_t *in_tensor, int CHECK_PARAM_NOT_NULL(om_file); const std::string om_file_name(om_file); - std::string op_name = std::string(op_type) + "_" + std::to_string(domi::GetCurrentTimestap()); + std::string op_name = std::string(op_type) + "_" + std::to_string(ge::GetCurrentTimestap()); ge::OpDescPtr op_desc = ge::MakeShared(op_name, op_type); if (op_desc == nullptr) { return ge::FAILED; diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index de222c8c..8e556ff2 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -18,18 +18,15 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/opskernel/ops_kernel_info_types.h" -#include "graph/build/stream_graph_optimizer.h" #include "graph/build/run_context.h" +#include "graph/build/stream_graph_optimizer.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "model/ge_model.h" -using domi::ATTR_MODEL_MEMORY_SIZE; -using domi::ATTR_MODEL_WEIGHT_SIZE; using domi::BuildMode; -using domi::DATA; namespace { const int32_t kInvalidPerfLevel = -1; @@ -101,8 +98,10 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vectorGetAllSubgraphs()) { - GraphUtils::DumpGEGraphToOnnx(*graph, "SubgraphGetTask"); - } - GE_TIMESTAMP_START(GetTaskInfo); - ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_ptr_list, session_id); + ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); @@ -147,6 +142,11 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vectorGetAllSubgraphs()) { + GraphUtils::DumpGEGraphToOnnx(*graph, "SubgraphGetTask"); + } + ge_model_ptr = MakeShared(); if (ge_model_ptr == nullptr) { return MEMALLOC_FAILED; @@ -158,7 +158,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, uint64_t session_id) { GE_CHECK_NOTNULL(model_ptr); GE_CHECK_NOTNULL(comp_graph); @@ -173,7 +173,8 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr GELOGE(INTERNAL_ERROR, "Get weight memory size fail."); return INTERNAL_ERROR; } - auto *get_mem_base = reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemMaxSize()); + auto *get_mem_base = + reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemMaxSize())); uint8_t *get_weight_mem_base = get_mem_base; if (weight_size > 0) { get_weight_mem_base = get_mem_base + memory_size; @@ -193,7 +194,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr } StreamGraphOptimizer stream_optimizer; - ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_ptr_list, run_context.GetRunContext()); + ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_map, run_context.GetRunContext()); if (ret != SUCCESS) { GELOGE(ret, "Optimize streamed subGraph fail."); return ret; @@ -202,7 +203,8 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr GraphUtils::DumpGEGraph(comp_graph, "AfterOptimizeStreamedSubGraph"); GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterOptimizeStreamedSubGraph"); - auto *get_var_mem_base = reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemLogicBase()); + auto *get_var_mem_base = + reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemLogicBase())); uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) ? ge::VarManager::Instance(0)->GetVarMemMaxSize() : 0; diff --git a/src/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h index c1c4f7b6..d0bf26e6 100644 --- a/src/ge/graph/build/graph_builder.h +++ b/src/ge/graph/build/graph_builder.h @@ -53,7 +53,7 @@ class GraphBuilder { private: Status CalcOpParam(const ge::ComputeGraphPtr &graph); Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, - std::vector &subgraph_ptr_list, uint64_t session_id = INVALID_SESSION_ID); + Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); Status SetInputSize(const ge::NodePtr &node_ptr); Status UpdateDataInputSize(const ge::NodePtr &node_ptr); Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector &subgraph_ptr_list); diff --git a/src/ge/graph/build/logical_stream_allocator.cc b/src/ge/graph/build/logical_stream_allocator.cc index d57d5ac5..ff33e3b7 100644 --- a/src/ge/graph/build/logical_stream_allocator.cc +++ b/src/ge/graph/build/logical_stream_allocator.cc @@ -15,23 +15,20 @@ */ #include "graph/build/logical_stream_allocator.h" +#include #include "common/ge/ge_util.h" -#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" using std::map; +using std::queue; using std::set; using std::string; using std::vector; -using domi::ATTR_NAME_STREAM_LABEL; -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::HCOMALLREDUCE; - namespace { const char *const kAICPUEngineName = "DNN_VM_AICPU"; const char *const kAttrNameParentOpType = "parentOpType"; @@ -75,7 +72,7 @@ bool LogicalStreamPass::HasNonConstInputNode(const Subgraph &subgraph) const { return false; } -Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { +Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; map label_streams; @@ -102,7 +99,7 @@ Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { +Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; @@ -134,8 +131,7 @@ Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, - Context &context) { +Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; if (IsHeadNodeExceeded(subgraphs)) { int64_t &next_stream = context.next_stream; @@ -303,7 +299,7 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { subgraph->stream_id = stream_id; engine_next_streams_[engine_name] = stream_id + 1; - assigned_subgraphs_.emplace(subgraph); + assigned_subgraphs_.emplace_back(subgraph); if ((stream_id + 1) > engine_stream_num_[engine_name]) { engine_stream_num_[engine_name] = stream_id + 1; @@ -316,6 +312,15 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { } void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { + // If the default stream is valid, the first assigned stream will reuse the default stream id + // and other streams use new id. To ensure that the id of the new stream is continuous, + // we first subtract one from next_stream. + int64_t to_be_updated_stream = kInvalidStream; + if (context.default_stream != kInvalidStream) { + context.next_stream--; + to_be_updated_stream = context.next_stream; + } + // Update the starting stream id for each engine. int64_t &next_stream = context.next_stream; map engine_start_streams; @@ -325,10 +330,16 @@ void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { next_stream += stream_count; } - // Update the subgraphs assigned by the engine. + // Update the subgraph streams assigned by engine. for (auto &subgraph : assigned_subgraphs_) { subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; - GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); + if (subgraph->stream_id == to_be_updated_stream) { + subgraph->stream_id = context.default_stream; + GELOGI("Subgraph %s of engine %s reuses default stream %ld.", subgraph->name.c_str(), + subgraph->engine_conf.id.c_str(), context.default_stream); + } else { + GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); + } } } @@ -342,7 +353,30 @@ void AssignByDependencyPass::UpdateReusedSubgraphs() { } } -Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { +Status SingleStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + // context.default_stream can be kInvalidStream only when graph is the root graph. + int64_t new_stream = context.default_stream; + if (new_stream == kInvalidStream) { + new_stream = context.next_stream; + ++context.next_stream; + } + + for (const SubgraphPtr &subgraph : subgraphs) { + if (!HasAssignedStream(*subgraph)) { + const string &stream_label = subgraph->subgraph_info.GetStreamLabel(); + if (!stream_label.empty()) { + GELOGE(INTERNAL_ERROR, "Stream labels are not supported (subgraph: %s, stream label: %s).", + subgraph->name.c_str(), stream_label.c_str()); + return INTERNAL_ERROR; + } + subgraph->stream_id = new_stream; + } + } + + return SUCCESS; +} + +Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { // Check if all subgraphs have been assigned a stream. for (const SubgraphPtr &subgraph : subgraphs) { const string &engine_name = subgraph->engine_conf.id; @@ -358,7 +392,7 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vectorGetDirectNode()) { + for (NodePtr &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); node->GetOpDesc()->SetStreamId(kInvalidStream); } @@ -375,81 +409,18 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vectorGetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str()); } else { node->GetOpDesc()->SetStreamId(stream_id); + GELOGD("Node %s of type %s in subgraph %s is assigned stream %ld (engine: %s).", node->GetName().c_str(), + node->GetType().c_str(), subgraph->name.c_str(), stream_id, engine_name.c_str()); } } } // Update stream id for nodes belong to skipped engine subgraph - GE_CHK_STATUS_RET(UpdateForSkippedEngine(whole_graph, subgraphs)); - - RefreshContinuousStreams(whole_graph, context); + GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs)); return SUCCESS; } -Status AllReduceParallelPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { - if (!context.hcom_parallel) { - return NOT_CHANGED; - } - - GELOGI("AllReduceParallelPass is enabled."); - GraphUtils::DumpGEGraph(whole_graph, "BeforeAllReduceParallel"); - - // All successors of HcomAllReduce. - set all_reduce_succs; - - for (const NodePtr &node : whole_graph->GetDirectNode()) { - if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { - continue; - } - - string reduce_stream_label; - GE_CHECK_NOTNULL(node->GetOpDesc()); - // ATTR_NAME_STREAM_LABEL is optional. - (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); - - set cur_nodes = {node}; - while (!cur_nodes.empty()) { - set all_out_data_nodes; - for (auto &curr_node : cur_nodes) { - for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { - string out_stream_label; - GE_CHECK_NOTNULL(out_node->GetOpDesc()); - // ATTR_NAME_STREAM_LABEL is optional. - (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); - if (out_stream_label == reduce_stream_label) { - all_reduce_succs.emplace(out_node); - all_out_data_nodes.emplace(out_node); - } - } - } - cur_nodes = all_out_data_nodes; - } - } - - map old_stream_to_new; - for (const NodePtr &node : all_reduce_succs) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - auto old_stream = node->GetOpDesc()->GetStreamId(); - if (old_stream != kInvalidStream) { - int64_t new_stream = kInvalidStream; - auto iter = old_stream_to_new.find(old_stream); - if (iter != old_stream_to_new.end()) { - new_stream = iter->second; - } else { - new_stream = context.next_stream; - context.next_stream++; - old_stream_to_new.emplace(old_stream, new_stream); - } - - GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); - node->GetOpDesc()->SetStreamId(new_stream); - } - } - - return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; -} - int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { set stream_ids; @@ -460,6 +431,7 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { stream_ids.insert(stream_id); } } + for (const auto &out_node : node->GetOutAllNodes()) { GE_CHECK_NOTNULL_EXEC(out_node->GetOpDesc(), return kInvalidStream); int64_t stream_id = out_node->GetOpDesc()->GetStreamId(); @@ -467,9 +439,10 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { stream_ids.insert(stream_id); } } + if (stream_ids.size() == 1) { int64_t stream_id = *(stream_ids.begin()); - GELOGI("Node %s of type %s: its all input and output nodes are in same stream[%ld].", node->GetName().c_str(), + GELOGI("The stream of all input and output nodes of node %s (type: %s) is %ld.", node->GetName().c_str(), node->GetType().c_str(), stream_id); return stream_id; } @@ -477,43 +450,46 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { return kInvalidStream; } -Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, +Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, const vector &subgraphs) { - set nodes_to_be_updated; + set ops_without_label; - // Check if sub graph is engine skipped and without stream label or not + // Check if subgraph is engine skipped and without stream label or not for (const SubgraphPtr &subgraph : subgraphs) { - if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { - auto graph = subgraph->subgraph_info.GetSubGraph(); - for (NodePtr &node : graph->GetDirectNode()) { + if (IsEngineSkip(*subgraph)) { + auto compute_graph = subgraph->subgraph_info.GetSubGraph(); + for (NodePtr &node : compute_graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); auto stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream) { - nodes_to_be_updated.insert(op_desc); + if (stream_id != kInvalidStream && !HasStreamLabel(*subgraph)) { + ops_without_label.emplace(op_desc); } } } } // Try reassign the stream id - for (ge::NodePtr &node : whole_graph->GetDirectNode()) { + for (ge::NodePtr &node : graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); int64_t stream_id = op_desc->GetStreamId(); - if (nodes_to_be_updated.find(op_desc) != nodes_to_be_updated.end()) { - if (AreAllPredStreamsInvalid(node)) { + if (ops_without_label.find(op_desc) != ops_without_label.end()) { + if (AreAllPredStreamsInvalid(node) && op_desc->GetSubgraphInstanceNames().empty()) { op_desc->SetStreamId(kInvalidStream); - } else { + GELOGI("Node %s of type %s reassign to stream %ld from stream %ld.", node->GetName().c_str(), + node->GetType().c_str(), kInvalidStream, stream_id); + } else if (!node->GetOutAllNodes().empty()) { int64_t inout_stream = GetSingleInoutStream(node); if (inout_stream != kInvalidStream) { op_desc->SetStreamId(inout_stream); - GELOGI("Node %s of type %s reassign to stream id[%ld] from stream id[%ld].", node->GetName().c_str(), + GELOGI("Node %s of type %s reassign to stream %ld from stream %ld.", node->GetName().c_str(), node->GetType().c_str(), inout_stream, stream_id); } } } } + return SUCCESS; } @@ -530,51 +506,79 @@ bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { return true; } -void NodeStreamUpdatePass::RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const { - int64_t stream_num = context.next_stream; - vector stream_has_node(stream_num); +Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + if (!context.enable_hcom_parallel) { + return NOT_CHANGED; + } - for (const NodePtr &node : whole_graph->GetDirectNode()) { - if (node != nullptr) { - auto op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream && stream_id < stream_num) { - stream_has_node[stream_id] = true; - } - } + GELOGI("AllReduceParallelPass is enabled."); + GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); + + // All successors of HcomAllReduce. + set all_reduce_succs; + + for (const NodePtr &node : graph->GetDirectNode()) { + if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { + continue; } - } - context.next_stream = 0; - vector old_to_new_streams(stream_num, kInvalidStream); - for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { - if (stream_has_node[old_stream]) { - old_to_new_streams[old_stream] = context.next_stream; - ++context.next_stream; + string reduce_stream_label; + GE_CHECK_NOTNULL(node->GetOpDesc()); + (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); + + set cur_nodes = {node}; + while (!cur_nodes.empty()) { + set all_out_data_nodes; + for (auto &curr_node : cur_nodes) { + for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { + string out_stream_label; + GE_CHECK_NOTNULL(out_node->GetOpDesc()); + (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); + if (out_stream_label == reduce_stream_label) { + all_reduce_succs.emplace(out_node); + all_out_data_nodes.emplace(out_node); + } + } + } + cur_nodes = all_out_data_nodes; } } - for (const NodePtr &node : whole_graph->GetDirectNode()) { - auto op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream && stream_id < stream_num) { - op_desc->SetStreamId(old_to_new_streams[stream_id]); + map old_stream_to_new; + for (const NodePtr &node : all_reduce_succs) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto old_stream = node->GetOpDesc()->GetStreamId(); + if (old_stream != kInvalidStream) { + int64_t new_stream = kInvalidStream; + auto iter = old_stream_to_new.find(old_stream); + if (iter != old_stream_to_new.end()) { + new_stream = iter->second; + } else { + new_stream = context.next_stream; + context.next_stream++; + old_stream_to_new.emplace(old_stream, new_stream); } + + GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); + node->GetOpDesc()->SetStreamId(new_stream); } } + + return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; } LogicalStreamAllocator::LogicalStreamAllocator(const map &scheduler_confs, - const map &max_parallel_num, bool hcom_parallel) - : scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) { - context_.hcom_parallel = hcom_parallel; -} + const map &max_parallel_num) + : scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) {} + +void LogicalStreamAllocator::EnableSingleStream(bool enable) { context_.enable_single_stream = enable; } + +void LogicalStreamAllocator::EnableHcomParallel(bool enable) { context_.enable_hcom_parallel = enable; } -Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const vector &subgraph_infos, +Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &root_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num) { - GE_CHECK_NOTNULL(whole_graph); + GE_CHECK_NOTNULL(root_graph); + map engine_confs; GE_TIMESTAMP_START(InitEngineConfs); for (const auto &item : scheduler_confs_) { @@ -588,16 +592,64 @@ Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const } GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); + Status status = DoAssign(root_graph, subgraph_map, engine_confs); + if (status != SUCCESS) { + GELOGE(status, "Assign streams failed."); + return status; + } + + vector subgraphs = root_graph->GetAllSubgraphs(); + for (const ComputeGraphPtr &subgraph : subgraphs) { + Status status = DoAssign(subgraph, subgraph_map, engine_confs); + if (status != SUCCESS) { + GELOGE(status, "Assign streams failed."); + return status; + } + } + + RefreshContinuousStreams(root_graph); + + stream_num = context_.next_stream; + GELOGI("Assigned logical stream num: %ld.", stream_num); + + return SUCCESS; +} + +Status LogicalStreamAllocator::DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, + const map &engine_confs) { + GE_CHECK_NOTNULL(graph); + + NodePtr parent_node = graph->GetParentNode(); + if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { + context_.default_stream = kInvalidStream; + } else { + context_.default_stream = parent_node->GetOpDesc()->GetStreamId(); + } + + auto iter = subgraph_map.find(graph); + if (iter == subgraph_map.end()) { + GELOGE(FAILED, "Graph %s not found.", graph->GetName().c_str()); + return FAILED; + } + + const vector &subgraph_info_list = iter->second; vector subgraphs; GE_TIMESTAMP_START(ConvertSubgraphs); - Status status = ConvertSubgraphs(subgraph_infos, engine_confs, subgraphs); + Status status = ConvertSubgraphs(subgraph_info_list, engine_confs, subgraphs); GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); if (status != SUCCESS) { GELOGE(status, "Create subgraphs failed."); return status; } - return RunPasses(whole_graph, subgraphs, stream_num); + GELOGI("Subgraphs of graph %s:", graph->GetName().c_str()); + for (const auto &subgraph : subgraphs) { + if (subgraph != nullptr) { + GELOGI("subgraph: %s", subgraph->name.c_str()); + } + } + + return RunPasses(graph, subgraphs); } Status LogicalStreamAllocator::ConvertSubgraphs(const vector &subgraph_infos, @@ -636,19 +688,24 @@ Status LogicalStreamAllocator::ConvertSubgraphs(const vector &s return SUCCESS; } -Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, const vector &subgraphs, - int64_t &stream_num) { +Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vector &subgraphs) { vector passes; - passes.emplace_back(MakeShared()); - passes.emplace_back(MakeShared()); - passes.emplace_back(MakeShared()); - passes.emplace_back(MakeShared()); - passes.emplace_back(MakeShared()); + + if (context_.enable_single_stream) { + passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); + } else { + passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); + passes.emplace_back(MakeShared()); + } for (auto &pass : passes) { GE_CHECK_NOTNULL(pass); - Status status = pass->Run(whole_graph, subgraphs, context_); + Status status = pass->Run(graph, subgraphs, context_); if (status == SUCCESS) { GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); } else if (status == NOT_CHANGED) { @@ -659,9 +716,42 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, con } } - stream_num = context_.next_stream; - GELOGI("Assigned logical stream num: %ld.", stream_num); - return SUCCESS; } + +void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &graph) { + int64_t stream_num = context_.next_stream; + vector stream_has_node(stream_num); + + for (const NodePtr &node : graph->GetAllNodes()) { + if (node != nullptr) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream && stream_id < stream_num) { + stream_has_node[stream_id] = true; + } + } + } + } + + context_.next_stream = 0; + vector old_to_new_streams(stream_num, kInvalidStream); + for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { + if (stream_has_node[old_stream]) { + old_to_new_streams[old_stream] = context_.next_stream; + ++context_.next_stream; + } + } + + for (const NodePtr &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream && stream_id < stream_num) { + op_desc->SetStreamId(old_to_new_streams[stream_id]); + } + } + } +} } // namespace ge diff --git a/src/ge/graph/build/logical_stream_allocator.h b/src/ge/graph/build/logical_stream_allocator.h index 2265a0f3..71946630 100644 --- a/src/ge/graph/build/logical_stream_allocator.h +++ b/src/ge/graph/build/logical_stream_allocator.h @@ -60,9 +60,10 @@ class LogicalStreamPass { }; struct Context { - // Next stream id. + int64_t default_stream = kInvalidStream; int64_t next_stream = 0; - bool hcom_parallel = false; + bool enable_single_stream = false; + bool enable_hcom_parallel = false; }; explicit LogicalStreamPass(const std::string &name); @@ -71,7 +72,7 @@ class LogicalStreamPass { virtual ~LogicalStreamPass() = default; const std::string &GetName() const; - virtual Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) = 0; + virtual Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) = 0; protected: bool IsEngineSkip(const Subgraph &subgraph) const; @@ -93,21 +94,21 @@ using LogicalStreamPassPtr = std::shared_ptr; class AssignByLabelPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); - Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; // Engines such as hccl require independent Stream. class IndependentStreamPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); - Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; // Reuse streams or assign new streams based on dependencies. class AssignByDependencyPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); - Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; private: void InitEndSubgraphMap(const std::vector &subgraphs, std::map &end_subgraph_map); @@ -132,17 +133,24 @@ class AssignByDependencyPass : public LogicalStreamPass { std::map engine_stream_num_; // Subgraphs of assign stream by engine - std::set assigned_subgraphs_; + std::vector assigned_subgraphs_; // std::vector> reused_subgraphs_; }; +// All nodes in the graph are assigned the same stream. +class SingleStreamPass : public LogicalStreamPass { + public: + STREAM_PASS_DEFAULT_FUNC(SingleStreamPass); + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; +}; + // Update the stream of subgraphs to nodes. class NodeStreamUpdatePass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); - Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; private: /// Optimize for case like: @@ -150,19 +158,18 @@ class NodeStreamUpdatePass : public LogicalStreamPass { /// To case: /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) - Status UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs); + Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector &subgraphs); int64_t GetSingleInoutStream(const NodePtr &node) const; - // Judge if all predecessors' streams of node are INVALID_STREAM + // Judge if all predecessors' streams of node are kInvalidStream bool AreAllPredStreamsInvalid(const NodePtr &node) const; - void RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const; }; // AllReduce and backward operators execute in parallel. class AllReduceParallelPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); - Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; }; // Assign logical streams which is not limited by the number of tasks. @@ -173,18 +180,24 @@ class LogicalStreamAllocator { public: LogicalStreamAllocator(const std::map &scheduler_confs, - const std::map &max_parallel_num, bool hcom_parallel = false); + const std::map &max_parallel_num); LogicalStreamAllocator(const LogicalStreamAllocator &) = delete; LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; ~LogicalStreamAllocator() = default; - Status Assign(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs, int64_t &stream_num); + void EnableSingleStream(bool enable); + void EnableHcomParallel(bool hcom_parallel); + + Status Assign(const ComputeGraphPtr &root_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num); private: + Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, + const map &engine_confs); Status ConvertSubgraphs(const std::vector &subgraph_infos, const std::map &engine_confs, std::vector &subgraphs); - Status RunPasses(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs, int64_t &stream_num); + Status RunPasses(const ComputeGraphPtr &graph, const std::vector &subgraphs); + void RefreshContinuousStreams(const ComputeGraphPtr &graph); const std::map &scheduler_confs_; const std::map &max_parallel_num_; diff --git a/src/ge/graph/build/memory/binary_block_mem_assigner.cc b/src/ge/graph/build/memory/binary_block_mem_assigner.cc index 67c04ef6..8668e81e 100644 --- a/src/ge/graph/build/memory/binary_block_mem_assigner.cc +++ b/src/ge/graph/build/memory/binary_block_mem_assigner.cc @@ -100,13 +100,13 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { GELOGD("Origin ranges:"); for (auto &v : ranges) { - GELOGD("__%s", domi::ToString(v).c_str()); + GELOGD("__%s", ToString(v).c_str()); } PlanRanges(range_number_limit, ranges); GELOGD("Origin ranges:"); for (auto &v : ranges) { - GELOGD("__%s", domi::ToString(v).c_str()); + GELOGD("__%s", ToString(v).c_str()); } for (auto &range : ranges) { @@ -115,7 +115,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { range_ceils.push_back(range.back()); } } - GELOGI("Range ceils: %s", domi::ToString(range_ceils).c_str()); + GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); return SUCCESS; } diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index e0fd3d9b..73d3ee98 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -29,7 +29,6 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" -#include "common/op/attr_define.h" #include "graph/debug/ge_attr_define.h" #include "graph/optimize/common/params.h" @@ -47,29 +46,6 @@ const int kReuseMaxCharNum = 2000; } // namespace namespace ge { -using domi::AIPP_DATA_TYPE; -using domi::AIPPDATA; -using domi::ANN_DATA_TYPE; -using domi::APPLYMOMENTUM; -using domi::ASSIGN; -using domi::ASSIGNADD; -using domi::ASSIGNSUB; -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::DATA; -using domi::DATA_TYPE; -using domi::ENTER; -using domi::FASTRCNNPREDICTIONS; -using domi::HCOMALLREDUCE; -using domi::HCOMBROADCAST; -using domi::MULTISHAPE; -using domi::NETOUTPUT; -using domi::NEXTITERATION; -using domi::PROPOSAL; -using domi::REFENTER; -using domi::REFNEXTITERATION; -using domi::VARIABLE; -using domi::ZEROSLIKE; using std::map; using std::pair; using std::string; @@ -89,6 +65,9 @@ void MemoryBlock::Resize() { block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; } block_size_ = block_size; + if (last_continuous_block_) { + block_size_ += MEM_ALIGN_SIZE; + } } } @@ -158,12 +137,15 @@ string ToString(ge::NodeTypeIndex &x) { string MemoryBlock::String() { stringstream ss; ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << ""; - ss << "real_size_list: " << domi::ToString(real_size_list_) << ""; + ss << "real_size_list: " << ToString(real_size_list_) << ""; ss << "ref_count: " << ref_count_ << ""; ss << "members: "; for (auto x : NodeTypeIndexList()) { ss << "__node: " << ToString(x) << ""; } + for (const auto &symbol : SymbolList()) { + ss << "__symbol: " << symbol << ""; + } return ss.str(); } @@ -177,29 +159,56 @@ BlockMemAssigner::~BlockMemAssigner() { } void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { - vector temp; + if (GraphUtils::GetRefMapping(compute_graph_, symbol_to_anchors_, anchor_to_symbol_) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get ref-mapping for graph %s failed.", compute_graph_->GetName().c_str()); + return; + } - for (const NodePtr &n : compute_graph_->GetDirectNode()) { + vector temp; + for (const NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - for (const auto &output_desc : node_op_desc->GetAllOutputsDescPtr()) { + + if (node_op_desc->GetType() == ATOMICADDRCLEAN) { + atomic_addr_clean_id_ = node_op_desc->GetId(); + } + + for (auto &out_anchor : n->GetAllOutDataAnchors()) { + GeTensorDesc output_desc = node_op_desc->GetOutputDesc(out_anchor->GetIdx()); bool reuse_input = false; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(*output_desc, reuse_input) != SUCCESS, + GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(output_desc, reuse_input) != SUCCESS, GELOGI("Get reuse_input failed")); if (!reuse_input) { int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_desc, size) != SUCCESS, GELOGI("Get size failed")); - all_memory_size.emplace_back(size); + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(output_desc, size) != SUCCESS, GELOGI("Get size failed")); + if (anchor_to_symbol_.empty()) { + all_memory_size.emplace_back(size); + } else { + auto iter1 = anchor_to_symbol_.find(NodeIndexIO(n, out_anchor->GetIdx(), kOut).ToString()); + if (iter1 == anchor_to_symbol_.end()) { + continue; + } + std::string symbol = iter1->second; + auto iter2 = symbol_size_.find(symbol); + if (iter2 == symbol_size_.end()) { + symbol_size_[symbol] = size; + } else if (size > static_cast(iter2->second)) { + iter2->second = size; + } + } } } - temp.clear(); GetNodeWorkSpaceSize(n, temp); all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); } + GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); + for (auto &pair : symbol_size_) { + all_memory_size.emplace_back(pair.second); + } sort(all_memory_size.begin(), all_memory_size.end()); - GELOGI("All memory size: %s", domi::ToString(all_memory_size).c_str()); + GELOGI("All memory size: %s", ToString(all_memory_size).c_str()); for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { if (*iter == 0) { @@ -208,7 +217,11 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { ++iter; } } + + InitReuseFlag(); + PrintSymbolMap(); } + /// /// @ingroup domi /// @brief decide memory size based on actual input memory size @@ -259,18 +272,43 @@ void ReduceReusableBlockCount(const MemoryBlock &mem_block, map &reusable_block_counts, const MemoryBlock &reusable_block, - size_t block_size) { + size_t block_size, size_t real_size, bool continuous, int64_t atomic_addr_clean_id) { bool can_reuse = false; + + // If node is before atomic_addr_clean node, the continus memory can't be reused. + if (!reusable_block.NodeTypeIndexList().empty()) { + auto node = reusable_block.NodeTypeIndexList()[0].node; + if (node != nullptr) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + if ((op_desc->GetId() < atomic_addr_clean_id) && continuous) { + return false; + } + } + } + } + + // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block + if (continuous || reusable_block.continuous_block_) { + auto it = std::max_element(std::begin(reusable_block.RealSizeList()), std::end(reusable_block.RealSizeList())); + if (it != std::end(reusable_block.RealSizeList())) { + GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) || + (reusable_block.continuous_block_ && (real_size > *it)), + GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", + real_size, continuous, *it, reusable_block.continuous_block_); + return false;); + } + } if (reusable_block.Size() == block_size) { can_reuse = true; } else { string key = std::to_string(reusable_block.Size()); key += "_" + std::to_string(reusable_block.stream_id_); auto it = reusable_block_counts.find(key); - if ((it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && (reusable_block.Size() > block_size)) { + GE_IF_BOOL_EXEC( + (it != reusable_block_counts.end() && (it->second > kReuseMaxCount)) && (reusable_block.Size() > block_size), can_reuse = true; - GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", reusable_block.Size(), block_size); - } + GELOGD("Less size mem reuse, reuse block size:%zu, current block size:%zu", reusable_block.Size(), block_size);); } return can_reuse; } @@ -283,9 +321,186 @@ bool CanReuseByStream(const std::unordered_set &reuse_stream, MemoryBlo return can_reuse; } +bool BlockMemAssigner::IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, + uint32_t &peer_input_index) { + if (n == nullptr || n->GetAllOutDataAnchors().size() <= 0) { + return false; + } + if (static_cast(out_index) < n->GetAllOutDataAnchors().size()) { + auto out_anchor = n->GetOutDataAnchor(out_index); + GE_IF_BOOL_EXEC(out_anchor == nullptr, + GELOGE(FAILED, "Node[%s] output[%u] anchor is null.", n->GetName().c_str(), out_index); + return false;); + for (auto const &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, + GELOGE(FAILED, "Node[%s] output[%u] peer_in_anchor 0 is null.", n->GetName().c_str(), out_index); + return false;); + auto peer_node = peer_in_anchor->GetOwnerNode(); + GE_IF_BOOL_EXEC(peer_node == nullptr, + GELOGE(FAILED, "Node[%s] output[%u] node is null.", n->GetName().c_str(), out_index); + return false;); + + // Get the continuous input type of the node, default is false + bool is_input_continuous = false; + auto peer_in_node_desc = peer_node->GetOpDesc(); + GE_IF_BOOL_EXEC(peer_in_node_desc == nullptr, + GELOGE(FAILED, "Node[%s] output[%u] nodedesc is null.", n->GetName().c_str(), out_index); + return false;); + + // If GetBool fail, is_input_continuous is false. + (void)ge::AttrUtils::GetBool(peer_in_node_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); + if (is_input_continuous) { + if (n->GetOwnerComputeGraph() != nullptr) { + string graph_name = n->GetOwnerComputeGraph()->GetName(); + GELOGI("%s name[%s] output[%u] node[%s] set input[%d] continuous, input size[%u].", graph_name.c_str(), + n->GetName().c_str(), out_index, peer_in_node_desc->GetName().c_str(), peer_in_anchor->GetIdx(), + peer_node->GetAllInDataAnchorsSize()); + // Only set attr one times. + if (node_continuous_input_blocks_[peer_in_node_desc->GetName()].size() == 0) { + (void)ge::AttrUtils::SetBool(peer_in_node_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, true); + node_continuous_input_counts_[peer_in_node_desc->GetName()] = peer_node->GetAllInDataAnchorsSize(); + } + peer_input_index = peer_in_anchor->GetIdx(); + peer_name = peer_in_node_desc->GetName(); + return true; + } + } + } + } + return false; +} + +/// +/// @ingroup GE +/// @brief Check pre_reuse flag & post_reuse glag for each symbol +/// @return void +/// +void BlockMemAssigner::InitReuseFlag() { + static const std::set kPreReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ANN_DATA_TYPE, + ge::NETOUTPUT, ge::PROPOSAL, ge::ZEROSLIKE, + ge::CONSTANT, ge::CONSTANTOP}; + static const std::set kPostReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ENTER, + ge::REFENTER, ge::NEXTITERATION, ge::REFNEXTITERATION}; + for (auto &pair : symbol_to_anchors_) { + std::string symbol = pair.first; + bool pre_reuse_flag = true; + bool post_reuse_flag = true; + for (auto &node_index_io : pair.second) { + if (node_index_io.io_type == kIn) { + continue; + } + + OutDataAnchorPtr out_anchor = node_index_io.node->GetOutDataAnchor(node_index_io.index); + if (out_anchor == nullptr) { + continue; + } + + bool out_flg = false; + if (node_index_io.node->GetOutDataNodes().empty()) { + out_flg = true; + } + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + if (IsDirectOutputNode(in_anchor->GetOwnerNode(), in_anchor->GetIdx())) { + out_flg = true; + break; + } + } + std::string type = out_anchor->GetOwnerNode()->GetType(); + pre_reuse_flag = pre_reuse_flag && !out_flg && (kPreReuseTypes.count(type) == 0); + post_reuse_flag = post_reuse_flag && (kPostReuseTypes.count(type) == 0); + if (!pre_reuse_flag && !post_reuse_flag) { + break; + } + } + pre_reuse_flag_[symbol] = pre_reuse_flag; + post_reuse_flag_[symbol] = post_reuse_flag; + } +} + +/// +/// @ingroup GE +/// @brief get pre_reuse flag +/// @param [in] node +/// @param [in] out_index +/// @return bool +/// +bool BlockMemAssigner::IsPreReuse(const NodePtr &node, uint32_t out_index) const { + OutDataAnchorPtr out_data_anchor = nullptr; + if (static_cast(out_index) < node->GetAllOutDataAnchors().size()) { + out_data_anchor = node->GetOutDataAnchor(out_index); + } + if (out_data_anchor == nullptr) { + return false; + } + NodeIndexIO cur_node_index_io(out_data_anchor->GetOwnerNode(), out_data_anchor->GetIdx(), kOut); + auto iter1 = anchor_to_symbol_.find(cur_node_index_io.ToString()); + if (iter1 == anchor_to_symbol_.end()) { + return false; + } + + std::string symbol = iter1->second; + auto iter2 = pre_reuse_flag_.find(symbol); + if (iter2 == pre_reuse_flag_.end()) { + return false; + } + return iter2->second; +} + +/// +/// @ingroup GE +/// @brief get post_reuse flag +/// @param [in] mem_block +/// @return bool +/// +bool BlockMemAssigner::IsPostReuse(const MemoryBlock *mem_block) const { + if (mem_block == nullptr) { + return false; + } + for (auto &symbol : mem_block->SymbolList()) { + auto iter = post_reuse_flag_.find(symbol); + if (iter == post_reuse_flag_.end()) { + continue; + } + if (!iter->second) { + return false; + } + } + return true; +} + +/// +/// @ingroup GE +/// @brief check if symbol of cur node_index_io has block +/// @param [in] node_index_io +/// @return bool +/// +bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { + auto iter = anchor_to_symbol_.find(node_index_io.ToString()); + if (iter == anchor_to_symbol_.end()) { + return false; + } + std::string symbol = iter->second; + return symbol_blocks_.find(symbol) != symbol_blocks_.end(); +} + +/// +/// @ingroup GE +/// @brief Print symbol +/// @return void +/// +void BlockMemAssigner::PrintSymbolMap() { + for (auto &pair : symbol_to_anchors_) { + GELOGD("symbol=%s, max_size=%zu, pre_reuse=%s, post_reuse=%s", pair.first.c_str(), symbol_size_[pair.first], + pre_reuse_flag_[pair.first] ? "true" : "false", post_reuse_flag_[pair.first] ? "true" : "false"); + for (auto &node_index_io : pair.second) { + GELOGD("anchor:%s", node_index_io.ToString().c_str()); + } + } +} + MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const NodePtr &n, uint32_t out_index, const vector &workspace_reuse_flag, - const bool is_op_reuse_mem) { + const bool is_op_reuse_mem, const bool continuous) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); @@ -293,52 +508,36 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, string ge_disable_reuse_mem_env = "0"; (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); if (ge_disable_reuse_mem_env != "1") { - int64_t convergence_label; - bool reuse_mem_flag = - ((workspace_reuse_flag.size() > out_index) && (workspace_reuse_flag[out_index] == false)) ? false : true; - if (!ge::AttrUtils::GetInt(node_op_desc, kL2FusionDynamicConvergeOp, convergence_label)) { - bool out_flg = false; - GE_IF_BOOL_EXEC(n->GetOutDataNodes().empty(), out_flg = true); - if (static_cast(out_index) < n->GetAllOutDataAnchors().size()) { - for (auto in_anchor : n->GetOutDataAnchor(out_index)->GetPeerInDataAnchors()) { - if (IsDirectOutputNode(in_anchor->GetOwnerNode(), in_anchor->GetIdx())) { - out_flg = true; - break; - } + bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); + bool is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && reuse_mem_flag && is_op_reuse_mem && + (IsPreReuse(n, out_index)); + auto stream_id = node_op_desc->GetStreamId(); + auto map_iter = reusable_streams_map_.find(stream_id); + if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { + for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { + MemoryBlock *reusable_block = *it; + if (!IsPostReuse(reusable_block)) { + continue; } - auto op_type = node_op_desc->GetType(); - bool is_reuse_memory = !out_flg && reuse_mem_flag && (op_type != DATA_TYPE) && (op_type != AIPP_DATA_TYPE) && - (op_type != CONSTANT) && (op_type != NETOUTPUT) && (op_type != PROPOSAL) && - (op_type != ANN_DATA_TYPE) && (op_type != ZEROSLIKE) && (op_type != CONSTANTOP) && - is_op_reuse_mem; - - auto stream_id = node_op_desc->GetStreamId(); - auto map_iter = reusable_streams_map_.find(stream_id); - if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { - for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { - MemoryBlock *reusable_block = *it; - bool is_data = false; - for (auto node_type : reusable_block->NodeTypeIndexList()) { - GE_IF_BOOL_EXEC(node_type.node != nullptr, string type = node_type.node->GetType(); - bool flag = (type == DATA_TYPE) || (type == ENTER) || (type == REFENTER) || - (type == AIPP_DATA_TYPE) || (type == NEXTITERATION) || - (type == REFNEXTITERATION); - GE_IF_BOOL_EXEC(flag, is_data = true; break;);); - } - GE_IF_BOOL_EXEC(is_data == true, continue); - - // A node can reuse blocks of the same stream and preorder streams - if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size) && - CanReuseByStream(map_iter->second, *reusable_block)) { - GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, - stream_id); - reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size); - reusable_block->ref_count_++; - ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); - reusable_blocks_.erase(it); - return reusable_block; + + // A node can reuse blocks of the same stream and preorder streams + auto id = GetAtomicAddrCleanId(); + if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous, id) && + CanReuseByStream(map_iter->second, *reusable_block)) { + GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, + stream_id); + reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size); + if (mem_type == kOutput) { + auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); + if (iter != anchor_to_symbol_.end()) { + reusable_block->AddSymbol(iter->second); } } + reusable_block->continuous_block_ = continuous; + reusable_block->ref_count_++; + ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); + reusable_blocks_.erase(it); + return reusable_block; } } } @@ -347,46 +546,55 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, auto block = new (std::nothrow) MemoryBlock(block_size, is_op_reuse_mem); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed."); + // Data and netoutput need zero copy block + if ((node_op_desc->GetType() == DATA_TYPE && !continuous) || (node_op_desc->GetType() == NETOUTPUT)) { + block->is_zero_copy_ = true; + } + block->Init(real_size, mem_type, n, out_index); block->stream_id_ = node_op_desc->GetStreamId(); block->ref_count_++; + block->continuous_block_ = continuous; + if (mem_type == kOutput) { + auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); + if (iter != anchor_to_symbol_.end()) { + block->AddSymbol(iter->second); + } + } memory_blocks_.emplace_back(block); return block; } MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, const vector &ranges, - const bool is_op_reuse_mem) { + const bool is_op_reuse_mem, const bool continuous) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null."); auto node_op_desc = n->GetOpDesc(); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node_op_desc == nullptr, return nullptr, "node_op_desc is null."); MemoryBlock *block = nullptr; - bool reuse_input = false; - uint32_t reuse_input_index = 0; + NodeIndexIO node_index_io = NodeIndexIO(n, index, kOut); int64_t size = 0; auto output_op_desc = node_op_desc->GetOutputDescPtr(index); if (output_op_desc != nullptr) { - GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInput(*output_op_desc, reuse_input) != SUCCESS, - GELOGI("Get reuse_input failed")); - GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) != SUCCESS, - GELOGI("Get reuse_input_index failed")); GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed")); } - if (reuse_input) { - auto in_data_anchor = n->GetInDataAnchor(reuse_input_index); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(in_data_anchor == nullptr, return nullptr, "In data anchor is null."); - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(peer_out_anchor == nullptr, return nullptr, "Peer out data anchor is null."); - auto reuse_src_node = peer_out_anchor->GetOwnerNode(); - auto reuse_src_node_output_index = static_cast(peer_out_anchor->GetIdx()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - (node_out_blocks_.empty() || (node_out_blocks_[reuse_src_node->GetName()].size() <= reuse_src_node_output_index)), - return nullptr, "node_out_block of node_out_block[reuse_src_node->Name()] is empty!"); - block = node_out_blocks_[reuse_src_node->GetName()][reuse_src_node_output_index]; + if (IsSymbolExist(node_index_io)) { + std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; + block = symbol_blocks_[symbol]; + block->AddNodeTypeIndex({n, kOutput, index}, size); + block->ref_count_++; } else { - auto block_size = GetBlockSize(size, ranges); + int64_t max_size = size; + auto iter1 = anchor_to_symbol_.find(node_index_io.ToString()); + if (iter1 != anchor_to_symbol_.end()) { + auto iter2 = symbol_size_.find(iter1->second); + if (iter2 != symbol_size_.end()) { + max_size = iter2->second; + } + } + auto block_size = GetBlockSize(max_size, ranges); vector workspace_reuse_flag; - block = ApplyMemory(block_size, size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem); + block = ApplyMemory(block_size, size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem, continuous); } GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr."); int out_count_reuse_input = block->ref_count_; @@ -404,6 +612,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, out_count++; } } + bool reuse_input = false; for (const auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { auto owner_node = in_anchor->GetOwnerNode(); GE_IF_BOOL_EXEC(owner_node == nullptr, continue); @@ -470,6 +679,31 @@ bool IsReferencePreviousNodeOutputMemory(const ge::NodePtr &node, uint32_t outpu return false; } +// atomic out memory will be reassigned +bool IsAtomicOutputMemory(const ge::NodePtr &node, uint32_t output_index, bool is_atomic, + bool out_node_set_continuous_input) { + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + return false; + } + vector atomic_output_index; + // If GetListInt fail, atomic_output_index is empty. + (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); + if (!out_node_set_continuous_input && is_atomic) { + for (auto &index : atomic_output_index) { + if (static_cast(index) == output_index) { + if (node->GetOwnerComputeGraph() != nullptr) { + string graph_name = node->GetOwnerComputeGraph()->GetName(); + GELOGD("[IMAS]Atomic no assign %s name[%s] output[%d] streamid[%ld].", graph_name.c_str(), + op_desc->GetName().c_str(), index, op_desc->GetStreamId()); + } + return true; + } + } + } + return false; +} + void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector &reusable_memory) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(to_release == nullptr, return, "Input parameter to_release is null."); GE_CHK_TRUE_EXEC_INFO(to_release->ref_count_ <= 0, return, "Release memory"); @@ -571,10 +805,12 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector GELOGI("Assign memory node[%s], output size[%d], output memory type size[%d]", node_op_desc->GetName().c_str(), node_op_desc->GetOutputsSize(), memorys_type.size()); if (has_mem_type_attr && (memorys_type.size() != node_op_desc->GetOutputsSize())) { - GELOGE(INTERNAL_ERROR, "L1fusion: node[%s], output memory size err[outputsize:%zu, memorysize:%zu]", + GELOGE(INTERNAL_ERROR, "fusion: node[%s], output memory size err[outputsize:%zu, memorysize:%zu]", node_op_desc->GetName().c_str(), node_op_desc->GetOutputsSize(), memorys_type.size()); return INTERNAL_ERROR; } + + is_op_reuse_mem_ = true; if (op_reuse_env_valid_ == true) { vector::iterator it_name = std::find(op_no_reuse_mem_vec_.begin(), op_no_reuse_mem_vec_.end(), node_op_desc->GetName()); @@ -584,6 +820,9 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector is_op_reuse_mem_ = false;); } + bool is_atomic = false; + // If GetBool fail, is_atomic is false. + (void)ge::AttrUtils::GetBool(node_op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_);) @@ -593,19 +832,43 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector if (output_op_desc != nullptr) { GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed")); } - // l1 fusion: l1 type's size not means malloc HBM memory - if (has_mem_type_attr && memorys_type[i] != RT_MEMORY_HBM) { - GELOGI("L1fusion: node[%s], output[%s], output memory type [%d]", node_op_desc->GetName().c_str(), + // fusion: other type's size not means malloc HBM memory + bool l1_flag = has_mem_type_attr && memorys_type[i] == RT_MEMORY_L1; + if (l1_flag) { + GELOGI("fusion: node[%s], output[%s], output memory type [%d]", node_op_desc->GetName().c_str(), node_op_desc->GetOutputNameByIndex(i).c_str(), memorys_type[i]); size = 0; } - if ((size == 0) || CheckIsZeroMemNodeType(node->GetType()) || IsReferencePreviousNodeOutputMemory(node, i)) { + std::string peer_name; + uint32_t peer_input_index = 0; + bool out_node_set_continuous_input = false; + bool no_need_assign_memory = + ((size == 0) || CheckIsZeroMemNodeType(node->GetType()) || IsReferencePreviousNodeOutputMemory(node, i)); + if (!no_need_assign_memory) { + out_node_set_continuous_input = IsOutNodeSetContinuousInput(node, i, peer_name, peer_input_index); + no_need_assign_memory = IsAtomicOutputMemory(node, i, is_atomic, out_node_set_continuous_input); + } + if (no_need_assign_memory) { zero_memory_list_.emplace_back(node, kOutput, i); continue; } - MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, is_op_reuse_mem_); + bool reuse_mem = is_op_reuse_mem_; + // atomic can't be reused + if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) { + reuse_mem = false; + } + MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, reuse_mem, out_node_set_continuous_input); if (mem_block != nullptr) { node_out_blocks_[node->GetName()].emplace_back(mem_block); + if (out_node_set_continuous_input) { + node_continuous_input_blocks_[peer_name][peer_input_index] = mem_block; + } + NodeIndexIO node_index_io(node, i, kOut); + auto iter = anchor_to_symbol_.find(node_index_io.ToString()); + if (iter == anchor_to_symbol_.end()) { + continue; + } + symbol_blocks_[iter->second] = mem_block; } } return SUCCESS; @@ -621,15 +884,14 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { // Init reusable streams map InitReusableStreamMap(); - (void)ge::GetContext().GetOption("ge.exec.disableReuseMemory", ge_disable_reuse_mem_env_); - + (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env_); GEEVENT("Reuse memory %s", ge_disable_reuse_mem_env_ == "1" ? "close" : "open"); string op_no_reuse_mem_str; const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM); GE_IF_BOOL_EXEC(op_no_reuse_mem != nullptr, op_no_reuse_mem_str = string(op_no_reuse_mem); CheckAndGetOpReuseEnv(op_no_reuse_mem_str, op_no_reuse_mem_vec_, op_reuse_env_valid_);); - for (NodePtr &n : compute_graph_->GetDirectNode()) { + for (NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); int64_t stream_id = node_op_desc->GetStreamId(); @@ -651,16 +913,17 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { temp.size(), workspace_memory_type.size()); if (has_workspace_mem_type_attr && (temp.size() != workspace_memory_type.size())) { - GELOGE(INTERNAL_ERROR, "L1fusion: node[%s], workspace_memory size err![v_temp:%zu, workspace:%zu]", temp.size(), - workspace_memory_type.size()); + GELOGE(INTERNAL_ERROR, "fusion: node[%s], workspace_memory size err![v_temp:%zu, workspace:%zu]", + n->GetName().c_str(), temp.size(), workspace_memory_type.size()); return; } for (size_t i = 0; i < temp.size(); i++) { - // l1 fusion: l1 type's size not means malloc HBM memory + // fusion: other type's size not means malloc HBM memory bool workspace_skip_flag = false; - if (has_workspace_mem_type_attr && workspace_memory_type[i] != RT_MEMORY_HBM) { - GELOGI("L1fusion: node[%s]workspace index[%d] is l1 type, add to zero_memory_list, workspace memory type [%ld]", - node_op_desc->GetName().c_str(), i, workspace_memory_type[i]); + if (has_workspace_mem_type_attr && workspace_memory_type[i] == RT_MEMORY_L1) { + GELOGI( + "fusion: node[%s]workspace index[%d] is not hbm type, add to zero_memory_list, workspace memory type [%ld]", + node_op_desc->GetName().c_str(), i, workspace_memory_type[i]); workspace_skip_flag = true; } if (temp[i] == 0 || workspace_skip_flag) { @@ -669,7 +932,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { } MemoryBlock *mem_block = ApplyMemory(GetBlockSize(static_cast(temp[i]), ranges), static_cast(temp[i]), kWorkspace, n, - static_cast(i), workspace_reuse_flag, is_op_reuse_mem_); + static_cast(i), workspace_reuse_flag, is_op_reuse_mem_, false); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); } @@ -683,6 +946,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { } GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), MergeDynamicBatchBlocks();) + AssignContinuousBlocks(); ResizeMemoryBlocks(); GELOGD("Memory blocks after resize:"); @@ -733,6 +997,9 @@ void MergeBlocks(std::vector &dest, std::vector &s return; } if (dest[i] != nullptr && src[i] != nullptr) { + for (auto &symbol : src[i]->SymbolList()) { + dest[i]->AddSymbol(symbol); + } for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j]); src[i]->deleted_block_ = true; @@ -774,6 +1041,80 @@ void BlockMemAssigner::MergeDynamicBatchBlocks() { } } +// asending order +static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { + if (left == nullptr || right == nullptr) { + return false; + } + if (left->input_index_ < right->input_index_) { + return true; + } + return false; +} +/// +/// @ingroup domi +/// @brief order blocks by continuous input index +/// @param [in] blocks need be processed +/// @param [in] input blocks need continuous +/// @param [out] blocks after continuous order +/// @param [in/out] blocks ordered +/// +void ReAssignContinuousBlocks(const std::vector &org_blocks, + const std::map block_map, + std::vector &dest_blocks, std::vector &continuous_blocks) { + for (auto &memory_block : org_blocks) { + if (memory_block == nullptr || memory_block->deleted_block_) { + continue; + } + if (block_map.find(memory_block) != block_map.end()) { + continue; + } + dest_blocks.emplace_back(memory_block); + } + + // add continuous block + std::sort(continuous_blocks.begin(), continuous_blocks.end(), CompareBlockIndex); + size_t count = 0; + for (auto &memory_block : continuous_blocks) { + GE_IF_BOOL_EXEC(memory_block == nullptr, continue); + + GELOGI("Block continuous input index:%d", memory_block->input_index_); + count++; + if (count == continuous_blocks.size()) { + memory_block->last_continuous_block_ = true; + } + dest_blocks.emplace_back(memory_block); + } +} + +void BlockMemAssigner::AssignContinuousBlocks() { + for (auto &block_map : node_continuous_input_blocks_) { + std::vector dest_memory_blocks; + std::map continuous_block_map; + std::vector continuous_blocks; + auto it = node_continuous_input_counts_.find(block_map.first); + GE_IF_BOOL_EXEC(it == node_continuous_input_counts_.end(), continue); + GELOGI("Node:%s continuous input block count:%zu input count:%u", block_map.first.c_str(), block_map.second.size(), + it->second); + GE_IF_BOOL_EXEC(it->second != block_map.second.size(), continue); + + for (auto &it : block_map.second) { + if (it.second != nullptr) { + continuous_block_map[it.second] = it.first; + it.second->input_index_ = it.first; + continuous_blocks.emplace_back(it.second); + } + } + if (continuous_block_map.size() != continuous_blocks.size()) { + GELOGW("Node:%s continuous input map size:%zu vector size:%zu", block_map.first.c_str(), + continuous_block_map.size(), continuous_blocks.size()); + continue; + } + ReAssignContinuousBlocks(memory_blocks_, continuous_block_map, dest_memory_blocks, continuous_blocks); + memory_blocks_.swap(dest_memory_blocks); + } +} + /// /// @ingroup domi_omg /// @brief traverse memory size, resize, calculate offset @@ -781,13 +1122,14 @@ void BlockMemAssigner::MergeDynamicBatchBlocks() { /// void BlockMemAssigner::ResizeMemoryBlocks() { for (auto &memory_block : memory_blocks_) { - if (memory_block == nullptr || memory_block->deleted_block_) { + if (memory_block == nullptr || memory_block->deleted_block_ || memory_block->is_zero_copy_) { continue; } memory_block->Resize(); memory_block->SetHeadOffset(mem_offset_); mem_offset_ += memory_block->Size(); memory_block->SetTailOffset(mem_offset_ - 1); + GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_); } } @@ -822,8 +1164,8 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t output_list.at(node_type_index.index) = offset; } } else { - // l1 fusion: keep the original offset value from op_desc - bool set_out_offset = (!has_mem_type_attr) || (memorys_type[node_type_index.index] == RT_MEMORY_HBM); + // fusion: keep the original other type offset value from op_desc + bool set_out_offset = (!has_mem_type_attr) || (memorys_type[node_type_index.index] != RT_MEMORY_L1); if (set_out_offset) { output_list.at(node_type_index.index) = offset; } @@ -841,9 +1183,9 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t vector workspace_memory_type; bool has_workspace_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_memory_type); - // l1 fusion: keep the original offset value from op_desc + // fusion: keep the original other type offset value from op_desc bool set_workspace_offset = - (!has_workspace_mem_type_attr) || (workspace_memory_type[node_type_index.index] == RT_MEMORY_HBM); + (!has_workspace_mem_type_attr) || (workspace_memory_type[node_type_index.index] != RT_MEMORY_L1); if (set_workspace_offset) { workspace_list.at(node_type_index.index) = offset; } @@ -854,11 +1196,16 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t } } -void BlockMemAssigner::SetOpMemOffset() { +void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { for (MemoryBlock *memory_block : memory_blocks_) { if (memory_block == nullptr || memory_block->deleted_block_) { continue; } + + if ((is_zero_copy && !memory_block->is_zero_copy_) || (!is_zero_copy && memory_block->is_zero_copy_)) { + continue; + } + size_t index = 0; size_t real_size = 0; auto real_size_list_size = memory_block->RealSizeList().size(); @@ -870,8 +1217,11 @@ void BlockMemAssigner::SetOpMemOffset() { index++; } } - for (const NodeTypeIndex &node_type_index : zero_memory_list_) { - SetOffsetSize(node_type_index, 0, 0, 0); + + if (!is_zero_copy) { + for (const NodeTypeIndex &node_type_index : zero_memory_list_) { + SetOffsetSize(node_type_index, 0, 0, 0); + } } } @@ -884,7 +1234,7 @@ Status BlockMemAssigner::Assign() { GE_IF_BOOL_EXEC(ranges.empty(), return SUCCESS); AssignMemoryWithReuse(ranges); - SetOpMemOffset(); + SetOpMemOffset(false); return SUCCESS; } @@ -925,7 +1275,7 @@ void BlockMemAssigner::InitReusableStreamMap() { void BlockMemAssigner::FindHeadAndTailNodesForStream(map> &stream_head_tail_node_map, unordered_map &stream_mem_map) { - for (const auto &n : compute_graph_->GetDirectNode()) { + for (const auto &n : compute_graph_->GetAllNodes()) { GE_IF_BOOL_EXEC(n->GetOpDesc() == nullptr, GELOGW("Op desc is nullptr"); continue); auto stream_id = n->GetOpDesc()->GetStreamId(); // traverse to find streams's first and last node. @@ -961,11 +1311,19 @@ void BlockMemAssigner::FindDependentStream(map> } NodePtr pre_node = it1.second.second; NodePtr post_node = it2.second.first; + std::vector out_nodes; + // Direct link out_node for (const auto &out_node : pre_node->GetOutNodes()) { if ((out_node->GetOpDesc() == nullptr) || (post_node->GetOpDesc() == nullptr) || (pre_node->GetOpDesc() == nullptr)) { continue; } + out_nodes.emplace_back(out_node); + } + + FindDependentStreamBetweenGraphs(pre_node, out_nodes); + + for (auto &out_node : out_nodes) { if (out_node->GetOpDesc()->GetId() == post_node->GetOpDesc()->GetId()) { stream_dependency_map[pre_node->GetOpDesc()->GetStreamId()].insert(post_node->GetOpDesc()->GetStreamId()); } @@ -974,6 +1332,46 @@ void BlockMemAssigner::FindDependentStream(map> } } +/// +/// @ingroup GE +/// @brief Find dependent link between parent_graph and sub_graph +/// @param [in] pre_node +/// @param [out] out_nodes +/// @return void +/// @author +/// +void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes) { + if ((pre_node == nullptr) || (pre_node->GetOpDesc() == nullptr)) { + return; + } + + // FunctionOp & subgraph input + std::vector subgraph_names = pre_node->GetOpDesc()->GetSubgraphInstanceNames(); + for (auto &subgraph_name : subgraph_names) { + ComputeGraphPtr subgraph = compute_graph_->GetSubgraph(subgraph_name); + if (subgraph == nullptr) { + continue; + } + for (auto &node : subgraph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + continue; + } + if (op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { + out_nodes.emplace_back(node); + } + } + } + + // subgraph output & parent_node output + if (NodeUtils::IsSubgraphOutput(pre_node)) { + NodePtr parent_node = pre_node->GetOwnerComputeGraph()->GetParentNode(); + for (const auto &out_node : parent_node->GetOutNodes()) { + out_nodes.emplace_back(out_node); + } + } +} + bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h index d0cb5339..97e69431 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.h +++ b/src/ge/graph/build/memory/block_mem_assigner.h @@ -28,6 +28,7 @@ #include "common/util.h" #include "graph/build/memory/mem_assigner.h" #include "graph/compute_graph.h" +#include "graph/utils/graph_utils.h" namespace ge { enum MemoryType { kOutput, kWorkspace }; @@ -48,6 +49,10 @@ class MemoryBlock { stream_id_(0), deleted_block_(false), reuse_mem_(reuse_mem), + input_index_(0), + continuous_block_(false), + last_continuous_block_(false), + is_zero_copy_(false), block_size_(block_size), head_offset_(0), tail_offset_(0) {} @@ -56,7 +61,10 @@ class MemoryBlock { MemoryBlock &operator=(const MemoryBlock &) = delete; - ~MemoryBlock() { node_type_index_list_.clear(); } + ~MemoryBlock() { + node_type_index_list_.clear(); + symbol_list_.clear(); + } void Init(size_t real_size, MemoryType type, const ge::NodePtr &node, uint32_t out_index) { real_size_list_.emplace_back(real_size); @@ -77,7 +85,10 @@ class MemoryBlock { real_size_list_.emplace_back(real_size); } + void AddSymbol(const std::string &symbol) { symbol_list_.emplace_back(symbol); } + const std::vector &NodeTypeIndexList() const { return node_type_index_list_; } + const std::vector &SymbolList() const { return symbol_list_; } const std::vector &RealSizeList() const { return real_size_list_; } void Resize(); @@ -90,6 +101,10 @@ class MemoryBlock { int64_t stream_id_; bool deleted_block_; bool reuse_mem_; + uint32_t input_index_; + bool continuous_block_; + bool last_continuous_block_; + bool is_zero_copy_; private: size_t block_size_; @@ -97,6 +112,7 @@ class MemoryBlock { size_t head_offset_; size_t tail_offset_; std::vector node_type_index_list_; + std::vector symbol_list_; }; class BlockMemAssigner : public MemAssigner { @@ -111,7 +127,11 @@ class BlockMemAssigner : public MemAssigner { Status Assign() override; - size_t GetMemOffset() const { return mem_offset_; } + size_t GetMemOffset() const { return mem_offset_; }; + + int64_t GetAtomicAddrCleanId() const { return atomic_addr_clean_id_; }; + + std::vector GetMemoryBlocks() const { return memory_blocks_; }; /// /// @ingroup domi @@ -128,7 +148,7 @@ class BlockMemAssigner : public MemAssigner { /// void AssignMemoryWithReuse(std::vector &ranges); - void SetOpMemOffset(); + void SetOpMemOffset(bool is_zero_copy); protected: /// @@ -173,6 +193,16 @@ class BlockMemAssigner : public MemAssigner { void FindDependentStream(std::map> &stream_head_tail_node_map, std::map> &stream_dependency_map); + /// + /// @ingroup GE + /// @brief Find dependent link between parent_graph and sub_graph + /// @param [in] pre_node + /// @param [out] out_nodes + /// @return void + /// @author + /// + void FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes); + /// /// @ingroup GE /// @brief Determine whether it is the type of zero memory node. @@ -182,6 +212,45 @@ class BlockMemAssigner : public MemAssigner { /// bool CheckIsZeroMemNodeType(const std::string &node_type) const; + /// + /// @ingroup GE + /// @brief Check pre_reuse flag & post_reuse glag for each symbol + /// @return void + /// + void InitReuseFlag(); + + /// + /// @ingroup GE + /// @brief get pre_reuse flag + /// @param [in] node + /// @param [in] out_index + /// @return bool + /// + bool IsPreReuse(const NodePtr &node, uint32_t out_index) const; + + /// + /// @ingroup GE + /// @brief get post_reuse flag + /// @param [in] mem_block + /// @return bool + /// + bool IsPostReuse(const MemoryBlock *mem_block) const; + + /// + /// @ingroup GE + /// @brief check if symbol of cur node_index_io has block + /// @param [in] node_index_io + /// @return bool + /// + bool IsSymbolExist(const NodeIndexIO &node_index_io); + + /// + /// @ingroup GE + /// @brief Print symbol + /// @return void + /// + void PrintSymbolMap(); + size_t mem_offset_; ge::ComputeGraphPtr compute_graph_; @@ -190,6 +259,13 @@ class BlockMemAssigner : public MemAssigner { std::vector zero_memory_list_; + // ref mapping + std::map> symbol_to_anchors_; + std::map anchor_to_symbol_; + std::map pre_reuse_flag_; + std::map post_reuse_flag_; + std::map symbol_size_; + private: /// /// @ingroup GE @@ -201,7 +277,7 @@ class BlockMemAssigner : public MemAssigner { /// @author /// MemoryBlock *ApplyOutMemory(const ge::NodePtr &n, uint32_t index, const std::vector &ranges, - const bool is_op_reuse_mem); + const bool is_op_reuse_mem, const bool continuous); Status AssignOutputMemoryWithReuse(const NodePtr &node, vector &ranges); /// @@ -218,7 +294,7 @@ class BlockMemAssigner : public MemAssigner { /// MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const ge::NodePtr &n, uint32_t out_index, const std::vector &workspace_reuse_flag, - const bool is_op_reuse_mem); + const bool is_op_reuse_mem, const bool continuous); /// /// @ingroup GE @@ -273,6 +349,11 @@ class BlockMemAssigner : public MemAssigner { /// void MergeDynamicBatchBlocks(); + void AssignContinuousBlocks(); + + bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, + uint32_t &peer_input_index); + std::vector reusable_blocks_; std::map reusable_block_counts_; @@ -281,6 +362,12 @@ class BlockMemAssigner : public MemAssigner { std::unordered_map> node_out_blocks_; + std::unordered_map symbol_blocks_; + + std::unordered_map> node_continuous_input_blocks_; + + std::unordered_map node_continuous_input_counts_; + // save stream_id and reusable stream_ids std::unordered_map> reusable_streams_map_; @@ -292,6 +379,8 @@ class BlockMemAssigner : public MemAssigner { std::string ge_disable_reuse_mem_env_ = "0"; bool is_op_reuse_mem_ = true; + + int64_t atomic_addr_clean_id_ = 0; }; } // namespace ge #endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_ diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index 7fc07f42..c3078bec 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -18,10 +18,10 @@ #include #include #include "common/math/math_util.h" -#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/var_mem_assign_util.h" +#include "graph/build/memory/block_mem_assigner.h" #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" @@ -29,22 +29,15 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -using domi::AIPP_DATA_TYPE; -using domi::ATOMICADDRCLEAN; -using domi::ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -using domi::ATTR_NAME_AUTOMIC_ADD_START; -using domi::CONCAT; -using domi::CONSTANTOP; -using domi::DATA_TYPE; -using domi::HCOMBROADCAST; -using domi::LABELSWITCHBYINDEX; -using domi::NODE_NAME_NET_OUTPUT; -using domi::STREAMMERGE; -using domi::VARIABLE; - namespace { const int kDataOutputIndex = 0; const int kAllInputAddrIsAtomic = -1; +const int kVirtualInputNodeMemoryReuse = 0; +const int kVirtualOutputNodeMemoryReuse = 1; +const size_t kVirtualInputNodeOutputSize = 1; +const size_t kVirtualOutputNodeInputSize = 1; +const size_t kVirtualNodeDataIndex = 0; +const char *const kMbatchNodeNameFlag = "_ascend_mbatch_batch_"; } // namespace namespace ge { Status VariableMemoryAssigner::Assign() { @@ -69,12 +62,12 @@ Status VariableMemoryAssigner::AssignVarAttr2Nodes() { } Status GraphMemoryAssigner::AssignMemory() { - ge::HybridMemAssigner mem_assigner(compute_graph_); - if (mem_assigner.Assign() != ge::SUCCESS) { + ge::HybridMemAssignerPtr mem_assigner(new (std::nothrow) HybridMemAssigner(compute_graph_)); + if (mem_assigner->Assign() != ge::SUCCESS) { GELOGE(ge::FAILED, "Memory assigner failed"); return ge::FAILED; } - MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner.GetMemOffset()); + MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset()); memory_offset_.push_back(memory_offset); auto session_id = compute_graph_->GetSessionID(); @@ -91,6 +84,9 @@ Status GraphMemoryAssigner::AssignMemory() { } int64_t var_size_assign = ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) - var_size_before_assign; GELOGI("GraphMemoryAssigner::AssignMemory variable size = %ld", var_size_assign); + + mem_assigner_ = std::move(mem_assigner); + return ge::SUCCESS; } @@ -149,6 +145,65 @@ ge::Status GraphMemoryAssigner::CalculateTensorRealSizeAndOutSize(const ge::Cons return SUCCESS; } +Status GraphMemoryAssigner::GetMaxBatchLabel(const map> &mem_reuse_virtual_nodes_map, + int32_t mem_reuse_model, string &max_batch_label) { + for (auto &i_map : mem_reuse_virtual_nodes_map) { + vector virtual_nodes_list = i_map.second; + vector max_shape_dims; + size_t max_batch_dim = 0; + bool max_batch_dim_find = false; + for (size_t i = 0; i < virtual_nodes_list.size(); ++i) { + GE_CHECK_NOTNULL(virtual_nodes_list[i]); + OpDescPtr op_desc = virtual_nodes_list[i]->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + ge::ConstGeTensorDescPtr input_output_desc; + if (mem_reuse_model == kVirtualInputNodeMemoryReuse) { + input_output_desc = op_desc->GetOutputDescPtr(kVirtualNodeDataIndex); + } else if (mem_reuse_model == kVirtualOutputNodeMemoryReuse) { + input_output_desc = op_desc->GetInputDescPtr(kVirtualNodeDataIndex); + } else { + GELOGE(FAILED, "Invalid parameter memory reuse model, which is: %d.", mem_reuse_model); + return FAILED; + } + GE_CHECK_NOTNULL(input_output_desc); + + if (i == 0) { + // All ops must have ATTR_NAME_BATCH_LABEL, no need to check return value. + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, max_batch_label); + max_shape_dims = input_output_desc->GetShape().GetDims(); + } else { + vector current_shape_dims = input_output_desc->GetShape().GetDims(); + if (current_shape_dims.size() != max_shape_dims.size()) { + GELOGE(FAILED, "The shape size of several nodes between multiple batches does not match."); + return FAILED; + } + for (size_t j = 0; j < current_shape_dims.size(); ++j) { + if (current_shape_dims[j] == max_shape_dims[j]) { + continue; + } + if (max_batch_dim_find && max_batch_dim != j) { + GELOGE(FAILED, "The shape of several nodes between multiple batches does not match."); + return FAILED; + } + max_batch_dim_find = true; + max_batch_dim = j; + if (current_shape_dims[j] > max_shape_dims[j]) { + max_shape_dims[j] = current_shape_dims[j]; + // All ops must have ATTR_NAME_BATCH_LABEL, no need to check return value. + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, max_batch_label); + } + // Only compare the first different dim in shape. + break; + } + } + } + // In every element of virtual_input_nodes_map, the label of the max batch node is the same. + break; + } + return SUCCESS; +} + Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offset) { if (memory_offset_.empty()) { GELOGE(FAILED, "memory_offset_ is empty."); @@ -163,8 +218,6 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offse GE_CHK_STATUS_RET(ReAssignReuseAndNoPaddingContinuousOutputMemory(), "ReAssignReuseAndNoPaddingContinuousOutputMemory Failed!"); - GE_CHK_STATUS_RET(ReAssignMergeMemory(), "ReAssignMergeMemory Failed!"); - GE_CHK_STATUS_RET(ReAssignAtomicMemory(is_loop_graph), "ReAssignAtomicMemory Failed!"); mem_offset = memory_offset_[0].mem_offset_; @@ -177,26 +230,53 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offse return SUCCESS; } +Status GraphMemoryAssigner::AssignZeroCopyMemory(size_t &mem_offset, size_t &zero_mem_copy_size) { + BlockMemAssignerPtr priority_assigner = std::move(mem_assigner_->GetPriorityAssinger()); + GE_IF_BOOL_EXEC(priority_assigner == nullptr, GELOGE(FAILED, "Get priority_assigner failed."); return ge::FAILED;); + + size_t mem_offset_tmp = mem_offset; + + // set offset for zero copy block + for (auto &memory_block : priority_assigner->GetMemoryBlocks()) { + if (memory_block == nullptr || memory_block->deleted_block_ || !memory_block->is_zero_copy_) { + continue; + } + memory_block->Resize(); + memory_block->SetHeadOffset(mem_offset); + mem_offset += memory_block->Size(); + memory_block->SetTailOffset(mem_offset - 1); + GELOGI("mem_offset_ include zero_copy_memory is %zu.", mem_offset); + } + + // set offset for zero copy nodes + priority_assigner->SetOpMemOffset(true); + + zero_mem_copy_size = mem_offset - mem_offset_tmp; + GELOGI("max_mem_offset:%zu, mem_offset:%zu, zero_mem_copy_size:%zu.", mem_offset, mem_offset_tmp, zero_mem_copy_size); + + return SUCCESS; +} + Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { GELOGI("Begin to reassign continuous memory"); Status ret; - for (auto &node : compute_graph_->GetDirectNode()) { + for (auto &node : compute_graph_->GetAllNodes()) { // Get the continuous input type of the node, default is false bool is_input_continuous = false; GE_CHECK_NOTNULL(node->GetOpDesc()); // If GetBool fail, is_input_continuous is false. (void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); - int64_t mem_clean_start = memory_offset_[0].mem_offset_; + // Assign continuous input memory if (is_input_continuous) { - ret = AssignContinuousInputMemory(node); + int64_t mem_clean_start = 0; + int64_t mem_clean_size = 0; + ret = AssignContinuousInputMemory(node, mem_clean_start, mem_clean_size); if (ret != ge::SUCCESS) { GELOGE(ret, "Assign continuous input memory failed!"); return ret; } - memory_offset_[0].mem_offset_ += MEM_ALIGN_SIZE; - // Clean up atomic address, eg, hcom node vector input_indexes; // If GetListInt fail, input_indexes is empty. @@ -211,7 +291,6 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { } else if (is_loop_graph) { GE_CHK_STATUS_RET(SetLoopGraphAtomicAttr(node, mem_clean_start)); } else { - int64_t mem_clean_size = memory_offset_[0].mem_offset_ - mem_clean_start; GE_CHK_STATUS_RET(SetAtomicCleanAttr(nullptr, mem_clean_start, mem_clean_size), "SetAtomicCleanAttr failed."); } } @@ -230,13 +309,7 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { // If the output is ref type and refers to the ref of an input, the name of the output // and the input are the same. Ge encounters ref type, finds matching relationship according // to the names of input and output, and allocates the same memory address, eg: HCOMBroadcast - if (is_ref) { - ret = AssignReferenceMemory(node); - if (ret != ge::SUCCESS) { - GELOGE(ret, "Assign reference memory failed!"); - return ret; - } - } else if (is_output_continuous) { // Assign continuous output memory + if (!is_ref && is_output_continuous) { // Assign continuous output memory ret = AssignContinuousOutputMemory(node); if (ret != ge::SUCCESS) { GELOGE(ret, "Assign reference memory failed!"); @@ -249,14 +322,16 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { return ge::SUCCESS; } -Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) { +Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, + int64_t &continuous_mem_size) { GELOGI("Current node %s needs continuous input.", node->GetName().c_str()); + continuous_mem_start = memory_offset_[0].mem_offset_; + bool continuous_input_alloc = false; + (void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT_ALLOC, continuous_input_alloc); for (auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_data_anchor == nullptr, continue); - if (peer_out_data_anchor == nullptr) { - continue; - } auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); GE_IF_BOOL_EXEC(peer_op_desc == nullptr, continue); bool is_peer_output_continuous = false; @@ -267,28 +342,48 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) // continuous output of the previous node is the same, we can support it. If size != 1, there may be // conflict between the two, we can not support it. auto peer_output_size = peer_op_desc->GetOutputsSize(); - if (is_peer_output_continuous && (peer_output_size != 1)) { - GELOGE(PARAM_INVALID, - "Current node %s requires continuous input, while the previous node %s requires " - "continuous output. There may be conflict between the two. This node is not supported now.", - node->GetOpDesc()->GetName().c_str(), peer_op_desc->GetName().c_str()); - return PARAM_INVALID; - } + GE_IF_BOOL_EXEC(is_peer_output_continuous && (peer_output_size != 1), + GELOGE(PARAM_INVALID, + "Current node %s requires continuous input, while the previous node %s requires " + "continuous output. There may be conflict between the two. This node is not supported now.", + node->GetOpDesc()->GetName().c_str(), peer_op_desc->GetName().c_str()); + return PARAM_INVALID;); bool is_peer_reference = false; // If GetBool fail, is_peer_reference is false. (void)AttrUtils::GetBool(peer_op_desc, ATTR_NAME_REFERENCE, is_peer_reference); - - if (is_peer_reference) { - GELOGE(PARAM_INVALID, - "Current node %s requires continuous input, while the previous node %s requires " - "reference. There may be conflict between the two. This node is not supported now.", - node->GetOpDesc()->GetName().c_str(), peer_op_desc->GetName().c_str()); - return PARAM_INVALID; - } + GE_IF_BOOL_EXEC(is_peer_reference, + GELOGE(PARAM_INVALID, + "Current node %s requires continuous input, while the previous node %s requires " + "reference. There may be conflict between the two. This node is not supported now.", + node->GetOpDesc()->GetName().c_str(), peer_op_desc->GetName().c_str()); + return PARAM_INVALID;); vector output_list = peer_op_desc->GetOutputOffset(); if (peer_out_data_anchor->GetIdx() < static_cast(output_list.size())) { + if (continuous_input_alloc) { + if (in_data_anchor->GetIdx() == 0) { + continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); + } + // can not use else if, incase only one input + if (in_data_anchor->GetIdx() == static_cast(node->GetAllInDataAnchors().size()) - 1) { + int64_t tensor_desc_size = 0; + Status ret = ge::TensorUtils::GetSize(*(peer_op_desc->GetOutputDescPtr(peer_out_data_anchor->GetIdx())), + tensor_desc_size); + GE_IF_BOOL_EXEC(ret != ge::SUCCESS, GELOGE(FAILED, "GetSize failed."); return FAILED;); + + tensor_desc_size = (tensor_desc_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; + continuous_mem_size = + output_list.at(peer_out_data_anchor->GetIdx()) - continuous_mem_start + tensor_desc_size + MEM_ALIGN_SIZE; + } + GELOGI( + "[IMAS]Check Continuous input : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " + "real_size[%u].", + node->GetOwnerComputeGraph()->GetName().c_str(), peer_op_desc->GetName().c_str(), + peer_out_data_anchor->GetIdx(), output_list.at(peer_out_data_anchor->GetIdx()), peer_op_desc->GetStreamId(), + 0, 0); + continue; + } output_list.at(peer_out_data_anchor->GetIdx()) = memory_offset_[0].mem_offset_; } else { GELOGE(FAILED, "index : %d is out of range.", peer_out_data_anchor->GetIdx()); @@ -296,25 +391,24 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) } peer_op_desc->SetOutputOffset(output_list); size_t pre_mem_offset = memory_offset_[0].mem_offset_; - std::vector offsets_for_l1_fusion = {}; + std::vector offsets_for_fusion = {}; bool has_offset_attr = - AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION, offsets_for_l1_fusion); + AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION, offsets_for_fusion); int64_t tensor_desc_size = 0; if (has_offset_attr) { - if (peer_out_data_anchor->GetIdx() < static_cast(offsets_for_l1_fusion.size())) { - auto offset_for_l1_fusion = offsets_for_l1_fusion[peer_out_data_anchor->GetIdx()]; - memory_offset_[0].mem_offset_ += offset_for_l1_fusion; + if (peer_out_data_anchor->GetIdx() < static_cast(offsets_for_fusion.size())) { + auto offset_for_fusion = offsets_for_fusion[peer_out_data_anchor->GetIdx()]; + memory_offset_[0].mem_offset_ += offset_for_fusion; } else { - GELOGE(FAILED, "l1 fusion: peer node %s index : %d is out of range.", peer_op_desc->GetName().c_str(), + GELOGE(FAILED, "fusion: peer node %s index : %d is out of range.", peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx()); return FAILED; } } else { - if (TensorUtils::GetSize(*(peer_op_desc->GetOutputDescPtr(peer_out_data_anchor->GetIdx())), tensor_desc_size) != - SUCCESS) { - GELOGE(FAILED, "GetSize failed."); - return FAILED; - } + Status ret = + TensorUtils::GetSize(*(peer_op_desc->GetOutputDescPtr(peer_out_data_anchor->GetIdx())), tensor_desc_size); + GE_IF_BOOL_EXEC(ret != ge::SUCCESS, GELOGE(FAILED, "GetSize failed."); return FAILED;); + memory_offset_[0].mem_offset_ += tensor_desc_size; } @@ -331,6 +425,10 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) pre_mem_offset, peer_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); } + memory_offset_[0].mem_offset_ += MEM_ALIGN_SIZE; + if (!continuous_input_alloc) { + continuous_mem_size = memory_offset_[0].mem_offset_ - continuous_mem_start; + } return SUCCESS; } @@ -371,7 +469,70 @@ Status GraphMemoryAssigner::AssignContinuousOutputMemory(const ge::NodePtr &node return ge::SUCCESS; } +Status GraphMemoryAssigner::ReAssignVirtualInputNodeMemory(NodePtr node, size_t &mem_offset_reuse) { + OpDescPtr op_desc = node->GetOpDesc(); + vector output_list = op_desc->GetOutputOffset(); + if (output_list.empty()) { + GELOGE(FAILED, "Outputoffset is empty node name:%s", node->GetName().c_str()); + return FAILED; + } + output_list.at(0) = mem_offset_reuse; + op_desc->SetOutputOffset(output_list); + GELOGI("Set virtual input node %s output offset to %zu.", op_desc->GetName().c_str(), mem_offset_reuse); + + int64_t attr_dim_index; + bool get_attr_dim_flag = ge::AttrUtils::GetInt(op_desc, ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX, attr_dim_index); + if (!get_attr_dim_flag) { + GELOGE(FAILED, "Get attr _reuse_input_on_dim_index failed."); + return FAILED; + } + + size_t extra_memory_size = 0; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor); + auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + vector output_offsets = peer_op_desc->GetOutputOffset(); + if (peer_out_data_anchor->GetIdx() >= static_cast(output_offsets.size())) { + GELOGE(ge::FAILED, "Index : %d is out of range.", peer_out_data_anchor->GetIdx()); + return ge::FAILED; + } + output_offsets.at(peer_out_data_anchor->GetIdx()) = mem_offset_reuse; + peer_op_desc->SetOutputOffset(output_offsets); + size_t pre_mem_offset = mem_offset_reuse; + + // Calculate tensor real size of each piece of data and out size of complete data + ge::ConstGeTensorDescPtr output_desc = peer_op_desc->GetOutputDescPtr(peer_out_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(output_desc); + int64_t output_mem_size; + int64_t batch_dim_num = 1; + int64_t out_size; + if (CalculateTensorRealSizeAndOutSize(output_desc, attr_dim_index, output_mem_size, batch_dim_num, out_size) != + SUCCESS) { + GELOGE(FAILED, "CalculateTensorRealSizeAndOutSize failed for node %s output [%d].", + peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx()); + return FAILED; + } + + mem_offset_reuse += output_mem_size; + extra_memory_size = extra_memory_size + out_size - output_mem_size; + + GELOGI( + "[IMAS]Virtual node optimize: set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%ld] " + "real_size[%ld].", + node->GetOwnerComputeGraph()->GetName().c_str(), peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx(), + pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); + } + mem_offset_reuse += extra_memory_size; + size_t after_mem_offset = mem_offset_reuse; + GELOGI("After reassign virtual input node[name: %s, type: %s] memory, memory offset = %zu.", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset); + return SUCCESS; +} + Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { + map> mem_reuse_virtual_input_nodes_map; for (const auto &n : compute_graph_->GetAllNodes()) { OpDescPtr op_desc = n->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -383,68 +544,128 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { GE_IF_BOOL_EXEC(!get_reuse_flag, continue); if (attr_reuse && attr_continuous) { - vector output_list = op_desc->GetOutputOffset(); - if (output_list.empty()) { - GELOGE(FAILED, "Outputoffset is empty node name:%s", n->GetName().c_str()); - return FAILED; - } - output_list.at(0) = memory_offset_[0].mem_offset_; - op_desc->SetOutputOffset(output_list); - GELOGI("Set node %s output offset to %zu.", op_desc->GetName().c_str(), memory_offset_[0].mem_offset_); - - int64_t attr_dim_index; - bool get_attr_dim_flag = ge::AttrUtils::GetInt(op_desc, ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX, attr_dim_index); - if (!get_attr_dim_flag) { - GELOGE(FAILED, "Get attr _reuse_input_on_dim_index failed."); + if (op_desc->GetOutputsSize() != kVirtualInputNodeOutputSize) { + // When current virtual node has several outputs, can't directly determine which input is the tensor for reuse. + GELOGE(FAILED, "Only one output is supported, current virtual node %s has %zu inputs.", n->GetName().c_str(), + op_desc->GetOutputsSize()); return FAILED; } - size_t extra_memory_size = 0; - for (const auto &in_data_anchor : n->GetAllInDataAnchors()) { - auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor); - auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - vector output_offsets = peer_op_desc->GetOutputOffset(); - if (peer_out_data_anchor->GetIdx() >= static_cast(output_offsets.size())) { - GELOGE(ge::FAILED, "Index : %d is out of range.", peer_out_data_anchor->GetIdx()); - return ge::FAILED; + GELOGD("Start to reassign memory for virtual input node, memory offset = %zu.", memory_offset_[0].mem_offset_); + string batch_label_string; + // Not all ops have ATTR_NAME_BATCH_LABEL, no need to check return value, only check out parameter + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label_string); + if (batch_label_string.empty()) { + size_t node_mem_offset = memory_offset_[0].mem_offset_; + // No ATTR_NAME_BATCH_LABEL, no need to reuse memory. + Status status = ReAssignVirtualInputNodeMemory(n, node_mem_offset); + if (status != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual input node failed, node name: %s.", n->GetName().c_str()); + return FAILED; } - output_offsets.at(peer_out_data_anchor->GetIdx()) = memory_offset_[0].mem_offset_; - peer_op_desc->SetOutputOffset(output_offsets); - size_t pre_mem_offset = memory_offset_[0].mem_offset_; - - // calculate tensor real size of each piece of data and out size of complete data - ge::ConstGeTensorDescPtr output_desc = peer_op_desc->GetOutputDescPtr(peer_out_data_anchor->GetIdx()); - GE_CHECK_NOTNULL(output_desc); - int64_t output_mem_size; - int64_t batch_dim_num = 1; - int64_t out_size; - if (CalculateTensorRealSizeAndOutSize(output_desc, attr_dim_index, output_mem_size, batch_dim_num, out_size) != - SUCCESS) { - GELOGE(FAILED, "CalculateTensorRealSizeAndOutSize failed for node %s output [%d].", - peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx()); + + memory_offset_[0].mem_offset_ = node_mem_offset; + AlignMemOffset(MEM_ALIGN_SIZE); + GELOGD("After reassign memory for virtual input node, align memory = %zu.", memory_offset_[0].mem_offset_); + } else { + // Has ATTR_NAME_BATCH_LABEL, for dynamic multi-batch node, need to reuse memory. + string current_node_full_name = op_desc->GetName(); + size_t pos = current_node_full_name.find(kMbatchNodeNameFlag); + if (pos == string::npos) { + GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual input node, node name: %s.", + kMbatchNodeNameFlag, n->GetName().c_str()); return FAILED; } + string fixed_name = current_node_full_name.substr(0, pos); + vector parallel_virtual_input_nodes; + if (mem_reuse_virtual_input_nodes_map.count(fixed_name) != 0) { + parallel_virtual_input_nodes = mem_reuse_virtual_input_nodes_map[fixed_name]; + } + parallel_virtual_input_nodes.emplace_back(n); + mem_reuse_virtual_input_nodes_map[fixed_name] = parallel_virtual_input_nodes; + } + } + } + + int32_t mem_reuse_model = 0; + if (ReAssignVirtualNodesMemory(mem_reuse_virtual_input_nodes_map, mem_reuse_model) != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual input nodes failed."); + return FAILED; + } + return SUCCESS; +} - memory_offset_[0].mem_offset_ += output_mem_size; - extra_memory_size = extra_memory_size + out_size - output_mem_size; +Status GraphMemoryAssigner::ReAssignVirtualOutputNodeMemory(NodePtr node, size_t &mem_offset_reuse) { + OpDescPtr op_desc = node->GetOpDesc(); + + // 1. set memory of to be reused input tensor + auto in_data_anchor_list = node->GetAllInDataAnchors(); + auto peer_out_data_anchor = in_data_anchor_list.at(0)->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_data_anchor); + auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + vector in_node_output_offsets = peer_op_desc->GetOutputOffset(); + if (peer_out_data_anchor->GetIdx() >= static_cast(in_node_output_offsets.size())) { + GELOGE(FAILED, "Index : %d is out of range.", peer_out_data_anchor->GetIdx()); + return FAILED; + } + in_node_output_offsets.at(peer_out_data_anchor->GetIdx()) = mem_offset_reuse; + peer_op_desc->SetOutputOffset(in_node_output_offsets); + GELOGI("Set virtual output node %s input data offset to %zu.", op_desc->GetName().c_str(), mem_offset_reuse); - GELOGI( - "[IMAS]Virtual node optimize : set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%ld] " - "real_size[%ld].", - n->GetOwnerComputeGraph()->GetName().c_str(), peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx(), - pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); - } - memory_offset_[0].mem_offset_ += extra_memory_size; - GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu.", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); + // 2. set memory of output tensor + vector output_list = op_desc->GetOutputOffset(); + if (output_list.empty()) { + GELOGE(FAILED, "Outputoffset is empty, node name: %s", node->GetName().c_str()); + return FAILED; + } + if (op_desc->GetOutputsSize() > output_list.size()) { + GELOGE(FAILED, "The size %zu of op_desc is more than output_list's size %zu.", op_desc->GetOutputsSize(), + output_list.size()); + return FAILED; + } + int64_t attr_dim_index; + bool get_attr_dim_flag = ge::AttrUtils::GetInt(op_desc, ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX, attr_dim_index); + if (!get_attr_dim_flag) { + GELOGE(FAILED, "Get attr _reuse_input_on_dim_index failed."); + return FAILED; + } + + size_t extra_memory_size = 0; + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + output_list[out_data_anchor->GetIdx()] = mem_offset_reuse; + size_t pre_mem_offset = mem_offset_reuse; + + // calculate tensor real size of each piece of data and out size of complete data + ge::ConstGeTensorDescPtr output_desc = op_desc->GetOutputDescPtr(out_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(output_desc); + int64_t output_mem_size; + int64_t batch_dim_num = 1; + int64_t out_size; + if (CalculateTensorRealSizeAndOutSize(output_desc, attr_dim_index, output_mem_size, batch_dim_num, out_size) != + SUCCESS) { + GELOGE(FAILED, "CalculateTensorRealSizeAndOutSize failed for node %s output [%d].", op_desc->GetName().c_str(), + out_data_anchor->GetIdx()); + return FAILED; } + + mem_offset_reuse += output_mem_size; + extra_memory_size = extra_memory_size + out_size - output_mem_size; + + GELOGI("[IMAS]Virtual node optimize: set %s name[%s] output[%d] offset to [%zu], size[%ld], real_size[%ld].", + node->GetOwnerComputeGraph()->GetName().c_str(), op_desc->GetName().c_str(), out_data_anchor->GetIdx(), + pre_mem_offset, out_size, output_mem_size); } + op_desc->SetOutputOffset(output_list); + mem_offset_reuse += extra_memory_size; + size_t after_mem_offset = mem_offset_reuse; + GELOGI("After reassign virtual output node[name: %s, type: %s] memory, memory offset = %zu.", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset); return SUCCESS; } Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { + map> mem_reuse_virtual_output_nodes_map; for (const auto &n : compute_graph_->GetAllNodes()) { OpDescPtr op_desc = n->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -457,151 +678,129 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { if (attr_reuse && attr_continuous) { auto in_data_anchor_list = n->GetAllInDataAnchors(); - if (in_data_anchor_list.size() != 1) { - // When current node has several inputs, can't directly determine which input is the tensor for reuse. - GELOGE(FAILED, "Only one input is supported, current node %s has %zu inputs.", n->GetName().c_str(), + if (in_data_anchor_list.size() != kVirtualOutputNodeInputSize) { + // When current virtual node has several inputs, can't directly determine which input is the tensor for reuse. + GELOGE(FAILED, "Only one input is supported, current virtual node %s has %zu inputs.", n->GetName().c_str(), in_data_anchor_list.size()); return FAILED; } - // 1. set memory of to be reused input tensor - auto peer_out_data_anchor = in_data_anchor_list.at(0)->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor); - auto peer_op_desc = peer_out_data_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - vector in_node_output_offsets = peer_op_desc->GetOutputOffset(); - if (peer_out_data_anchor->GetIdx() >= static_cast(in_node_output_offsets.size())) { - GELOGE(FAILED, "Index : %d is out of range.", peer_out_data_anchor->GetIdx()); - return FAILED; - } - in_node_output_offsets.at(peer_out_data_anchor->GetIdx()) = memory_offset_[0].mem_offset_; - peer_op_desc->SetOutputOffset(in_node_output_offsets); - GELOGI("Set node %s input data offset to %zu.", op_desc->GetName().c_str(), memory_offset_[0].mem_offset_); - - // 2. set memory of output tensor - vector output_list = op_desc->GetOutputOffset(); - if (output_list.empty()) { - GELOGE(FAILED, "Outputoffset is empty, node name: %s", n->GetName().c_str()); - return FAILED; - } - if (op_desc->GetOutputsSize() > output_list.size()) { - GELOGE(FAILED, "The size %zu of op_desc is more than output_list's size %zu.", op_desc->GetOutputsSize(), - output_list.size()); - return FAILED; - } - int64_t attr_dim_index; - bool get_attr_dim_flag = ge::AttrUtils::GetInt(op_desc, ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX, attr_dim_index); - if (!get_attr_dim_flag) { - GELOGE(FAILED, "Get attr _reuse_input_on_dim_index failed."); - return FAILED; - } - - size_t extra_memory_size = 0; - for (auto &out_data_anchor : n->GetAllOutDataAnchors()) { - output_list[out_data_anchor->GetIdx()] = memory_offset_[0].mem_offset_; - size_t pre_mem_offset = memory_offset_[0].mem_offset_; - - // calculate tensor real size of each piece of data and out size of complete data - ge::ConstGeTensorDescPtr output_desc = op_desc->GetOutputDescPtr(out_data_anchor->GetIdx()); - GE_CHECK_NOTNULL(output_desc); - int64_t output_mem_size; - int64_t batch_dim_num = 1; - int64_t out_size; - if (CalculateTensorRealSizeAndOutSize(output_desc, attr_dim_index, output_mem_size, batch_dim_num, out_size) != - SUCCESS) { - GELOGE(FAILED, "CalculateTensorRealSizeAndOutSize failed for node %s output [%d].", - op_desc->GetName().c_str(), out_data_anchor->GetIdx()); + GELOGD("Start to reassign memory for virtual output node, memory offset = %zu.", memory_offset_[0].mem_offset_); + string batch_label_string; + // Not all ops have ATTR_NAME_BATCH_LABEL, no need to check return value, only check out parameter + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label_string); + if (batch_label_string.empty()) { + size_t node_mem_offset = memory_offset_[0].mem_offset_; + // No ATTR_NAME_BATCH_LABEL, no need to reuse memory. + Status status = ReAssignVirtualOutputNodeMemory(n, node_mem_offset); + if (status != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual output node failed, node name: %s.", n->GetName().c_str()); return FAILED; } - - memory_offset_[0].mem_offset_ += output_mem_size; - extra_memory_size = extra_memory_size + out_size - output_mem_size; - - GELOGI("[IMAS]Virtual node optimize : set %s name[%s] output[%d] offset to [%zu], size[%ld], real_size[%ld].", - n->GetOwnerComputeGraph()->GetName().c_str(), op_desc->GetName().c_str(), out_data_anchor->GetIdx(), - pre_mem_offset, out_size, output_mem_size); + memory_offset_[0].mem_offset_ = node_mem_offset; + AlignMemOffset(MEM_ALIGN_SIZE); + GELOGD("After reassign memory for virtual output node, align memory = %zu.", memory_offset_[0].mem_offset_); + } else { + // Has ATTR_NAME_BATCH_LABEL, for dynamic multi-batch node, need to reuse memory. + string current_node_full_name = op_desc->GetName(); + size_t pos = current_node_full_name.find(kMbatchNodeNameFlag); + if (pos == string::npos) { + GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual output node, node name: %s.", + kMbatchNodeNameFlag, n->GetName().c_str()); + return FAILED; + } + string fixed_name = current_node_full_name.substr(0, pos); + vector parallel_virtual_output_nodes; + if (mem_reuse_virtual_output_nodes_map.count(fixed_name) != 0) { + parallel_virtual_output_nodes = mem_reuse_virtual_output_nodes_map[fixed_name]; + } + parallel_virtual_output_nodes.emplace_back(n); + mem_reuse_virtual_output_nodes_map[fixed_name] = parallel_virtual_output_nodes; } - op_desc->SetOutputOffset(output_list); - memory_offset_[0].mem_offset_ += extra_memory_size; - GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu.", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); } } + + int32_t mem_reuse_model = 1; + if (ReAssignVirtualNodesMemory(mem_reuse_virtual_output_nodes_map, mem_reuse_model) != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual output nodes failed."); + return FAILED; + } return SUCCESS; } -Status GraphMemoryAssigner::ReAssignMergeMemory() { - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { - GE_IF_BOOL_EXEC(n->GetOpDesc() == nullptr, continue); - string node_type; - GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get node type fail."); - if (node_type != STREAMMERGE) { - continue; - } - - vector> input_node_list; - for (const auto &in_anchor : n->GetAllInDataAnchors()) { - ge::OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - std::string in_name; - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, in_name) && !in_name.empty(), { - ge::NodePtr in_node = compute_graph_->FindNode(in_name); - GE_CHECK_NOTNULL(in_node); - input_node_list.emplace_back(std::make_pair(0, in_node)); - }); - continue; - } - ge::NodePtr src_node = out_anchor->GetOwnerNode(); - input_node_list.emplace_back(std::make_pair(out_anchor->GetIdx(), src_node)); - } - - int64_t data_output_offset = -1; - int64_t max_output_size = -1; - for (auto &iter : input_node_list) { - int index = iter.first; - NodePtr src_node = iter.second; - GE_CHECK_NOTNULL(src_node->GetOpDesc()); - int64_t tmp_output_size = src_node->GetOpDesc()->GetOutputDesc(index).GetShape().GetShapeSize(); - if ((data_output_offset == -1) || (tmp_output_size > max_output_size)) { - vector output_list = src_node->GetOpDesc()->GetOutputOffset(); - int output_size = static_cast(output_list.size()); - if (index >= output_size) { - GELOGE(INTERNAL_ERROR, "out_anchor[%d] >= output_list[%d]", index, output_size); - return INTERNAL_ERROR; +Status GraphMemoryAssigner::ReAssignVirtualNodesMemory(map> &mem_reuse_nodes_map, + int32_t mem_reuse_model) { + // Find max batch label value + string max_batch_label; + if (GetMaxBatchLabel(mem_reuse_nodes_map, mem_reuse_model, max_batch_label) != SUCCESS) { + GELOGE(FAILED, "Get max batch label failed."); + return FAILED; + } + GELOGI("The batch label of max batch virtual nodes is %s.", max_batch_label.c_str()); + + // Assign memory of max batch nodes that have the same batch label. + GELOGD("Start to reassign memory for max batch virtual nodes, memory offset = %zu.", memory_offset_[0].mem_offset_); + vector nodes_mem_offset_list; + for (auto &i_map : mem_reuse_nodes_map) { + size_t max_batch_node_mem_offset = memory_offset_[0].mem_offset_; + nodes_mem_offset_list.emplace_back(max_batch_node_mem_offset); + + vector virtual_nodes_list = i_map.second; + for (auto &i_node : virtual_nodes_list) { + // Op_desc is not nullptr, it has been checked. + OpDescPtr op_desc = i_node->GetOpDesc(); + string batch_label_string; + // All ops must have ATTR_NAME_BATCH_LABEL, no need to check return value. + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label_string); + if (batch_label_string == max_batch_label) { + Status status = SUCCESS; + if (mem_reuse_model == kVirtualInputNodeMemoryReuse) { + status = ReAssignVirtualInputNodeMemory(i_node, max_batch_node_mem_offset); + } else if (mem_reuse_model == kVirtualOutputNodeMemoryReuse) { + status = ReAssignVirtualOutputNodeMemory(i_node, max_batch_node_mem_offset); + } else { + GELOGE(FAILED, "Invalid parameter memory reuse model, which is: %d.", mem_reuse_model); + return FAILED; } - data_output_offset = output_list[index]; - max_output_size = tmp_output_size; + if (status != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual node failed, node name: %s.", i_node->GetName().c_str()); + return FAILED; + } + memory_offset_[0].mem_offset_ = max_batch_node_mem_offset; + AlignMemOffset(MEM_ALIGN_SIZE); + GELOGD("After reassign memory for virtual node, align memory = %zu.", memory_offset_[0].mem_offset_); + // Only assign memory of max batch nodes. + break; } - GELOGD("merge=%s, input=%s, size=%ld, offset=%ld, max_size=%ld", n->GetName().c_str(), - src_node->GetName().c_str(), tmp_output_size, data_output_offset, max_output_size); } + } - vector input_list; - for (auto &iter : input_node_list) { - int index = iter.first; - NodePtr src_node = iter.second; - GE_CHECK_NOTNULL(src_node->GetOpDesc()); - vector output_list = src_node->GetOpDesc()->GetOutputOffset(); - int output_size = static_cast(output_list.size()); - if (index >= output_size) { - GELOGE(INTERNAL_ERROR, "out_anchor[%d] >= output_list[%d]", index, output_size); - return INTERNAL_ERROR; + // Assign memory of remaining nodes that have the same fixed_name. + GELOGD("Start to reassign memory for remaining batch virtual nodes, memory offset = %zu.", + memory_offset_[0].mem_offset_); + size_t memory_reuse_index = 0; + for (auto &i_map : mem_reuse_nodes_map) { + vector virtual_nodes_list = i_map.second; + for (auto &i_node : virtual_nodes_list) { + size_t remaining_batch_node_mem_offset = nodes_mem_offset_list[memory_reuse_index]; + Status status = SUCCESS; + if (mem_reuse_model == kVirtualInputNodeMemoryReuse) { + status = ReAssignVirtualInputNodeMemory(i_node, remaining_batch_node_mem_offset); + } else if (mem_reuse_model == kVirtualOutputNodeMemoryReuse) { + status = ReAssignVirtualOutputNodeMemory(i_node, remaining_batch_node_mem_offset); + } else { + GELOGE(FAILED, "Invalid parameter memory reuse model, which is: %d.", mem_reuse_model); + return FAILED; } - output_list[index] = data_output_offset; - src_node->GetOpDesc()->SetOutputOffset(output_list); - GELOGI( - "[IMAS]ReAssignMergeMemory : Set %s name[%s] output[%d] offset to [%ld] stream_id[%ld] size[%ld] " - "real_size[%ld].", - n->GetOwnerComputeGraph()->GetName().c_str(), src_node->GetOpDesc()->GetName().c_str(), index, - data_output_offset, src_node->GetOpDesc()->GetStreamId(), max_output_size, max_output_size); - input_list.emplace_back(data_output_offset); + if (status != SUCCESS) { + GELOGE(FAILED, "Reassign memory of virtual node failed, node name: %s.", i_node->GetName().c_str()); + return FAILED; + } } - - n->GetOpDesc()->SetInputOffset(input_list); + memory_reuse_index++; } - GELOGI("After reassign merge memory, memoffset = %zu.", memory_offset_[0].mem_offset_); return SUCCESS; } @@ -614,7 +813,7 @@ Status GraphMemoryAssigner::ReAssignAtomicMemory(bool is_loop_graph) { int64_t atomic_mem_start = static_cast(memory_offset_[0].mem_offset_); GELOGI("Begin to reAssign atomic memory, atomic initial address mem_offset = %zu!", memory_offset_[0].mem_offset_); - for (auto &node : compute_graph_->GetDirectNode()) { + for (auto &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); if (node_op_desc == nullptr) { continue; @@ -687,161 +886,63 @@ Status GraphMemoryAssigner::ReAssignAtomicMemory(bool is_loop_graph) { return SUCCESS; } -Status GraphMemoryAssigner::AssignSubgraphInputsMemory() { - GE_CHECK_NOTNULL(compute_graph_); - for (ComputeGraphPtr &graph : compute_graph_->GetAllSubgraphs()) { - GE_CHECK_NOTNULL(graph); - const NodePtr &parent_node = graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - const OpDescPtr &parent_desc = parent_node->GetOpDesc(); - GE_CHECK_NOTNULL(parent_desc); - - const vector input_offsets = parent_desc->GetInputOffset(); - GELOGI("SubGraph: %s graph input size: %u, parent input size: %zu, parent input offset: %zu.", - graph->GetName().c_str(), graph->GetInputSize(), parent_desc->GetInputsSize(), input_offsets.size()); - if (parent_desc->GetInputsSize() < graph->GetInputSize()) { - GELOGE(FAILED, "SubGraph: %s Input size: %u is grater than parent input size: %zu.", graph->GetName().c_str(), - graph->GetInputSize(), parent_desc->GetInputsSize()); - return FAILED; +Status GraphMemoryAssigner::AssignReferenceMemory() { + for (auto &node : compute_graph_->GetDirectNode()) { + // Get the reference type of the node, default is false + bool is_ref = false; + // If GetBool fail, is_ref is false. + (void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_REFERENCE, is_ref); + if (!is_ref) { + continue; } - for (NodePtr &node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (node->GetType() != DATA_TYPE) { - continue; - } - - // Find functional node input anchor. - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "Node: %s get attr %s failed", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return FAILED; - } - - GELOGI("SubGraph: %s Parent input index: %u.", graph->GetName().c_str(), parent_index); - if (parent_index >= input_offsets.size()) { - GELOGE(FAILED, "SubGraph: %s Parent input size: %zu, parent index: %u.", graph->GetName().c_str(), parent_index, - input_offsets.size()); - return FAILED; - } - - // Find subgraph data input anchor. - OutDataAnchorPtr out_anchor = node->GetOutDataAnchor(kDataOutputIndex); - GE_CHECK_NOTNULL(out_anchor); + GELOGI("Current node %s needs to support the reference relationship between output and input.", + node->GetName().c_str()); - for (InDataAnchorPtr &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_anchor); - const NodePtr &peer_node = peer_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); + auto out_op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(ge::FAILED, "out_op_desc is null."); return ge::FAILED); + vector output_list = out_op_desc->GetOutputOffset(); - vector input_offset = peer_node->GetOpDesc()->GetInputOffset(); - if (peer_anchor->GetIdx() < 0 || input_offset.size() <= static_cast(peer_anchor->GetIdx())) { - GELOGE(FAILED, "SubGraph: %s Node: %s invalid anchor index: %d.", graph->GetName().c_str(), - peer_node->GetName().c_str(), peer_anchor->GetIdx()); - return FAILED; - } - - input_offset[peer_anchor->GetIdx()] = input_offsets[parent_index]; - peer_node->GetOpDesc()->SetInputOffset(input_offset); - } + if (out_op_desc->GetOutputsSize() > output_list.size()) { + GELOGE(ge::FAILED, "The size %zu of node output desc is more than output_list's size %zu.", + out_op_desc->GetOutputsSize(), output_list.size()); + return ge::FAILED; } - } - - return SUCCESS; -} -Status GraphMemoryAssigner::AssignSubgraphOutputsMemory() { - GE_CHECK_NOTNULL(compute_graph_); - for (ComputeGraphPtr &graph : compute_graph_->GetAllSubgraphs()) { - GE_CHECK_NOTNULL(graph); - const NodePtr &parent_node = graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - const NodePtr &net_output_node = graph->FindNode(NODE_NAME_NET_OUTPUT); - GE_CHECK_NOTNULL(net_output_node); - const OpDescPtr &net_output_desc = net_output_node->GetOpDesc(); - GE_CHECK_NOTNULL(net_output_desc); - - const vector input_offsets = net_output_desc->GetInputOffset(); - for (size_t i = 0; i < input_offsets.size(); ++i) { - uint32_t parent_index = 0; - if (!AttrUtils::GetInt(net_output_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGW("SubGraph: %s input tensor %zu attr %s not found.", graph->GetName().c_str(), i, - ATTR_NAME_PARENT_NODE_INDEX.c_str()); - continue; - } - - const OutDataAnchorPtr &out_anchor = parent_node->GetOutDataAnchor(parent_index); - GE_CHECK_NOTNULL(out_anchor); - for (InDataAnchorPtr &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_anchor); - const NodePtr &peer_node = peer_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - - vector input_offset = peer_node->GetOpDesc()->GetInputOffset(); - if (peer_anchor->GetIdx() < 0 || input_offset.size() <= static_cast(peer_anchor->GetIdx())) { - GELOGE(FAILED, "SubGraph: %s Node: %s invalid anchor index: %d.", graph->GetName().c_str(), - peer_node->GetName().c_str(), peer_anchor->GetIdx()); - return FAILED; - } + map input_name_index; + for (const auto &input_name : out_op_desc->GetAllInputNames()) { + int index = out_op_desc->GetInputIndexByName(input_name); + input_name_index.emplace(input_name, index); + } - input_offset[peer_anchor->GetIdx()] = input_offsets[i]; - peer_node->GetOpDesc()->SetInputOffset(input_offset); + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + string out_data_anchor_name = out_op_desc->GetOutputNameByIndex(out_data_anchor->GetIdx()); + auto iter = input_name_index.find(out_data_anchor_name); + if (iter != input_name_index.end()) { + int index = iter->second; + GELOGI("Reference memory: input anchor index = %d, input anchor name = %s, output anchor name = %s.", index, + iter->first.c_str(), out_data_anchor_name.c_str()); + GE_CHECK_NOTNULL(node->GetInDataAnchor(index)); + auto peer_out_anchor = node->GetInDataAnchor(index)->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + int peer_out_anchor_index = peer_out_anchor->GetIdx(); + auto peer_out_node = peer_out_anchor->GetOwnerNode(); + auto peer_out_op_desc = peer_out_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_out_op_desc); + output_list[out_data_anchor->GetIdx()] = peer_out_op_desc->GetOutputOffset()[peer_out_anchor_index]; + GELOGI("Reference output : Set %s name[%s] output[%d] offset to [%ld] stream_id[%ld]", + node->GetOwnerComputeGraph()->GetName().c_str(), peer_out_op_desc->GetName().c_str(), + out_data_anchor->GetIdx(), output_list[out_data_anchor->GetIdx()], peer_out_op_desc->GetStreamId()); + } else { + GELOGI("Reference output : origin %s name[%s] output[%d] offset is [%ld] stream_id[%ld]", + node->GetOwnerComputeGraph()->GetName().c_str(), out_op_desc->GetName().c_str(), + out_data_anchor->GetIdx(), output_list[out_data_anchor->GetIdx()], out_op_desc->GetStreamId()); } } - } - - return SUCCESS; -} - -Status GraphMemoryAssigner::AssignReferenceMemory(const ge::NodePtr &node) { - GELOGI("Current node %s needs to support the reference relationship between output and input.", - node->GetName().c_str()); - - auto out_op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(ge::FAILED, "out_op_desc is null."); return ge::FAILED); - vector output_list = out_op_desc->GetOutputOffset(); - - if (out_op_desc->GetOutputsSize() > output_list.size()) { - GELOGE(ge::FAILED, "The size %zu of node output desc is more than output_list's size %zu.", - out_op_desc->GetOutputsSize(), output_list.size()); - return ge::FAILED; - } - map input_name_index; - for (const auto &input_name : out_op_desc->GetAllInputNames()) { - int index = out_op_desc->GetInputIndexByName(input_name); - input_name_index.emplace(input_name, index); - } - - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - string out_data_anchor_name = out_op_desc->GetOutputNameByIndex(out_data_anchor->GetIdx()); - auto iter = input_name_index.find(out_data_anchor_name); - if (iter != input_name_index.end()) { - int index = iter->second; - GELOGI("Reference memory: input anchor index = %d, input anchor name = %s, output anchor name = %s.", index, - iter->first.c_str(), out_data_anchor_name.c_str()); - GE_CHECK_NOTNULL(node->GetInDataAnchor(index)); - auto peer_out_anchor = node->GetInDataAnchor(index)->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - int peer_out_anchor_index = peer_out_anchor->GetIdx(); - auto peer_out_node = peer_out_anchor->GetOwnerNode(); - auto peer_out_op_desc = peer_out_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_out_op_desc); - output_list[out_data_anchor->GetIdx()] = peer_out_op_desc->GetOutputOffset()[peer_out_anchor_index]; - GELOGI("Reference output : Set %s name[%s] output[%d] offset to [%ld] stream_id[%ld]", - node->GetOwnerComputeGraph()->GetName().c_str(), peer_out_op_desc->GetName().c_str(), - out_data_anchor->GetIdx(), output_list[out_data_anchor->GetIdx()], peer_out_op_desc->GetStreamId()); - } else { - GELOGI("Reference output : origin %s name[%s] output[%d] offset is [%ld] stream_id[%ld]", - node->GetOwnerComputeGraph()->GetName().c_str(), out_op_desc->GetName().c_str(), out_data_anchor->GetIdx(), - output_list[out_data_anchor->GetIdx()], out_op_desc->GetStreamId()); - } + out_op_desc->SetOutputOffset(output_list); } - out_op_desc->SetOutputOffset(output_list); - return ge::SUCCESS; } @@ -914,8 +1015,9 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { // If you have already assigned an atomic address, skip it, and you don't need to reassign it. if (is_assigned_mem) { GELOGI( - "[IMAS]Atomic output : we have assigned atomic memory as the input of next node in " - "ReAssignContinuousMemory function."); + "Node %s atomic output : we have assigned atomic memory as the input of next node in " + "ReAssignContinuousMemory function.", + op_desc->GetName().c_str()); continue; } @@ -1015,7 +1117,7 @@ Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPt } Status GraphMemoryAssigner::CheckOffset() { - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node->GetOpDesc()); vector input_list = node->GetOpDesc()->GetInputOffset(); for (auto input : input_list) { @@ -1049,7 +1151,7 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { } GEEVENT("[IMAS]AfterAssignMemory : %s memoffset[%zu]", compute_graph_->GetName().c_str(), memory_offset_[0].mem_offset_); - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { if (UpdateOpInputOffset(node) != ge::SUCCESS) { GELOGE(ge::FAILED, "Update op input offset failed"); return ge::FAILED; @@ -1058,6 +1160,24 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { return ge::SUCCESS; } +ge::Status GraphMemoryAssigner::UpdateConstArgsOffset(const NodePtr &node, vector &input_list) const { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return SUCCESS; + } + + // Subgraph Data Node, check for constant input. + std::string op_type; + NodePtr in_node = NodeUtils::GetParentInput(node); + if (!NodeUtils::GetConstOpType(in_node, op_type)) { + return SUCCESS; // not constant input. + } + + vector const_input_list = in_node->GetOpDesc()->GetOutputOffset(); + node->GetOpDesc()->SetOutputOffset(const_input_list); // Set Data output same as const output. + return SUCCESS; +} + ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector &input_list) const { vector origin_input_list; vector memory_type; @@ -1084,17 +1204,17 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< auto mem_type_size = memory_type.size(); if ((input_size != mem_type_size) || (input_size != ori_input_offset_list_size)) { GELOGE(ge::FAILED, - "L1fusion: input_size[%zu] diff from memory_type_size[%zu]" + "fusion: node[%s] input_size[%zu] diff from memory_type_size[%zu]" " from ori_input_offset_list_size[%lu]", - input_size, mem_type_size, ori_input_offset_list_size); + tmp_op_desc->GetName().c_str(), input_size, mem_type_size, ori_input_offset_list_size); return ge::FAILED; } - // l1 keep orignal inputoffest + // not hbm keep orignal inputoffest // hbm inputoffset = original inputoffset + outputoffset - input_list.emplace_back(memory_type[input_index] != RT_MEMORY_HBM + input_list.emplace_back(memory_type[input_index] == RT_MEMORY_L1 ? origin_input_list[input_index] : origin_input_list[input_index] + output_list.at(peer_out_anchor->GetIdx())); - GELOGI("L1 fuison: node[%s] input[%d] is set from node[%s] out index[%d] offset[%ld]", + GELOGI("fuison: node[%s] input[%d] is set from node[%s] out index[%d] offset[%ld]", tmp_op_desc->GetName().c_str(), input_index, peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_out_anchor->GetIdx(), input_list.back()); @@ -1110,6 +1230,7 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< } ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { + GE_CHECK_NOTNULL(node->GetOpDesc()); vector input_list; if (node->GetType() == HCOMBROADCAST) { for (const auto &anchor : node->GetAllInDataAnchors()) { @@ -1140,13 +1261,20 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { } } } + } else if (node->GetType() == DATA) { + if (UpdateConstArgsOffset(node, input_list) != SUCCESS) { + GELOGE(FAILED, "Update data: %s args offset failed.", node->GetName().c_str()); + return FAILED; + } } else { - GE_CHK_STATUS_EXEC(UpdateOpInputOffset(node, input_list), GELOGE(FAILED, "UpdateOpInputOffset fail."); - return ge::FAILED); + if (UpdateOpInputOffset(node, input_list) != SUCCESS) { + GELOGE(FAILED, "Update node: %s input offset failed.", node->GetName().c_str()); + return FAILED; + } } - GE_CHECK_NOTNULL(node->GetOpDesc()); + node->GetOpDesc()->SetInputOffset(input_list); - return ge::SUCCESS; + return SUCCESS; } Status GraphMemoryAssigner::SetLoopGraphAtomicAttr(const ge::NodePtr &node, int64_t atomic_mem_start) { @@ -1181,7 +1309,7 @@ Status GraphMemoryAssigner::SetLoopGraphAtomicAttr(const ge::NodePtr &node, int6 ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t atomic_mem_start, int64_t atomic_mem_size) { - for (ge::NodePtr &node : compute_graph_->GetDirectNode()) { + for (ge::NodePtr &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); diff --git a/src/ge/graph/build/memory/graph_mem_assigner.h b/src/ge/graph/build/memory/graph_mem_assigner.h index e46d4f8b..67008918 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.h +++ b/src/ge/graph/build/memory/graph_mem_assigner.h @@ -26,6 +26,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" #include "runtime/mem.h" +#include "graph/build/memory/hybrid_mem_assigner.h" namespace ge { struct MemoryOffset { @@ -67,10 +68,13 @@ class VariableMemoryAssigner { }; using VariableMemoryAssignerPtr = std::shared_ptr; +using BlockMemAssignerPtr = std::shared_ptr; +using HybridMemAssignerPtr = std::shared_ptr; class GraphMemoryAssigner { public: - explicit GraphMemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} + explicit GraphMemoryAssigner(ge::ComputeGraphPtr compute_graph) + : compute_graph_(std::move(compute_graph)), mem_assigner_(nullptr) {} GraphMemoryAssigner(const GraphMemoryAssigner &) = delete; @@ -93,18 +97,18 @@ class GraphMemoryAssigner { /// ge::Status AssignVarAttr2Nodes(); - ge::Status AssignSubgraphInputsMemory(); - - ge::Status AssignSubgraphOutputsMemory(); - ge::Status ReAssignMemory(bool is_loop_graph, size_t &mem_offset); + ge::Status AssignZeroCopyMemory(size_t &mem_offset, size_t &zero_mem_copy_size); + ge::Status SetInputOffset(); ge::Status UpdateOpInputOffset(const NodePtr &node) const; ge::Status CheckOffset(); + ge::Status AssignReferenceMemory(); + private: /// /// @ingroup ge_graph @@ -117,19 +121,25 @@ class GraphMemoryAssigner { ge::Status ReAssignReuseAndNoPaddingContinuousOutputMemory(); + ge::Status ReAssignVirtualInputNodeMemory(NodePtr node, size_t &mem_offset_reuse); + + ge::Status ReAssignVirtualOutputNodeMemory(NodePtr node, size_t &mem_offset_reuse); + + ge::Status ReAssignVirtualNodesMemory(map> &mem_reuse_nodes_map, int32_t mem_reuse_model); + + ge::Status GetMaxBatchLabel(const map> &mem_reuse_virtual_nodes_map, int32_t mem_reuse_model, + string &max_batch_label); + ge::Status CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &output_desc, int64_t dim_index, int64_t &output_mem_size, int64_t &batch_dim_num, int64_t &out_size); - ge::Status ReAssignMergeMemory(); - ge::Status ReAssignAtomicMemory(bool is_loop_graph); - ge::Status AssignContinuousInputMemory(const ge::NodePtr &node); + ge::Status AssignContinuousInputMemory(const ge::NodePtr &node, int64_t &continuous_mem_start, + int64_t &continuous_mem_size); ge::Status AssignContinuousOutputMemory(const ge::NodePtr &node); - ge::Status AssignReferenceMemory(const ge::NodePtr &node); - /// /// @brief check the input of node whether support atomic attr /// @param node @@ -158,8 +168,11 @@ class GraphMemoryAssigner { ge::Status UpdateOpInputOffset(const NodePtr &node, vector &input_list) const; + ge::Status UpdateConstArgsOffset(const NodePtr &node, vector &input_list) const; + MemoryOffsetList memory_offset_; ge::ComputeGraphPtr compute_graph_; + HybridMemAssignerPtr mem_assigner_; }; } // namespace ge diff --git a/src/ge/graph/build/memory/hybrid_mem_assigner.cc b/src/ge/graph/build/memory/hybrid_mem_assigner.cc index 6165494c..925d742a 100644 --- a/src/ge/graph/build/memory/hybrid_mem_assigner.cc +++ b/src/ge/graph/build/memory/hybrid_mem_assigner.cc @@ -23,7 +23,7 @@ namespace ge { HybridMemAssigner::HybridMemAssigner(ge::ComputeGraphPtr compute_graph) - : mem_offset_(0), compute_graph_(std::move(compute_graph)) {} + : mem_offset_(0), compute_graph_(std::move(compute_graph)), priority_assigner_(nullptr) {} Status HybridMemAssigner::AssignMemory(std::unique_ptr &block_assigner, size_t &mem_size) { vector ranges; @@ -64,8 +64,10 @@ Status HybridMemAssigner::Assign() { priority_assigner = std::move(max_assigner); } - priority_assigner->SetOpMemOffset(); + priority_assigner->SetOpMemOffset(false); mem_offset_ = priority_assigner->GetMemOffset(); + priority_assigner_ = std::move(priority_assigner); + return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/build/memory/hybrid_mem_assigner.h b/src/ge/graph/build/memory/hybrid_mem_assigner.h index 1e78c278..db3741d4 100644 --- a/src/ge/graph/build/memory/hybrid_mem_assigner.h +++ b/src/ge/graph/build/memory/hybrid_mem_assigner.h @@ -19,12 +19,15 @@ #include #include "graph/build/memory/mem_assigner.h" +#include "graph/build/memory/block_mem_assigner.h" #include "graph/compute_graph.h" #include "common/types.h" #include "common/util.h" namespace ge { +using BlockMemAssignerPtr = std::shared_ptr; + class BlockMemAssigner; class HybridMemAssigner : public MemAssigner { @@ -41,12 +44,16 @@ class HybridMemAssigner : public MemAssigner { size_t GetMemOffset() const { return mem_offset_; } + BlockMemAssignerPtr GetPriorityAssinger() const { return priority_assigner_; } + private: Status AssignMemory(std::unique_ptr &block_assigner, size_t &mem_size); size_t mem_offset_; ge::ComputeGraphPtr compute_graph_; + + BlockMemAssignerPtr priority_assigner_; }; } // namespace ge #endif // GE_GRAPH_BUILD_MEMORY_HYBRID_MEM_ASSIGNER_H_ diff --git a/src/ge/graph/build/memory/memory_assigner.cc b/src/ge/graph/build/memory/memory_assigner.cc index 75ab01b4..e36f082e 100644 --- a/src/ge/graph/build/memory/memory_assigner.cc +++ b/src/ge/graph/build/memory/memory_assigner.cc @@ -20,7 +20,7 @@ #include "graph/build/memory/graph_mem_assigner.h" namespace ge { -Status MemoryAssigner::AssignMemory(bool is_loop_graph, size_t &mem_offset) { +Status MemoryAssigner::AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size) { GraphMemoryAssigner graph_mem_assigner(compute_graph_); if (graph_mem_assigner.AssignMemory() != ge::SUCCESS) { @@ -34,27 +34,28 @@ Status MemoryAssigner::AssignMemory(bool is_loop_graph, size_t &mem_offset) { return ge::FAILED; } + // Assign memory (block and offset) for zero copy nodes + if (graph_mem_assigner.AssignZeroCopyMemory(mem_offset, zero_copy_mem_size) != ge::SUCCESS) { + GELOGE(ge::FAILED, "Zero copy memory assigner failed"); + return ge::FAILED; + } + + // Assign memory for reference + if (graph_mem_assigner.AssignReferenceMemory() != ge::SUCCESS) { + GELOGE(ge::FAILED, "Assign reference memory failed!"); + return ge::FAILED; + } + // Must do variable attr assign after all the memory assigned if (graph_mem_assigner.AssignVarAttr2Nodes() != SUCCESS) { GELOGE(FAILED, "Variable Memory assigner failed"); return FAILED; } - if (graph_mem_assigner.SetInputOffset() != ge::SUCCESS) { GELOGE(ge::FAILED, "SetInputOffset Fail!"); return ge::FAILED; } - if (graph_mem_assigner.AssignSubgraphInputsMemory() != SUCCESS) { - GELOGE(FAILED, "Assign subgraph inputs memory failed"); - return FAILED; - } - - if (graph_mem_assigner.AssignSubgraphOutputsMemory() != SUCCESS) { - GELOGE(FAILED, "Assign subgraph inputs memory failed"); - return FAILED; - } - if (graph_mem_assigner.CheckOffset() != SUCCESS) { GELOGE(FAILED, "CheckOffset Fail!"); return FAILED; diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc index ff5f9798..a71e09b2 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.cc +++ b/src/ge/graph/build/memory/var_mem_assign_util.cc @@ -16,7 +16,6 @@ #include "graph/build/memory/var_mem_assign_util.h" #include -#include "common/op/attr_define.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/common/transop_util.h" @@ -50,11 +49,11 @@ Status VarMemAssignUtil::AssignMemory2VariableNode(ge::ComputeGraphPtr &compute_ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_graph) { GE_IF_BOOL_EXEC(compute_graph == nullptr, return FAILED); - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - GE_IF_BOOL_EXEC((n->GetType() != domi::VARIABLE) && (n->GetType() != domi::CONSTANTOP), continue); + for (const ge::NodePtr &n : compute_graph->GetAllNodes()) { + GE_IF_BOOL_EXEC((n->GetType() != VARIABLE) && (n->GetType() != CONSTANTOP), continue); string ref_var_src_var_name; GE_CHECK_NOTNULL(n->GetOpDesc()); - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); + GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); string node_name = n->GetName(); GE_IF_BOOL_EXEC(n->GetOpDesc()->GetAllOutputsDesc().empty(), GELOGE(FAILED, "node:%s has no OutputDesc.", n->GetName().c_str()); @@ -64,7 +63,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { GE_CHK_STATUS_RET( VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM)); - GE_IF_BOOL_EXEC(n->GetType() == domi::VARIABLE, + GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) ->SetAllocatedGraphId(node_name, compute_graph->GetGraphID())); @@ -85,7 +84,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t session_id) { string src_var_name; GE_CHECK_NOTNULL(node->GetOpDesc()); - if (ge::AttrUtils::GetStr(node->GetOpDesc(), domi::VAR_ATTR_SRC_VAR_NAME, src_var_name)) { + if (ge::AttrUtils::GetStr(node->GetOpDesc(), VAR_ATTR_SRC_VAR_NAME, src_var_name)) { ge::GeTensorDesc cur_tensor_desc; uint8_t *dev_ptr = nullptr; rtMemType_t memory_type = RT_MEMORY_HBM; @@ -99,12 +98,11 @@ Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t se } Status VarMemAssignUtil::AssignVarAttr2Nodes(ge::ComputeGraphPtr &compute_graph) { - for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetType() != domi::VARIABLE, continue); + for (const ge::NodePtr &node : compute_graph->GetAllNodes()) { + GE_IF_BOOL_EXEC(node->GetType() != VARIABLE, continue); string ref_var_src_var_name; GE_CHECK_NOTNULL(node->GetOpDesc()); - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), - continue); + GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); GE_CHK_STATUS_RET(DealVariableNode(compute_graph->GetGraphID(), node, compute_graph->GetSessionID())); } return SUCCESS; @@ -142,8 +140,7 @@ Status VarMemAssignUtil::DealExportVariableNode(const ge::NodePtr &node, const g GE_IF_BOOL_EXEC(var_out_anchor == nullptr, return FAILED); for (const ge::InDataAnchorPtr &dst_in_var_anchor : var_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || - (dst_node->GetType() == domi::ASSIGNSUB)) { + if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { if (dst_in_var_anchor == dst_node->GetInDataAnchor(0)) { GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, var_node, session_id)); } @@ -211,20 +208,19 @@ Status VarMemAssignUtil::DealVariableNode(uint32_t graph_id, const ge::NodePtr & for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); - if (dst_node->GetType() == domi::HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST) { GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); continue; } - if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || - (dst_node->GetType() == domi::ASSIGNSUB)) { + if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { if (dst_in_data_anchor == dst_node->GetInDataAnchor(0)) { GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, node, session_id)); } } auto dst_type = dst_node->GetType(); - bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || - (dst_type == domi::PERMUTE); + bool is_trans_node = + (dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); if (is_trans_node) { NodePtr final_trans_node = GetFinalTransNode(dst_node); GE_CHK_STATUS_RET(DealTransNode(final_trans_node)); @@ -241,8 +237,8 @@ ge::NodePtr VarMemAssignUtil::GetFinalTransNode(const ge::NodePtr &trans_node) { for (const auto &dst_in_anchor : trans_out_data_anchor->GetPeerInDataAnchors()) { NodePtr dst_node = dst_in_anchor->GetOwnerNode(); auto dst_type = dst_node->GetType(); - bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || - (dst_type == domi::PERMUTE); + bool is_trans_node = + (dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); if (is_trans_node && (dst_in_anchor->GetIdx() == 0)) { final_ref_node = GetFinalTransNode(dst_node); } @@ -256,8 +252,7 @@ Status VarMemAssignUtil::DealTransNode(const ge::NodePtr &final_trans_node) { GE_IF_BOOL_EXEC(final_trans_out_anchor == nullptr, return SUCCESS); for (const ge::InDataAnchorPtr &dst_in_var_anchor : final_trans_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || - (dst_node->GetType() == domi::ASSIGNSUB)) { + if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); } } @@ -269,8 +264,7 @@ Status VarMemAssignUtil::DealExportTransNode(const ge::NodePtr &node, const ge:: GE_CHECK_NOTNULL(node_out_anchor); for (const ge::InDataAnchorPtr &dst_in_var_anchor : node_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || - (dst_node->GetType() == domi::ASSIGNSUB)) { + if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); } } @@ -303,10 +297,10 @@ Status VarMemAssignUtil::SetOutTransNodeToAssign(const ge::NodePtr &node, const } Status VarMemAssignUtil::AssignMemory2HasRefAttrNode(ge::ComputeGraphPtr &compute_graph) { - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { + for (const ge::NodePtr &n : compute_graph->GetAllNodes()) { string ref_var_src_var_name; GE_CHECK_NOTNULL(n->GetOpDesc()); - bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); GE_IF_BOOL_EXEC(is_ref, GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); } @@ -329,7 +323,7 @@ Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, GE_CHECK_SIZE(ref_attr_node_output_list.size()); int out_index = 0; - bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), domi::REF_VAR_PRE_PEER_OUT_INDEX, out_index); + bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, out_index); if (!is_get) { GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); } diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index 750cc90b..4c3f3ffd 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -17,6 +17,7 @@ #include "graph/build/model_builder.h" #include #include +#include #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" @@ -27,6 +28,7 @@ #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" +#include "graph/ge_context.h" #include "graph/ge_error_codes.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" @@ -38,49 +40,11 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "graph/ge_context.h" #include "init/gelib.h" #include "memory/memory_assigner.h" #include "omg/version.h" #include "register/op_registry.h" -using domi::AIPP_CONV_FLAG; -using domi::AIPP_DATA_FLAG; -using domi::AIPP_DATA_TYPE; -using domi::AippOpParams; -using domi::ATOMICADDRCLEAN; -using domi::ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -using domi::ATTR_NAME_AUTOMIC_ADD_START; -using domi::CAST; -using domi::CHW_DIM_H; -using domi::CHW_DIM_W; -using domi::CONCAT; -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::CONVOLUTION; -using domi::DATA; -using domi::DATA_TYPE; -using domi::DEFAULT_FORMAT; -using domi::DIM_DEFAULT_SIZE; -using domi::DOMI_TENSOR_NC1HWC0; -using domi::HWC_DIM_H; -using domi::HWC_DIM_W; -using domi::LOOPCOND; -using domi::MODEL_ATTR_TASKS; -using domi::ModelTaskDef; -using domi::NCHW_DIM_H; -using domi::NCHW_DIM_N; -using domi::NCHW_DIM_W; -using domi::NETOUTPUT; -using domi::NHWC_DIM_H; -using domi::NHWC_DIM_W; -using domi::PlatformVersionManager; -using domi::STREAMMERGE; -using domi::VARIABLE; -using domi::XRGB_CHN_NUM; -using ge::FAILED; -using ge::PARAM_INVALID; -using ge::SUCCESS; using std::map; using std::set; using std::string; @@ -102,25 +66,25 @@ const char *const kVectorCore = "VectorCore"; const char *const kCoreType = "ge.engineType"; const std::string kEnableL1Fusion = "ge.l1Fusion"; -const set adjust_layer_type_ = {CONVOLUTION}; +const set adjust_layer_type_ = {ge::CONVOLUTION}; bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { auto type = op_desc->GetType(); - if (type == CONSTANTOP) { + if (type == ge::CONSTANTOP) { // constant op just has one output ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); return !(output_desc.GetDataType() == ge::DT_STRING); } - const set ge_local_set = {domi::STREAMMERGE, domi::MEMCPYASYNC, domi::STREAMACTIVE, domi::STREAMSWITCH, - domi::VARIABLE, domi::NOOP, domi::CONSTANT, domi::ENTER, - domi::REFENTER, domi::LOOPCOND, domi::NEXTITERATION, domi::REFNEXTITERATION, - domi::EXIT, domi::REFEXIT}; + const set ge_local_set = {ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, + ge::VARIABLE, ge::NOOP, ge::CONSTANT, ge::ENTER, + ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, + ge::EXIT, ge::REFEXIT, ge::MERGE, ge::MEMCPYADDRASYNC}; return (ge_local_set.find(type) != ge_local_set.end()); } } // namespace namespace ge { -ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const vector &subgraphs, +ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, const map &stream_max_parallel_num, bool hcom_parallel, int mode) : mem_offset_(0), weight_offset_(kWeightsStartOffset), @@ -133,6 +97,7 @@ ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const vector &is_input_const) { + GELOGI("SetIsInputConst const: %s", op_desc->GetName().c_str()); + for (size_t i = is_input_const.size(); i <= index; ++i) { + is_input_const.push_back(false); + } + is_input_const[index] = true; + + vector weights = OpDescUtils::MutableWeights(src_node); + if (weights.empty()) { + GELOGW("SetInputIsConst weights is empty"); + return false; + } + GeTensorPtr weight = weights[0]; + GE_IF_BOOL_EXEC(weight == nullptr, return true); + GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); + int64_t data_offset = 0; + if (TensorUtils::GetDataOffset(tensor_desc, data_offset) != GRAPH_SUCCESS) { + GELOGW("Get Offset from weight failed"); + return false; + } + auto input_tensor = op_desc->MutableInputDesc(static_cast(index)); + if (input_tensor == nullptr) { + GELOGW("Get input_tensor failed"); + return false; + } + TensorUtils::SetDataOffset(*input_tensor, data_offset); + return true; +} + void ModelBuilder::SetInputIsConst(const ge::NodePtr &n) { auto node_op_desc = n->GetOpDesc(); - if (node_op_desc == nullptr) { - GELOGW("node_op_desc is nullptr!"); - return; - } + GE_CHECK_NOTNULL_JUST_RETURN(node_op_desc); + auto is_input_const = node_op_desc->GetIsInputConst(); // must set all true input_const to false @@ -190,39 +183,35 @@ void ModelBuilder::SetInputIsConst(const ge::NodePtr &n) { GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); const auto &src_node = peer_out_anchor->GetOwnerNode(); if (src_node->GetType() == CONSTANT) { - GELOGI("SetIsInputConst const"); - for (size_t i = is_input_const.size(); i <= index; ++i) { - is_input_const.push_back(false); - } - is_input_const[index] = true; - - vector weights = OpDescUtils::MutableWeights(src_node); - if (weights.empty()) { - GELOGW("SetInputIsConst weights is empty"); + if (!SetInputConst(node_op_desc, src_node, index, is_input_const)) { return; } - GeTensorPtr weight = weights[0]; - GE_IF_BOOL_EXEC(weight == nullptr, continue); - GeTensorDesc &tensor_desc = weight->MutableTensorDesc(); - int64_t data_offset = 0; - if (TensorUtils::GetDataOffset(tensor_desc, data_offset) != GRAPH_SUCCESS) { - GELOGW("Get Offset from weight failed"); - return; - } - auto input_tensor = node_op_desc->MutableInputDesc(static_cast(index)); - if (input_tensor == nullptr) { - GELOGW("Get input_tensor failed"); - return; - } - TensorUtils::SetDataOffset(*input_tensor, data_offset); } else if (src_node->GetType() == CONSTANTOP) { if ((index < is_input_const.size()) && is_input_const[index]) { is_input_const[index] = false; } + } else if (src_node->GetType() == DATA) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + continue; + } + + // Subgraph Data Node, check for constant input. + std::string op_type; + const NodePtr in_node = NodeUtils::GetParentInput(src_node); + if (!NodeUtils::GetConstOpType(in_node, op_type)) { + continue; // not constant input. + } + + if (op_type == CONSTANT) { + if (!SetInputConst(node_op_desc, in_node, index, is_input_const)) { + return; + } + } } } - std::string input_const_info = domi::ToString(is_input_const); + std::string input_const_info = ToString(is_input_const); GELOGD("update opdesc:%s InputConst:%s", node_op_desc->GetName().c_str(), input_const_info.c_str()); node_op_desc->SetIsInputConst(is_input_const); } @@ -252,13 +241,32 @@ Status ModelBuilder::SetInputOutputDesc() { Status ret; GELOGI("Start to SetInputOutputDesc."); - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { is_loop_graph_ = true; } + // if user set input node format ND, the expected node for data and netoutput format is ND in + // final graph. + if ((domi::GetContext().format == domi::DOMI_TENSOR_ND) && + ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { + GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); + auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); + auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); + ge::Format format = ge::FORMAT_ND; + for (auto &inputDescPtr : inputDescsPtr) { + GE_CHECK_NOTNULL(inputDescPtr); + inputDescPtr->SetFormat(format); + inputDescPtr->SetOriginFormat(format); + } + for (auto &outputDescPtr : outputDescsPtr) { + GE_CHECK_NOTNULL(outputDescPtr); + outputDescPtr->SetFormat(format); + outputDescPtr->SetOriginFormat(format); + } + } if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { GELOGD("Data node: %s.", n->GetName().c_str()); @@ -282,7 +290,7 @@ Status ModelBuilder::SetInputOutputDesc() { } void ModelBuilder::AddNodeInputProperty() { - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); vector src_name_list; @@ -309,7 +317,7 @@ void ModelBuilder::AddNodeInputProperty() { node_op_desc->SetSrcIndex(src_index_list); } - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); @@ -347,7 +355,7 @@ void ModelBuilder::AddNodeInputProperty() { Status ModelBuilder::AdjustInputTensorFlag() { GELOGI("Start to AdjustInputTensorFlag."); - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { if ((n->GetType() == DATA_TYPE) || (n->GetType() == AIPP_DATA_TYPE)) { GELOGD("Data node: %s.", n->GetName().c_str()); for (const auto &anchor : n->GetAllOutDataAnchors()) { @@ -369,11 +377,11 @@ Status ModelBuilder::AdjustInputTensorFlag() { return SUCCESS; } void ModelBuilder::InitL1FusionOption() { - string is_l1_fusion_enable = "false"; - graphStatus ret = ge::GetContext().GetOption(kEnableL1Fusion, is_l1_fusion_enable); + string buffer_optimize = "off_optimize"; + graphStatus ret = ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); if (ret == GRAPH_SUCCESS) { - is_l1_fusion_enable_ = is_l1_fusion_enable == "true"; - GELOGD("The value of %s is %s.", kEnableL1Fusion.c_str(), is_l1_fusion_enable.c_str()); + is_l1_fusion_enable_ = (buffer_optimize == "l1_optimize"); + GELOGD("The value of %s is %s.", BUFFER_OPTIMIZE.c_str(), buffer_optimize.c_str()); } else { GELOGW("The value of %s is empty.", kEnableL1Fusion.c_str()); } @@ -386,18 +394,27 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_MEMORY_SIZE, max_mem_offset_), GELOGE(FAILED, "SetInt of ATTR_MODEL_MEMORY_SIZE failed."); return FAILED); - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_STREAM_NUM, stream_num_), - GELOGE(FAILED, "SetInt of ATTR_MODEL_STREAM_NUM failed."); - return FAILED); GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_WEIGHT_SIZE, weight_offset_), GELOGE(FAILED, "SetInt of ATTR_MODEL_WEIGHT_SIZE failed."); return FAILED); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_STREAM_NUM, stream_num_), + GELOGE(FAILED, "SetInt of ATTR_MODEL_STREAM_NUM failed."); + return FAILED); GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_EVENT_NUM, event_num_), GELOGE(FAILED, "SetInt of ATTR_MODEL_EVENT_NUM failed."); return FAILED); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(&model, ATTR_MODEL_HUGE_STREAM_LIST, huge_streams_), + GELOGE(FAILED, "SetInt of ATTR_MODEL_HUGE_STREAM_LIST failed."); + return FAILED); GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_LABEL_NUM, label_num_), GELOGE(FAILED, "SetInt of ATTR_MODEL_LABEL_NUM failed."); return FAILED); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), + GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); + return FAILED); + + GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); + string ge_core_type; Status ret = ge::GetContext().GetOption(kCoreType, ge_core_type); if (ret != SUCCESS) { @@ -428,7 +445,7 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { } void ModelBuilder::ClearOriginalFormat() { - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { + for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); if (node_op_desc != nullptr) { if (node_op_desc->HasAttr(ATTR_NAME_FORMAT)) { @@ -518,11 +535,17 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { ge_model.SetWeight(weight_buffer_); // Add TBE Kernels - for (const ge::NodePtr &n : compute_graph_->GetDirectNode()) { + std::set name_set; + for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); + if (name_set.count(tbe_kernel->GetName()) > 0) { + GELOGE(FAILED, "tbe_kernel name %s can't be the same", tbe_kernel->GetName().c_str()); + return FAILED; + } + name_set.insert(tbe_kernel->GetName()); tbe_kernel_store_.AddTBEKernel(tbe_kernel); GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); } @@ -539,7 +562,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { return INTERNAL_ERROR; } int byte_size = static_cast(task_def_bytes.GetSize()); - std::shared_ptr task = ge::MakeShared(); + std::shared_ptr task = ge::MakeShared(); GE_CHECK_NOTNULL(task); GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_def_bytes.GetData(), byte_size, task.get()), return INTERNAL_ERROR, "ReadProtoFromArray failed."); @@ -585,12 +608,6 @@ Status ModelBuilder::PreBuildModel() { Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { GE_CHK_STATUS_RET(AdjustInputTensorFlag(), "AdjustInputTensorFlag failed!"); - // Assign functional op labels. - GE_TIMESTAMP_START(AssignFunctionalLabels); - LabelAllocator label_allocator(compute_graph_); - GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); - GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); - // Assign logical streams. StreamAllocator stream_allocator(compute_graph_, subgraphs_); GE_TIMESTAMP_START(AssignLogicalStreams); @@ -598,9 +615,16 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { "Assign logical streams failed."); GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); + // Assign functional op labels. + GE_TIMESTAMP_START(AssignFunctionalLabels); + LabelAllocator label_allocator(compute_graph_); + GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); + GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); + GE_TIMESTAMP_START(AssignMemory); MemoryAssigner mem_assigner(compute_graph_); - GE_CHK_STATUS_RET(mem_assigner.AssignMemory(is_loop_graph_, mem_offset_), "Assign Memory Failed!"); + GE_CHK_STATUS_RET(mem_assigner.AssignMemory(is_loop_graph_, mem_offset_, zero_copy_mem_size_), + "Assign Memory Failed!"); GE_TIMESTAMP_END(AssignMemory, "GraphBuilder::AssignMemory"); // Compile single op in graph build stage @@ -611,6 +635,7 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { // Refresh real streams and insert event nodes. GE_TIMESTAMP_START(RefreshRealStream); GE_CHK_STATUS_RET(stream_allocator.RefreshRealStream(stream_num_, event_num_), "RefreshRealStream failed."); + huge_streams_ = stream_allocator.GetHugeStreams(); GE_TIMESTAMP_END(RefreshRealStream, "GraphBuilder::RefreshRealStream"); GE_TIMESTAMP_START(MergeWeights); diff --git a/src/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h index 4bf03bdc..8f0d69b4 100644 --- a/src/ge/graph/build/model_builder.h +++ b/src/ge/graph/build/model_builder.h @@ -37,7 +37,7 @@ namespace ge { class ModelBuilder { public: - ModelBuilder(ge::ComputeGraphPtr whole_graph, const std::vector &subgraphs, + ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode = static_cast(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); @@ -61,6 +61,8 @@ class ModelBuilder { Status MergeWeights(); private: + bool SetInputConst(const OpDescPtr &op_desc, const NodePtr &src_node, size_t index, vector &is_input_const); + void SetInputIsConst(const ge::NodePtr &n); void SetModelVersion(ge::Model &model); @@ -85,11 +87,11 @@ class ModelBuilder { ge::ComputeGraphPtr compute_graph_; - const std::vector &subgraphs_; + const Graph2SubGraphInfoList &subgraphs_; int64_t stream_num_; - int64_t event_num_; + vector huge_streams_; uint32_t label_num_; @@ -100,6 +102,7 @@ class ModelBuilder { int build_mode_; size_t max_mem_offset_; + size_t zero_copy_mem_size_; TBEKernelStore tbe_kernel_store_; diff --git a/src/ge/graph/build/run_context.cc b/src/ge/graph/build/run_context.cc index d0fab3bd..f2a41271 100644 --- a/src/ge/graph/build/run_context.cc +++ b/src/ge/graph/build/run_context.cc @@ -17,7 +17,6 @@ #include "graph/build/run_context.h" #include "common/util.h" -#include "framework/common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" @@ -170,7 +169,6 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; - return SUCCESS; } diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index baa5e400..d1efa221 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -17,40 +17,60 @@ #include "graph/build/stream_allocator.h" #include #include "common/ge/ge_util.h" -#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" +#include "graph/ge_context.h" +#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "init/gelib.h" - #include "graph/build/logical_stream_allocator.h" -using domi::ATTR_NAME_STREAM_LABEL; -using domi::HCOMALLGATHER; -using domi::HCOMALLREDUCE; -using domi::HCOMBROADCAST; -using domi::HCOMREDUCESCATTER; -using domi::RECV; -using domi::SEND; -using domi::STREAMACTIVE; -using domi::STREAMSWITCH; using std::map; using std::set; using std::string; using std::vector; namespace { -const int64_t kMaxNodeNumInNormalStream = 350; -const int64_t kMaxNodeNumInHcomStream = 5; - const uint32_t kMaxSwitchStreamNum = 1; +const int64_t kTaskNumPerNormalNode = 3; +const int64_t kTaskNumPerHcclNode = 200; +const char *const kTrueStr = "true"; +const char *const kFalseStr = "false"; + +inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &continuous_stream_label) { + if (ge::AttrUtils::GetStr(op_desc, ge::ATTR_NAME_CONTINUOUS_STREAM_LABEL, continuous_stream_label)) { + GELOGD("node[%s] get continuous_stream_label %s", op_desc->GetName().c_str(), continuous_stream_label.c_str()); + return true; + } + return false; +} + +bool IsHcclOp(const string &op_type) { + const set hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER}); + return hccl_op_types.find(op_type) != hccl_op_types.end(); +} } // namespace namespace ge { +StreamAllocator::StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs) + : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) { + string single_stream_str; + (void)GetContext().GetOption(ENABLE_SINGLE_STREAM, single_stream_str); + + const set stream_options = {"", kTrueStr, kFalseStr}; + if (stream_options.find(single_stream_str) == stream_options.end()) { + GELOGW("The value %s of the %s option is invalid, it should be true or false.", single_stream_str.c_str(), + ENABLE_SINGLE_STREAM); + } + + enable_single_stream_ = (single_stream_str == kTrueStr) ? true : false; + GELOGI("Enable single stream: %s.", enable_single_stream_ ? kTrueStr : kFalseStr); +} + Status StreamAllocator::AssignLogicalStreams(const std::map &max_parallel_num, bool hcom_parallel) { - GELOGI("AssignLogicalStreams start."); + GELOGI("Assign logical streams start."); GE_CHECK_NOTNULL(whole_graph_); GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); @@ -62,8 +82,10 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m } const map &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); + LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num); + logical_allocator.EnableSingleStream(enable_single_stream_); + logical_allocator.EnableHcomParallel(hcom_parallel); - LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); if (status != SUCCESS) { GELOGE(status, "Assign logical streams failed."); @@ -72,7 +94,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); - GELOGI("AssignLogicalStreams success."); + GELOGI("Assign logical streams success."); return SUCCESS; } @@ -82,9 +104,16 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_num) { GELOGI("RefreshRealStream start."); GE_CHECK_NOTNULL(whole_graph_); - Status status = ActiveStreamsBySpecificLabels(); + + Status status = AssignSingleStream(); + if (status != SUCCESS) { + GELOGE(status, "AssignSingleStream failed!"); + return status; + } + + status = SetActiveStreamsByLabel(); if (status != SUCCESS) { - GELOGE(status, "ActiveStreams failed!"); + GELOGE(status, "SetActiveStreamsByLabel failed!"); return status; } @@ -100,15 +129,16 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return status; } - status = SplitStreams(); + vector> split_streams(stream_num_); + status = SplitStreams(split_streams); if (status != SUCCESS) { GELOGE(status, "SplitStreams failed!"); return status; } - status = ActiveStreamsForLoop(); + status = UpdateActiveStreams(split_streams); if (status != SUCCESS) { - GELOGE(status, "ActiveStreamsForLoop failed!"); + GELOGE(status, "UpdateActiveStreams failed!"); return status; } @@ -146,7 +176,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); stream_num_ = 1; } - GELOGI("stream_num_: %ld, event_num_: %u.", stream_num_, event_num_); + GELOGI("stream num: %ld, event num: %u.", stream_num_, event_num_); GELOGI("RefreshRealStream successfully."); stream_num = stream_num_; @@ -155,10 +185,57 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return SUCCESS; } +Status StreamAllocator::AssignSingleStream() { + if (!enable_single_stream_) { + return SUCCESS; + } + + if (stream_num_ > 1) { + GELOGE(FAILED, "The number of ts streams is %ld, only one is supported.", stream_num_); + return FAILED; + } + + int64_t task_count = 0; + for (const NodePtr &node : whole_graph_->GetAllNodes()) { + string op_type = node->GetType(); + if (IsHcclOp(op_type)) { + task_count += kTaskNumPerHcclNode; + } else { + task_count += kTaskNumPerNormalNode; + } + } + + uint32_t max_normal_stream_count = 0; + uint32_t max_normal_task_count = 0; + Status status = GetMaxStreamAndTask(false, max_normal_stream_count, max_normal_task_count); + if (status != SUCCESS) { + GELOGE(status, "Get max task count of normal stream failed."); + return status; + } + + if (task_count > static_cast(max_normal_task_count)) { + uint32_t max_huge_stream_count = 0; + uint32_t max_huge_task_count = 0; + Status status = GetMaxStreamAndTask(true, max_huge_stream_count, max_huge_task_count); + if (status == SUCCESS) { + int64_t huge_stream = 0; + GELOGI("Use huge stream %ld.", huge_stream); + huge_streams_.emplace_back(huge_stream); + } else { + GELOGW( + "The estimated task count %ld is greater than the max count of normal stream," + " but the huge stream is not supported.", + task_count); + } + } + + return SUCCESS; +} + // Split the stream according to the maximum number of nodes in the stream. -Status StreamAllocator::SplitStreams() { - if (stream_num_ == 0) { - GELOGI("stream_num_ is 0"); +Status StreamAllocator::SplitStreams(vector> &split_streams) { + if (enable_single_stream_ || stream_num_ == 0) { + GELOGI("The single stream option is enabled or the number of streams is 0, no need to split streams."); return SUCCESS; } @@ -168,8 +245,10 @@ Status StreamAllocator::SplitStreams() { vector stream_node_num_vec(stream_num_); vector added_stream_num_vec(stream_num_); vector new_stream_id_vec(stream_num_); + map stream_continuous_2_node_num_map; + map> stream_continuous_2_nodes_map; + map> stream_2_nodes_map; vector pre_node_vec(stream_num_); - vector> split_streams(stream_num_); int64_t last_stream_id = stream_num_ - 1; for (auto i = 0; i <= last_stream_id; i++) { @@ -179,9 +258,16 @@ Status StreamAllocator::SplitStreams() { pre_node_vec[i] = nullptr; } + uint32_t max_stream_count = 0; + uint32_t max_task_count = 0; + GE_CHK_STATUS_RET(GetMaxStreamAndTask(false, max_stream_count, max_task_count), + "Get max stream and task count failed."); + for (const auto &cur_node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(cur_node->GetOpDesc()); - int64_t stream_id = cur_node->GetOpDesc()->GetStreamId(); + GE_CHECK_NOTNULL(cur_node); + auto op_desc = cur_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + int64_t stream_id = op_desc->GetStreamId(); if (stream_id == kInvalidStream) { continue; } @@ -190,15 +276,20 @@ Status StreamAllocator::SplitStreams() { return FAILED; } stream_node_num_vec[stream_id]++; - + stream_2_nodes_map[stream_id].push_back(cur_node); // The maximum number of tasks per stream. - int64_t max_node_num_one_stream = kMaxNodeNumInNormalStream; - const string op_type = cur_node->GetType(); - if ((op_type == HCOMBROADCAST) || (op_type == HCOMALLGATHER) || (op_type == HCOMALLREDUCE) || - (op_type == HCOMREDUCESCATTER)) { - max_node_num_one_stream = kMaxNodeNumInHcomStream; + int64_t max_node_num_one_stream = GetMaxNodeNumPerStream(cur_node, max_task_count); + std::string continuous_stream_label; + if (HasContinuousStreamLabel(op_desc, continuous_stream_label)) { + stream_continuous_2_node_num_map[continuous_stream_label]++; + // return error + if (stream_continuous_2_node_num_map[continuous_stream_label] > max_node_num_one_stream) { + GELOGE(FAILED, "SplitStreams:node[%s] stream_id[%ld] continuous stream label[%s] unsatisfied ", + op_desc->GetName().c_str(), stream_id, continuous_stream_label.c_str()); + return FAILED; + } + stream_continuous_2_nodes_map[continuous_stream_label].push_back(cur_node); } - // Split the stream if it exceeds the maximum number of nodes in the stream. if (stream_node_num_vec[stream_id] > max_node_num_one_stream) { last_stream_id++; @@ -206,27 +297,50 @@ Status StreamAllocator::SplitStreams() { "stream_node_num_vec[%ld]= %ld > max_node_num_one_stream : %ld, " "It's time to split the stream, split newly-added stream id is %ld", stream_id, stream_node_num_vec[stream_id], max_node_num_one_stream, last_stream_id); - + NodePtr pre_node = pre_node_vec[stream_id]; stream_node_num_vec[stream_id] = 1; + // try spilt a new stream and move same continuous stream label nodes from this stream + bool not_use_cur = false; + NodePtr not_cur = nullptr; + std::string cur_continuous_stream_label; + if (HasContinuousStreamLabel(op_desc, cur_continuous_stream_label)) { + // get stored nodes + auto nodes = stream_continuous_2_nodes_map[cur_continuous_stream_label]; + GE_RETURN_WITH_LOG_IF_FALSE(!nodes.empty(), "split stream with continuous stream label %s failed", + cur_continuous_stream_label.c_str()); + for (const auto &node : nodes) { + auto stored_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(stored_op_desc); + stored_op_desc->SetStreamId(last_stream_id); + stream_node_num_vec[stream_id]++; + } + not_use_cur = true; + not_cur = nodes.front(); + GE_CHECK_NOTNULL(not_cur); + GELOGI("split from first node %s with continuous stream label %s", not_cur->GetName().c_str(), + cur_continuous_stream_label.c_str()); + auto iter = std::find(stream_2_nodes_map[stream_id].begin(), stream_2_nodes_map[stream_id].end(), not_cur); + GE_RETURN_WITH_LOG_IF_FALSE( + (iter != stream_2_nodes_map[stream_id].end()) && (iter != stream_2_nodes_map[stream_id].begin()), + "split stream with continuous stream label %s failed", cur_continuous_stream_label.c_str()); + iter--; + pre_node = *iter; + } + added_stream_num_vec[stream_id]++; new_stream_id_vec[stream_id] = last_stream_id; split_streams[stream_id].emplace(last_stream_id); // Add the send/recv event to the first and last nodes of the split stream. - NodePtr pre_node = pre_node_vec[stream_id]; if (pre_node != nullptr) { - GELOGI("Add send event %u for node %s", event_num_, pre_node->GetName().c_str()); - GELOGI("Add recv event %u for node %s", event_num_, cur_node->GetName().c_str()); - AddSendEventId(pre_node, event_num_); - AddRecvEventId(cur_node, event_num_); - ++event_num_; + GE_CHK_STATUS_RET(AddEventId(pre_node, not_cur, cur_node, not_use_cur), "AddEventId failed."); } } /// If the split stream num is greater than 1, the node behind the same /// stream must reset the new stream id. if (added_stream_num_vec[stream_id] >= 1) { - cur_node->GetOpDesc()->SetStreamId(new_stream_id_vec[stream_id]); + op_desc->SetStreamId(new_stream_id_vec[stream_id]); } pre_node_vec[stream_id] = cur_node; @@ -235,40 +349,187 @@ Status StreamAllocator::SplitStreams() { if (last_stream_id >= 0) { stream_num_ = last_stream_id + 1; } - return UpdateActiveStreams(split_streams); + return SUCCESS; } -Status StreamAllocator::UpdateActiveStreams(vector> &split_streams) { - for (const auto &node : whole_graph_->GetAllNodes()) { - vector active_streams; - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - vector new_active_streams = active_streams; - for (const uint32_t logical_stream : active_streams) { - if (static_cast(logical_stream) >= split_streams.size()) { - GELOGE(FAILED, "logical stream is out of range."); - return FAILED; - } - const set &new_split_streams = split_streams[logical_stream]; - if (!new_split_streams.empty()) { - for (int64_t split_stream : new_split_streams) { - specific_activated_streams_.emplace(split_stream); - new_active_streams.emplace_back(static_cast(split_stream)); +Status StreamAllocator::AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, + bool not_use_cur) { + GELOGI("Add send event %u for node %s", event_num_, pre_node->GetName().c_str()); + AddSendEventId(pre_node, event_num_); + if (not_use_cur) { + GE_CHECK_NOTNULL(not_cur); + GELOGI("Add recv event %u for node %s", event_num_, not_cur->GetName().c_str()); + AddRecvEventId(not_cur, event_num_); + } else { + GELOGI("Add recv event %u for node %s", event_num_, cur_node->GetName().c_str()); + AddRecvEventId(cur_node, event_num_); + } + ++event_num_; + return SUCCESS; +} + +Status StreamAllocator::GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count) { + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on != nullptr) { + rtError_t ret = rtSetPlatformType(PLATFORM_MINI_V1); + if (ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Get max stream and task count by rts failed."); + return FAILED; + } + } + + uint32_t stream_type = RT_NORMAL_STREAM; + if (huge_stream) { + stream_type = RT_HUGE_STREAM; + } + rtError_t ret = rtGetMaxStreamAndTask(stream_type, &max_stream_count, &max_task_count); + if (ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Get max stream and task count by rts failed."); + return FAILED; + } + GELOGI("Allowed max stream count: %u, max task count per stream: %u.", max_stream_count, max_task_count); + + return SUCCESS; +} + +int64_t StreamAllocator::GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_task_count) { + int64_t max_node_num_one_stream = static_cast(max_task_count); + string op_type = node->GetType(); + if (IsHcclOp(op_type)) { + max_node_num_one_stream /= kTaskNumPerHcclNode; + } else { + max_node_num_one_stream /= kTaskNumPerNormalNode; + } + if (max_node_num_one_stream == 0) { + max_node_num_one_stream = 1; + } + + return max_node_num_one_stream; +} + +Status StreamAllocator::UpdateActiveStreams(const vector> &split_streams) { + UpdateLabelStreams(split_streams); + + for (auto &node : whole_graph_->GetAllNodes()) { + if ((node->GetType() == STREAMSWITCH) || (node->GetType() == STREAMSWITCHN)) { + if (InsertActiveNodesAfterSwitch(node) != SUCCESS) { + GELOGE(FAILED, "Insert active nodes after switch node failed."); + return FAILED; + } + } else { + vector active_streams; + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + vector new_active_streams = active_streams; + for (const uint32_t logical_stream : active_streams) { + if (static_cast(logical_stream) >= split_streams.size()) { + GELOGE(FAILED, "logical stream is out of range."); + return FAILED; + } + const set &new_split_streams = split_streams[logical_stream]; + if (!new_split_streams.empty()) { + for (int64_t split_stream : new_split_streams) { + new_active_streams.emplace_back(static_cast(split_stream)); + } } } + if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); + return FAILED; + } } - if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { - GELOGE(FAILED, "UpdateActiveStreams failed, node name : (%s).", node->GetName().c_str()); - return FAILED; + } + } + + Status status = SetActiveStreamsForSubgraph(); + if (status != SUCCESS) { + GELOGE(status, "SetActiveStreamsForSubgraph failed!"); + return status; + } + + status = SetActiveStreamsForLoop(); + if (status != SUCCESS) { + GELOGE(status, "SetActiveStreamsForLoop failed!"); + return status; + } + + return SUCCESS; +} + +void StreamAllocator::UpdateLabelStreams(const vector> &split_streams) { + for (size_t i = 0; i < split_streams.size(); i++) { + auto &streams = split_streams[i]; + if (streams.empty()) { + continue; + } + if (specific_activated_streams_.count(static_cast(i)) > 0) { + specific_activated_streams_.insert(streams.begin(), streams.end()); + } + for (auto &labeled_stream : labeled_streams_) { + if (labeled_stream.second.count(static_cast(i)) > 0) { + labeled_stream.second.insert(streams.begin(), streams.end()); + break; } } } +} + +Status StreamAllocator::SetActiveStreamsForSubgraph() { + for (auto &subgraph : whole_graph_->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(subgraph); + NodePtr first_active_node = nullptr; + + // Get all streams in subgraph. + set subgraph_streams; + for (auto &node : subgraph->GetDirectNode()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream) { + subgraph_streams.emplace(stream_id); + } + if (first_active_node == nullptr && node->GetType() == STREAMACTIVE) { + first_active_node = node; + } + } + + if (first_active_node == nullptr) { + continue; + } + + // Set active streams for StreamActive. + subgraph_streams.erase(first_active_node->GetOpDesc()->GetStreamId()); + + vector active_streams; + for (int64_t active_stream : subgraph_streams) { + active_streams.emplace_back(static_cast(active_stream)); + specific_activated_streams_.emplace(active_stream); + } + + if (!AttrUtils::SetListInt(first_active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", first_active_node->GetName().c_str()); + return FAILED; + } + + // Remove all events after StreamActive. + vector send_events; + GetSendEventIdList(first_active_node, send_events); + + for (const auto &event_id : send_events) { + NodePtr recv_node = GetNodeFromRecvEventId(event_id); + GE_CHECK_NOTNULL(recv_node); + + RmvSendEventId(first_active_node, event_id); + RmvRecvEventId(recv_node, event_id); + GELOGI("Remove event %u between node %s and node %s", event_id, first_active_node->GetName().c_str(), + recv_node->GetName().c_str()); + } + } + return SUCCESS; } -Status StreamAllocator::ActiveStreamsBySpecificLabels() { - // > - map> labeled_streams; +Status StreamAllocator::SetActiveStreamsByLabel() { for (const auto &node : whole_graph_->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -276,7 +537,7 @@ Status StreamAllocator::ActiveStreamsBySpecificLabels() { if (AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label) && !stream_label.empty()) { int64_t stream_id = op_desc->GetStreamId(); if (stream_id != kInvalidStream) { - labeled_streams[stream_label].emplace(stream_id); + labeled_streams_[stream_label].emplace(stream_id); } } } @@ -292,7 +553,7 @@ Status StreamAllocator::ActiveStreamsBySpecificLabels() { vector activated_stream_list; for (string &activated_label : activated_label_list) { specific_activated_labels_[activated_label].emplace(node); - for (int64_t activated_stream : labeled_streams[activated_label]) { + for (int64_t activated_stream : labeled_streams_[activated_label]) { activated_stream_list.push_back(static_cast(activated_stream)); specific_activated_streams_.emplace(activated_stream); specific_activated_streams_nodes_map_[activated_stream].emplace(node); @@ -307,7 +568,7 @@ Status StreamAllocator::ActiveStreamsBySpecificLabels() { return SUCCESS; } -Status StreamAllocator::ActiveStreamsForLoop() { +Status StreamAllocator::SetActiveStreamsForLoop() { vector loop_active_streams; for (int64_t stream_id = 0; stream_id < stream_num_; stream_id++) { if (specific_activated_streams_.count(stream_id) == 0) { @@ -404,6 +665,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const return SUCCESS; } + if (next_stream_id == kInvalidStream) { + GELOGE(FAILED, "Stream id of next_node %s should not be %ld", next_node->GetName().c_str(), kInvalidStream); + return FAILED; + } + // No event needs to be inserted between the active node and the activated stream. string next_node_label; if (AttrUtils::GetStr(next_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, next_node_label) && !next_node_label.empty()) { @@ -683,7 +949,7 @@ Status StreamAllocator::InsertSyncEventNodes() { return FAILED); (void)AttrUtils::SetListStr(op_desc_ptr, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); - NodePtr recv_node = whole_graph_->AddNode(op_desc_ptr); + NodePtr recv_node = node->GetOwnerComputeGraph()->AddNode(op_desc_ptr); GE_CHECK_NOTNULL(recv_node); GE_CHECK_NOTNULL(recv_node->GetOutControlAnchor()); Status status = GraphUtils::AddEdge(recv_node->GetOutControlAnchor(), node->GetInControlAnchor()); @@ -711,7 +977,7 @@ Status StreamAllocator::InsertSyncEventNodes() { return FAILED); (void)AttrUtils::SetListStr(op_desc_ptr, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); - NodePtr send_node = whole_graph_->AddNode(op_desc_ptr); + NodePtr send_node = node->GetOwnerComputeGraph()->AddNode(op_desc_ptr); GE_CHECK_NOTNULL(send_node); GE_CHECK_NOTNULL(send_node->GetInControlAnchor()); Status status = GraphUtils::AddEdge(node->GetOutControlAnchor(), send_node->GetInControlAnchor()); @@ -725,11 +991,32 @@ Status StreamAllocator::InsertSyncEventNodes() { } } + Status status = ReorderEventNodes(); + if (status != SUCCESS) { + GELOGE(status, "Graph ReorderEventNodes failed"); + return status; + } + + return SUCCESS; +} + +Status StreamAllocator::ReorderEventNodes() const { Status status = whole_graph_->InsertEventNodes(); + GraphUtils::DumpGEGraph(whole_graph_, "AfterInsertEventNodes", true); + GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterInsertEventNodes"); if (status != SUCCESS) { - GELOGE(status, "whole_graph_->InsertEventNodes failed"); + GELOGE(status, "Whole graph InsertEventNodes failed"); return status; } + for (const auto &subgraph : whole_graph_->GetAllSubgraphs()) { + status = subgraph->InsertEventNodes(); + GraphUtils::DumpGEGraph(subgraph, "AfterInsertEventNodes_Subgraph"); + GraphUtils::DumpGEGraphToOnnx(*subgraph, "AfterInsertEventNodes_Subgraph"); + if (status != SUCCESS) { + GELOGE(status, "Subgraph %s InsertEventNodes failed", subgraph->GetName().c_str()); + return status; + } + } return SUCCESS; } @@ -955,4 +1242,149 @@ Status StreamAllocator::InsertActiveEntryStream(const std::vector &act return SUCCESS; } + +Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node) { + vector active_nodes; + if (InsertActiveNodesAfterSwitch(switch_node, active_nodes) != SUCCESS) { + GELOGE(FAILED, "Insert active nodes after node %s failed.", switch_node->GetName().c_str()); + return FAILED; + } + if (active_nodes.empty()) { + return SUCCESS; + } + vector stream_ids; + for (auto &active_node : active_nodes) { + GE_CHECK_NOTNULL(active_node->GetOpDesc()); + active_node->GetOpDesc()->SetStreamId(stream_num_); + stream_ids.emplace_back(stream_num_); + specific_activated_streams_.emplace(stream_num_); + stream_num_++; + } + auto op_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, stream_ids)) { + GELOGE(FAILED, "SetListInt failed."); + return FAILED; + } + + return SUCCESS; +} + +Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node, vector &active_nodes) { + GE_CHECK_NOTNULL(switch_node); + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + vector ori_active_label_list; + if (!AttrUtils::GetListStr(switch_desc, ATTR_NAME_ACTIVE_LABEL_LIST, ori_active_label_list) || + ori_active_label_list.empty()) { + GELOGE(INTERNAL_ERROR, "Get active label list of switch %s failed.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + vector active_label_list; + vector added_active_nodes; + if (AddActiveNodes(switch_node, ori_active_label_list, active_label_list, added_active_nodes) != SUCCESS) { + GELOGE(FAILED, "Add active nodes after node %s failed.", switch_node->GetName().c_str()); + return FAILED; + } + + if (SetActiveLabelList(switch_node, active_label_list) != SUCCESS) { + GELOGE(FAILED, "set active label list failed"); + return FAILED; + } + + if (added_active_nodes.empty()) { + return SUCCESS; + } + + for (auto &active_node : added_active_nodes) { + GE_CHECK_NOTNULL(switch_node->GetOutControlAnchor()); + if (switch_node->GetOutControlAnchor()->LinkTo(active_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Link %s to %s failed.", switch_node->GetName().c_str(), active_node->GetName().c_str()); + return FAILED; + } + active_nodes.emplace_back(active_node); + } + return SUCCESS; +} + +Status StreamAllocator::AddActiveNodes(NodePtr &switch_node, const vector &ori_active_label_list, + vector &active_label_list, vector &added_active_nodes) { + size_t label_num = ori_active_label_list.size(); + for (size_t i = 0; i < label_num; i++) { + const string &active_label = ori_active_label_list[i]; + if (labeled_streams_.find(active_label) == labeled_streams_.end()) { + GELOGE(FAILED, "can not find stream label %s", active_label.c_str()); + return FAILED; + } + if (labeled_streams_[active_label].size() <= 1) { + active_label_list.emplace_back(active_label); + continue; + } + + string name = switch_node->GetName() + "_" + STREAMACTIVE + "_" + std::to_string(i); + GELOGI("Create StreamActive op %s after node %s.", name.c_str(), switch_node->GetName().c_str()); + OpDescPtr active_op_desc = MakeShared(name, STREAMACTIVE); + GE_CHECK_NOTNULL(active_op_desc); + NodePtr active_node = whole_graph_->AddNode(active_op_desc); + GE_CHECK_NOTNULL(active_node); + + for (NodePtr &node : switch_node->GetOutControlNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string stream_label; + // If GetStr failed, stream_label is empty. + (void)AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label); + if (stream_label != active_label) { + continue; + } + GE_CHECK_NOTNULL(switch_node->GetOutControlAnchor()); + if (switch_node->GetOutControlAnchor()->Unlink(node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Unlink %s to %s failed.", switch_node->GetName().c_str(), node->GetName().c_str()); + return FAILED; + } + GE_CHECK_NOTNULL(active_node->GetOutControlAnchor()); + if (active_node->GetOutControlAnchor()->LinkTo(node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Link %s to %s failed.", active_node->GetName().c_str(), node->GetName().c_str()); + return FAILED; + } + } + + if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) { + GELOGE(FAILED, "Set switch branch node label failed."); + return FAILED; + } + if (SetStreamLabel(active_node, name) != SUCCESS) { + GELOGE(FAILED, "Set stream label failed."); + return FAILED; + } + if (SetActiveLabelList(active_node, {active_label}) != SUCCESS) { + GELOGE(FAILED, "Set active label list failed."); + return FAILED; + } + if (SetActiveStreamList(active_node, active_label) != SUCCESS) { + GELOGE(FAILED, "Set active stream list failed."); + return FAILED; + } + + added_active_nodes.emplace_back(active_node); + active_label_list.emplace_back(name); + } + return SUCCESS; +} + +Status StreamAllocator::SetActiveStreamList(NodePtr &active_node, const string &active_label) { + if (labeled_streams_.find(active_label) == labeled_streams_.end()) { + GELOGE(FAILED, "Can not find stream label %s.", active_label.c_str()); + return FAILED; + } + set &streams = labeled_streams_[active_label]; + vector active_streams(streams.begin(), streams.end()); + if (!AttrUtils::SetListInt(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + GELOGE(FAILED, "SetListInt of %s failed.", ATTR_NAME_ACTIVE_STREAM_LIST.c_str()); + return FAILED; + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index e3901205..ea6d08a3 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -30,21 +30,27 @@ namespace ge { class StreamAllocator { public: - StreamAllocator(ComputeGraphPtr whole_graph, const std::vector &subgraphs) - : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} + StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs); StreamAllocator(const StreamAllocator &) = delete; StreamAllocator &operator=(const StreamAllocator &) = delete; ~StreamAllocator() = default; Status AssignLogicalStreams(const std::map &max_parallel_num, bool hcom_parallel); Status RefreshRealStream(int64_t &stream_num, int64_t &event_num); + const vector &GetHugeStreams() const { return huge_streams_; } private: - Status SplitStreams(); - Status ActiveStreamsBySpecificLabels(); - Status UpdateActiveStreams(std::vector> &splited_streams); - Status ActiveStreamsForLoop(); + Status SplitStreams(std::vector> &split_streams); + + Status AssignSingleStream(); + Status SetActiveStreamsByLabel(); + Status UpdateActiveStreams(const std::vector> &splited_streams); + void UpdateLabelStreams(const std::vector> &split_streams); + Status SetActiveStreamsForSubgraph(); + Status SetActiveStreamsForLoop(); Status CheckStreamActived() const; + Status GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count); + int64_t GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_node_num_one_stream); Status InsertSyncEvents(); Status InsertOneEventInTwoNodes(const NodePtr &cur_node_ptr, const NodePtr &next_node_ptr); @@ -56,11 +62,19 @@ class StreamAllocator { Status RefreshContinuousEvents(); Status InsertSyncEventNodes(); + Status ReorderEventNodes() const; + + Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); + Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector &switch_active_nodes); + Status SetActiveStreamList(NodePtr &active_node, const std::string &active_label); + Status AddActiveNodes(NodePtr &switch_node, const std::vector &ori_active_label_list, + std::vector &active_label_list, std::vector &added_active_nodes); Status AddActiveEntryStream(); Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const; Status InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id); + Status AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, bool not_use_cur); void AddSendEventId(const NodePtr &node, uint32_t event_id); void AddRecvEventId(const NodePtr &node, uint32_t event_id); void RmvSendEventId(const NodePtr &node, uint32_t event_id); @@ -75,11 +89,15 @@ class StreamAllocator { bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; ComputeGraphPtr whole_graph_; - const std::vector &subgraphs_; + const Graph2SubGraphInfoList &subgraphs_; int64_t stream_num_{0}; uint32_t event_num_{0}; + bool enable_single_stream_{false}; + vector huge_streams_; + // > + std::map> labeled_streams_; std::map> specific_activated_labels_; std::set specific_activated_streams_; std::map> specific_activated_streams_nodes_map_; diff --git a/src/ge/graph/build/stream_graph_optimizer.cc b/src/ge/graph/build/stream_graph_optimizer.cc index 5af54783..204a98b2 100644 --- a/src/ge/graph/build/stream_graph_optimizer.cc +++ b/src/ge/graph/build/stream_graph_optimizer.cc @@ -17,7 +17,6 @@ #include "stream_graph_optimizer.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" - #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "init/gelib.h" @@ -25,24 +24,26 @@ using std::vector; namespace { -static const int64_t kInvalidStream = -1; +const int64_t kInvalidStream = -1; } // namespace namespace ge { StreamGraphOptimizer::~StreamGraphOptimizer() {} -void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, vector &subgraph_infos) { +void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { size_t node_size = comp_graph->GetDirectNodesSize(); GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); - for (const auto &sub_graph_info : subgraph_infos) { - ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); - if (sub_graph == nullptr) { - continue; - } - for (ge::NodePtr &node : sub_graph->GetDirectNode()) { - GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); - if ((node->GetType() == domi::END) || (node->GetType() == domi::PLACEHOLDER)) { - node->GetOpDesc()->SetId(static_cast(node_size)); - node_size++; + for (const auto &subgraph_pair : subgraph_map) { + for (const auto &subgraph_info : subgraph_pair.second) { + ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); + if (subgraph == nullptr) { + continue; + } + for (ge::NodePtr &node : subgraph->GetDirectNode()) { + GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); + if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { + node->GetOpDesc()->SetId(static_cast(node_size)); + node_size++; + } } } } @@ -72,67 +73,75 @@ bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { } Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, - vector &subgraph_infos, + Graph2SubGraphInfoList &subgraph_map, struct RunContext &run_context) { - Status ret = SUCCESS; - GELOGI("Begin to Get optimize streamed subgraph."); + GELOGI("Optimize streamed subgraph start."); - RefreshNodeId(comp_graph, subgraph_infos); + RefreshNodeId(comp_graph, subgraph_map); std::shared_ptr instance = ge::GELib::GetInstance(); GE_CHECK_NOTNULL(instance); - for (auto &sub_graph_info : subgraph_infos) { - ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); - if (sub_graph == nullptr) { - continue; - } + for (const auto &subgraph_pair : subgraph_map) { + for (const auto &subgraph_info : subgraph_pair.second) { + ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); + GE_CHECK_NOTNULL(subgraph); - std::string engine_name = sub_graph_info->GetEngineName(); + GELOGI("Optimize subgraph %s", subgraph->GetName().c_str()); - vector graph_optimizers; - if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { - instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); - GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, subgraph num: %zu, graph Optimizer num: %zu.", - sub_graph->GetName().c_str(), engine_name.c_str(), subgraph_infos.size(), graph_optimizers.size()); + std::string engine_name = subgraph_info->GetEngineName(); - auto nodes = sub_graph->GetDirectNode(); - if (nodes.empty()) { - continue; - } - if (!IsSameStreamId(sub_graph)) { - GELOGI("There are more than one stream in subgraph %s", sub_graph->GetName().c_str()); - continue; - } - OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - int64_t stream_id = op_desc->GetStreamId(); - if (static_cast(stream_id) >= run_context.graphStreamList.size()) { - GELOGE(FAILED, "stream_id is bigger than run_context.graphStreamList.size()"); - return FAILED; - } - run_context.stream = run_context.graphStreamList[stream_id]; - GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", - sub_graph->GetName().c_str(), engine_name.c_str(), stream_id, - static_cast(reinterpret_cast(run_context.stream))); - for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { - GE_CHECK_NOTNULL(*iter); - ret = (*iter)->OptimizeStreamGraph(*sub_graph, run_context); - if (ret != SUCCESS) { - GELOGE(ret, - "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " - "Optimizer num: %zu, ret: %u", - sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); - return ret; + vector graph_optimizers; + if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { + instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); + GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, graph Optimizer num: %zu.", + subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); + + auto nodes = subgraph->GetDirectNode(); + if (nodes.empty()) { + continue; + } + + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on == nullptr) { + if (!IsSameStreamId(subgraph)) { + GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); + continue; + } + } + OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + int64_t stream_id = op_desc->GetStreamId(); + if (static_cast(stream_id) >= run_context.graphStreamList.size()) { + GELOGE(FAILED, "stream_id %ld is bigger than run_context.graphStreamList.size() %zu", stream_id, + run_context.graphStreamList.size()); + return FAILED; + } + run_context.stream = run_context.graphStreamList[stream_id]; + GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", + subgraph->GetName().c_str(), engine_name.c_str(), stream_id, + static_cast(reinterpret_cast(run_context.stream))); + for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { + GE_CHECK_NOTNULL(*iter); + Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); + if (ret != SUCCESS) { + GELOGE( + ret, + "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " + "Optimizer num: %zu, ret: %u", + subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); + return ret; + } + GELOGI( + "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " + "Optimizer num: %zu!", + subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); } - GELOGI( - "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " - "Optimizer num: %zu!", - sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); } } } - return ret; + GELOGI("Optimize streamed subgraph success."); + return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/build/stream_graph_optimizer.h b/src/ge/graph/build/stream_graph_optimizer.h index a65f95f2..3133d32d 100644 --- a/src/ge/graph/build/stream_graph_optimizer.h +++ b/src/ge/graph/build/stream_graph_optimizer.h @@ -35,11 +35,11 @@ class StreamGraphOptimizer { virtual ~StreamGraphOptimizer(); - Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, + Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, struct RunContext &run_context); private: - void RefreshNodeId(const ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list); + void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); bool IsSameStreamId(const ComputeGraphPtr &comp_graph); }; diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index 2266f137..a6bc6128 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -17,11 +17,11 @@ #include "graph/build/task_generator.h" #include #include -#include "common/util.h" #include "common/types.h" +#include "common/util.h" +#include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" -#include "framework/common/op/attr_define.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" @@ -31,16 +31,11 @@ #include "graph/utils/type_utils.h" #include "init/gelib.h" -using domi::CONSTANT; -using domi::HCOMALLREDUCE; using domi::LogTimeStampDef; -using domi::MODEL_ATTR_TASK_GEN_BASE_ADDR; -using domi::MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; -using domi::MODEL_ATTR_TASKS; using domi::ModelTaskDef; -using domi::NETOUTPUT; using domi::TaskDef; using std::map; +using std::set; using std::string; using std::vector; @@ -58,7 +53,7 @@ const uint64_t kProfilingBpEndLogid = 2; const uint64_t kProfilingArStartLogid = 3; const uint64_t kProfilingArEndLogid = 4; const uint64_t kProfilingIterEndLogid = 255; -const int64_t kMaxNodeNumInNormalStream = 350; +const int64_t kHashFactor = 100000; const int64_t kInvalidGroupId = -1; } // namespace namespace ge { @@ -188,7 +183,7 @@ Status TaskGenerator::UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t sessi return SUCCESS; } -Status TaskGenerator::SaveL1fusionNodes(map> &l1_fusion_nodes, ComputeGraphPtr &graph) { +Status TaskGenerator::SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph) { std::map nodes_with_group_attr; for (auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); @@ -196,17 +191,19 @@ Status TaskGenerator::SaveL1fusionNodes(map> &l1_f int64_t group_id = kInvalidGroupId; string name = node->GetName(); string type = node->GetType(); - // For l1 fusion ddb pass, task def must be continuous. + // For fusion ddb pass, task def must be continuous. // Part1: store // If op_desc have this tag, store it in the map firstly, // call the elements in the map GenerateTask at last - if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L1_FUSION_GROUP_ID, group_id)) { + // l1 and l2 is for now + if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L1_FUSION_GROUP_ID, group_id) || + ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L2_FUSION_GROUP_ID, group_id)) { auto stream_id = op_desc->GetStreamId(); - auto group_key = group_id + stream_id * kMaxNodeNumInNormalStream; - (void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_L1_FUSION_GROUP_KEY, group_key); - GELOGI("L1Fusion: store node[name:%s(%s), group id:%ld, group key:%ld, stream_id:%ld] task.", name.c_str(), + auto group_key = group_id + stream_id * kHashFactor; + (void)ge::AttrUtils::SetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key); + GELOGD("Fusion: store node[name:%s(%s), group id:%ld, group key:%ld, stream_id:%ld] task.", name.c_str(), type.c_str(), group_id, group_key, op_desc->GetStreamId()); - l1_fusion_nodes[group_key].push_back(node); + fusion_nodes[group_key].push_back(node); nodes_with_group_attr.insert({node, group_id}); } @@ -228,14 +225,12 @@ Status TaskGenerator::SaveL1fusionNodes(map> &l1_f if (call_check) { auto input_group_id = *input_group_ids.begin(); if (group_id != input_group_id) { - GELOGE(INTERNAL_ERROR, - "L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", + GELOGW("Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", name.c_str(), type.c_str(), group_id, input_group_id); - return INTERNAL_ERROR; } } } - GELOGI("L1Fusion: get fusion group numbers [%zu].", l1_fusion_nodes.size()); + GELOGI("Fusion: get fusion group numbers [%zu].", fusion_nodes.size()); return SUCCESS; } @@ -246,28 +241,26 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); return GE_CLI_GE_NOT_INITIALIZED; } - GE_CHK_STATUS_RET(MarkFirstAndLastNode(graph), "MarkFirstAndLastNode failed."); - ProfilingPoint ppoint; - vector ar_ppoint; - GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, ppoint, ar_ppoint)); + GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkNodeAndSetIndex failed."); + ProfilingPoint profiling_point; + vector all_reduce_nodes; + GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); const OpsKernelManager &ops_kernel_manager = ge_lib->OpsKernelManagerObj(); GE_TIMESTAMP_CALLNUM_START(GenerateTask); - // map store l1 fusion nodes - map> l1_fusion_nodes; - string is_l1_fusion_enable = "false"; - graphStatus ret = ge::GetContext().GetOption("ge.l1Fusion", is_l1_fusion_enable); - if ((ret == GRAPH_SUCCESS) && (is_l1_fusion_enable == "true")) { - GE_CHK_STATUS_RET(SaveL1fusionNodes(l1_fusion_nodes, graph)); - } - std::unordered_set l1_fusion_nodes_seen; - int64_t group_id; + // map store fusion nodes + map> fusion_nodes; + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on != nullptr) { + GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); + } + std::unordered_set fusion_nodes_seen; + int64_t group_key; uint32_t node_index = 0; for (auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - op_desc->SetId(node_index); node_index++; string name = node->GetName(); string type = node->GetType(); @@ -279,17 +272,16 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); string op_kernel_lib_name = op_desc->GetOpKernelLibName(); - // For l1 fusion ddb pass, task def must be continuous. + // For fusion ddb pass, task def must be continuous. // Part2: Call - auto l1_fusion_task_info = - L1FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, - ops_kernel_manager, task_def_list, op_name_map, ppoint, ar_ppoint}; - GE_CHK_STATUS_RET(GenerateTaskForL1FusionNode(l1_fusion_task_info, l1_fusion_nodes, l1_fusion_nodes_seen), - "Call GenerateTaskForL1FusionNode node:%s(%s) failed", name.c_str(), type.c_str()); + auto fusion_task_info = + FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, + ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; + GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), + "Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); // continue directly - if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_L1_FUSION_GROUP_ID, group_id)) { - GELOGI("L1Fusion not %s to generate node[name:%s(%s) task again.", op_kernel_lib_name.c_str(), name.c_str(), - type.c_str()); + if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) { + GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); continue; } if (op_kernel_lib_name.empty()) { @@ -315,7 +307,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra // Profiling task size_t task_list_size_before = task_def_list.size(); - GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, ppoint, ar_ppoint, node_index, task_def_list)); + GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); run_context.stream = run_context.graphStreamList[stream_id]; GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); @@ -328,7 +320,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra return ret; } // Profiling task - GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, ppoint, ar_ppoint, node_index, task_def_list)); + GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); size_t task_list_size_after = task_def_list.size(); // If tasks is reduced @@ -358,11 +350,11 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra return SUCCESS; } -Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_info, - std::map> &l1_fusion_nodes, - std::unordered_set &l1_fusion_nodes_seen) { +Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info, + std::map> &fusion_nodes, + std::unordered_set &fusion_nodes_seen) { Status ret = SUCCESS; - int64_t group_id; + int64_t group_key; auto &run_context = fusion_task_info.run_context; auto &graph = fusion_task_info.graph; auto &node = fusion_task_info.node; @@ -371,15 +363,13 @@ Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_ const auto &ops_kernel_manager = fusion_task_info.ops_kernel_manager; auto &task_def_list = fusion_task_info.task_def_list; auto &op_name_map = fusion_task_info.op_name_map; - auto &ppoint = fusion_task_info.ppoint; - auto &ar_ppoint = fusion_task_info.ar_ppoint; - auto stream_id = fusion_op_desc->GetStreamId(); - // If op_desc have this attr, call nodes with same group id in a stream together - if (ge::AttrUtils::GetInt(fusion_op_desc, ATTR_NAME_L1_FUSION_GROUP_ID, group_id) && - (l1_fusion_nodes_seen.count(node.get()) == 0)) { - auto group_key = group_id + stream_id * kMaxNodeNumInNormalStream; - GELOGI("L1Fusion: start fusion group index[%ld], nodes size[%ld].", group_key, l1_fusion_nodes[group_key].size()); - for (auto &fusion_node : l1_fusion_nodes[group_key]) { + auto &profiling_point = fusion_task_info.profiling_point; + auto &all_reduce_nodes = fusion_task_info.all_reduce_nodes; + // If op_desc have this attr, call nodes with same group key in a stream together + if (ge::AttrUtils::GetInt(fusion_op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key) && + (fusion_nodes_seen.count(node.get()) == 0)) { + GELOGI("Fusion: start fusion group index[%ld], nodes size[%zu].", group_key, fusion_nodes[group_key].size()); + for (auto &fusion_node : fusion_nodes[group_key]) { OpDescPtr op_desc = fusion_node->GetOpDesc(); UpdateOpIsVarAttr(op_desc, graph->GetSessionID()); @@ -387,7 +377,7 @@ Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_ std::string fusion_node_type = fusion_node->GetType(); std::string op_kernel_lib_name = op_desc->GetOpKernelLibName(); if (op_kernel_lib_name.empty()) { - GELOGI("L1Fusion: fusion_node[name:%s(%s)] task no need to generate task.", fusion_node_name.c_str(), + GELOGI("Fusion: fusion_node[name:%s(%s)] task no need to generate task.", fusion_node_name.c_str(), fusion_node_type.c_str()); continue; } @@ -395,14 +385,14 @@ Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_ size_t task_list_size_before = task_def_list.size(); OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); if (kernel_info_store == nullptr) { - GELOGE(INTERNAL_ERROR, "L1Fusion: No ops kernel store found. fusion_node:%s(%s), op_kernel_lib_name=%s.", + GELOGE(INTERNAL_ERROR, "Fusion: No ops kernel store found. fusion_node:%s(%s), op_kernel_lib_name=%s.", fusion_node_name.c_str(), fusion_node_type.c_str(), op_kernel_lib_name.c_str()); return INTERNAL_ERROR; } ret = UpdateAnchorStatus(fusion_node); if (ret != SUCCESS) { - GELOGE(ret, "L1Fusion: Call UpdateAnchorStatus fusion_node:%s(%s) failed", fusion_node_name.c_str(), + GELOGE(ret, "Fusion: Call UpdateAnchorStatus fusion_node:%s(%s) failed", fusion_node_name.c_str(), fusion_node_type.c_str()); return ret; } @@ -410,30 +400,30 @@ Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_ int64_t op_id = op_desc->GetId(); int64_t stream_id = op_desc->GetStreamId(); if (stream_id < 0 || stream_id >= (int64_t)run_context.graphStreamList.size()) { - GELOGE(INTERNAL_ERROR, "L1Fusion: fusion_node[name:%s(%s), id:%ld] stream id is invalid, stream list size=%zu", + GELOGE(INTERNAL_ERROR, "Fusion: fusion_node[name:%s(%s), id:%ld] stream id is invalid, stream list size=%zu", fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, run_context.graphStreamList.size()); return INTERNAL_ERROR; } // profiling task - (void)InsertProfilingTaskBefore(op_desc, ppoint, ar_ppoint, node_index, task_def_list); + (void)InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list); run_context.stream = run_context.graphStreamList[stream_id]; - GELOGI("L1Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), id:%ld, stream_id:%ld] task.", + GELOGI("Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, stream_id); ret = kernel_info_store->GenerateTask(*fusion_node, run_context, task_def_list); if (ret != SUCCESS) { GELOGE(ret, - "L1Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), " + "Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), " "id:%ld, stream_id:%ld] task failed.", op_kernel_lib_name.c_str(), fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, stream_id); return ret; } // profiling task - (void)InsertProfilingTaskAfter(op_desc, ppoint, ar_ppoint, node_index, task_def_list); + (void)InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list); size_t task_list_size_after = task_def_list.size(); // if tasks is reduced if (task_list_size_after < task_list_size_before) { GELOGE(FAILED, - "L1Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), " + "Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), " "id:%ld, stream_id:%ld] task. but task num from %zu to %zu.", op_kernel_lib_name.c_str(), fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, stream_id, task_list_size_before, task_list_size_after); @@ -451,13 +441,13 @@ Status TaskGenerator::GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_ } GELOGI( - "L1Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), id:%ld, stream_id:%ld]" + "Fusion: Call %s to generate fusion_node:[fusion_node_name:%s(%s), id:%ld, stream_id:%ld]" " task finished, generate %u task(s).", op_kernel_lib_name.c_str(), fusion_node_name.c_str(), fusion_node_type.c_str(), op_id, stream_id, task_list_size_after - task_list_size_before); // record nodes which have call generate task successfully - l1_fusion_nodes_seen.insert(fusion_node.get()); + fusion_nodes_seen.insert(fusion_node.get()); node_index++; } } @@ -493,85 +483,231 @@ Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { return SUCCESS; } -Status TaskGenerator::MarkFirstAndLastNode(ComputeGraphPtr &graph) { - std::shared_ptr ge_lib = GELib::GetInstance(); +Status TaskGenerator::MarkNodeAndSetIndex(ComputeGraphPtr &graph) { + auto ge_lib = GELib::GetInstance(); if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized or is finalized"); + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized or is finalized."); return GE_CLI_GE_NOT_INITIALIZED; } - map>> engine_stream_stat; - for (auto &node : graph->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - string op_kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); - int64_t stream_id = node->GetOpDesc()->GetStreamId(); + const auto all_nodes = graph->GetAllNodes(); - if (op_kernel_lib_name.empty()) { - // Reset op kernel lib - (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(node->GetOpDesc()); - op_kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); + int64_t node_index = 0; + for (auto &node : all_nodes) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + op_desc->SetId(node_index++); + } + + map> all_stream_ops; + for (auto &node : all_nodes) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + // Reset op kernel lib name + if (op_desc->GetOpKernelLibName().empty()) { + (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(op_desc); + } + + all_stream_ops[op_desc->GetStreamId()].emplace_back(op_desc); + } + + bool is_single_stream = all_stream_ops.size() == 1; + for (const auto &stream_ops : all_stream_ops) { + Status status = MarkFirstAndLastOps(stream_ops.second, is_single_stream); + if (status != SUCCESS) { + GELOGE(status, "Mark first and last nodes failed."); + return status; + } + } + + return SUCCESS; +} + +Status TaskGenerator::MarkFirstAndLastOps(const vector &ops, bool is_single_stream) const { + vector> continuous_op_lists(1); + const set label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); + for (auto &op_desc : ops) { + string op_type = op_desc->GetType(); + if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || label_op_types.count(op_type) != 0)) { + continuous_op_lists.emplace_back(vector()); + } else { + continuous_op_lists.back().emplace_back(op_desc); + } + } + GELOGI("Number of continuous node lists is %zu.", continuous_op_lists.size()); + + for (const auto &continuous_ops : continuous_op_lists) { + map> first_and_last_ops; + for (auto &op_desc : continuous_ops) { + string op_kernel_lib_name = op_desc->GetOpKernelLibName(); if (op_kernel_lib_name.empty()) { - GELOGE(INTERNAL_ERROR, "node:%s(%s) get op kernel lib failed.", node->GetName().c_str(), - node->GetType().c_str()); + GELOGE(INTERNAL_ERROR, "node:%s(%s) get op kernel lib failed.", op_desc->GetName().c_str(), + op_desc->GetType().c_str()); return INTERNAL_ERROR; } - } - auto it = engine_stream_stat.find(op_kernel_lib_name); - if (it == engine_stream_stat.end()) { - map> stream_map; - std::pair node_pair(node, node); - (void)stream_map.emplace(stream_id, node_pair); - (void)engine_stream_stat.emplace(op_kernel_lib_name, stream_map); - } else { - auto stream_it = it->second.find(stream_id); - if (stream_it == it->second.end()) { - std::pair node_pair(node, node); - (void)it->second.emplace(stream_id, node_pair); + auto it = first_and_last_ops.find(op_kernel_lib_name); + if (it == first_and_last_ops.end()) { + first_and_last_ops.emplace(op_kernel_lib_name, std::make_pair(op_desc, op_desc)); } else { - stream_it->second.second = node; + it->second.second = op_desc; } } - } - for (auto &it : engine_stream_stat) { - for (auto &stream_it : it.second) { - NodePtr &first_node = stream_it.second.first; - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(first_node->GetOpDesc(), kIsFirstNode, true), - GELOGE(FAILED, "SetBool failed."); + for (auto &it : first_and_last_ops) { + auto &op_pair = it.second; + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_pair.first, kIsFirstNode, true), GELOGE(FAILED, "SetBool failed."); return FAILED); - NodePtr &last_node = stream_it.second.second; - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(last_node->GetOpDesc(), kIsLastNode, true), - GELOGE(FAILED, "SetBool failed."); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_pair.second, kIsLastNode, true), GELOGE(FAILED, "SetBool failed."); return FAILED); } } + return SUCCESS; } -Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &ppoint, - vector &ar_ppoint) const { - GE_CHECK_NOTNULL(graph); - const char *is_profiling = std::getenv(kProfilingMode); - if (is_profiling == nullptr) { - return SUCCESS; +Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const { + GELOGI("Start AutoFindFpOpIndex"); + OpDescPtr fp_op_desc = nullptr; + uint32_t current_idx = 0; + uint32_t first_fp = 0; + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + string op_kernel_lib_name = op_desc->GetOpKernelLibName(); + if (op_kernel_lib_name.empty()) { + continue; + } + + if (op_desc->GetType() == GETNEXT || op_desc->GetType() == DATA) { + auto out_anchor = node->GetOutDataAnchor(0); + for (auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto in_node_desc = peer_in_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(in_node_desc); + if (fp_op_desc == nullptr || ((in_node_desc->GetId()) < (fp_op_desc->GetId()))) { + fp_op_desc = in_node_desc; + } + } + GELOGI("Find fp_op_desc is %s, id is %ld", fp_op_desc->GetName().c_str(), fp_op_desc->GetId()); + break; + } } - const char *fp_point = std::getenv(kProfilingFpPoint); - if (fp_point == nullptr) { - GELOGW("first forward profiling op name not set."); + + if (fp_op_desc == nullptr) { + GELOGW("not find fp_op_desc."); return SUCCESS; } - string fp_point_str = string(fp_point); - const char *bp_point = std::getenv(kProfilingBpPoint); - if (bp_point == nullptr) { - GELOGW("last backward profiling op name not set."); + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + current_idx++; + if (op_desc->GetName() == fp_op_desc->GetName()) { + first_fp = current_idx; + GELOGI("First fp name is %s, idx is %u", op_desc->GetName().c_str(), first_fp); + break; + } + } + profiling_point.fp_index = first_fp; + return SUCCESS; +} + +Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + vector &all_reduce_nodes) const { + GELOGI("Start AutoFindBpOpIndex"); + NodePtr bp_node = nullptr; + uint32_t last_bp = 0; + uint32_t iter_end = 0; + uint32_t current_idx = 0; + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + current_idx++; + string op_kernel_lib_name = op_desc->GetOpKernelLibName(); + if (op_kernel_lib_name.empty()) { + continue; + } + + if (op_desc->GetType() == HCOMALLREDUCE) { + bp_node = node; + all_reduce_nodes.emplace_back(current_idx); + GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); + } + if (op_desc->GetType() == NETOUTPUT) { + if (bp_node == nullptr) { + bp_node = node; + } + iter_end = current_idx; + GELOGI("Iter end name %s, idx %u", op_desc->GetName().c_str(), iter_end); + } + } + profiling_point.end_index = iter_end; + + if (bp_node == nullptr) { + GELOGW("not find bp_node."); return SUCCESS; } - string bp_point_str = string(bp_point); + OpDescPtr bp_op_desc = nullptr; + for (auto &in_anchor : bp_node->GetAllInDataAnchors()) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr || out_anchor->GetOwnerNode() == nullptr) { + continue; + } + auto out_node_desc = out_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(out_node_desc); + if (bp_op_desc == nullptr || ((out_node_desc->GetId()) > (bp_op_desc->GetId()))) { + bp_op_desc = out_node_desc; + } + GELOGI("bp_op_desc is %s, id is %ld", bp_op_desc->GetName().c_str(), bp_op_desc->GetId()); + } + + GE_CHECK_NOTNULL(bp_op_desc); + current_idx = 0; + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + current_idx++; + if (op_desc->GetName() == bp_op_desc->GetName()) { + last_bp = current_idx; + GELOGI("First bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); + break; + } + } + profiling_point.bp_index = last_bp; + return SUCCESS; +} + +Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, + ProfilingPoint &profiling_point) const { + GELOGI("Start FindFpOfEnv"); + uint32_t current_idx = 0; + uint32_t first_fp = 0; + for (auto &node : graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(node->GetOpDesc()); + current_idx++; + string op_kernel_lib_name = op_desc->GetOpKernelLibName(); + if (op_kernel_lib_name.empty()) { + continue; + } + + if (first_fp == 0 && IsProfPoint(op_desc, fp_point_str)) { + first_fp = current_idx; + GELOGI("First fp name from env is %s, idx %u", op_desc->GetName().c_str(), first_fp); + } + } + + profiling_point.fp_index = first_fp; + return SUCCESS; +} + +Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, + ProfilingPoint &profiling_point, vector &all_reduce_nodes) const { + GELOGI("Start FindBpOfEnv"); uint32_t current_idx = 0; uint32_t iter_end = 0; uint32_t last_bp = 0; - uint32_t first_fp = 0; for (auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(node->GetOpDesc()); @@ -585,43 +721,87 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi iter_end = current_idx; GELOGI("Iter end name %s, idx %u", op_desc->GetName().c_str(), iter_end); } - if (op_desc->GetType() == HCOMALLREDUCE) { - ar_ppoint.emplace_back(current_idx); + all_reduce_nodes.emplace_back(current_idx); GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); } + if (IsProfPoint(op_desc, bp_point_str)) { + last_bp = current_idx; + GELOGI("Last bp name from env is %s, idx %u", op_desc->GetName().c_str(), last_bp); + } + } - if (first_fp == 0 && IsProfPoint(op_desc, fp_point_str)) { - first_fp = current_idx; - GELOGI("First fp name %s, idx %u", op_desc->GetName().c_str(), first_fp); + profiling_point.bp_index = last_bp; + profiling_point.end_index = iter_end; + return SUCCESS; +} + +Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + vector &all_reduce_nodes) const { + GELOGI("Start FindProfilingTaskIndex."); + GE_CHECK_NOTNULL(graph); + const char *profiling_mode = std::getenv(kProfilingMode); + bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn(); + if (!is_profiling) { + return SUCCESS; + } + + const char *fp_point = std::getenv(kProfilingFpPoint); + Status ret; + if (fp_point == nullptr) { + ret = AutoFindFpOpIndex(graph, profiling_point); + if (ret != SUCCESS) { + GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); + return SUCCESS; } + } - if (IsProfPoint(op_desc, bp_point_str)) { - last_bp = current_idx; - GELOGI("Last bp name %s, idx %u", op_desc->GetName().c_str(), last_bp); + const char *bp_point = std::getenv(kProfilingBpPoint); + if (bp_point == nullptr) { + ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); + if (ret != SUCCESS) { + GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); + return SUCCESS; + } + } + + if (fp_point != nullptr) { + string fp_point_str = string(fp_point); + ret = FindFpOfEnv(graph, fp_point_str, profiling_point); + if (ret != SUCCESS) { + GELOGW("First backward profiling op name set but FindFpOfEnv failed."); + return SUCCESS; } } - ppoint.fp_index = first_fp; - ppoint.bp_index = last_bp; - ppoint.end_index = iter_end; + if (bp_point != nullptr) { + string bp_point_str = string(bp_point); + ret = FindBpOfEnv(graph, bp_point_str, profiling_point, all_reduce_nodes); + if (ret != SUCCESS) { + GELOGW("Last backward profiling op name set but FindBpOfEnv failed."); + return SUCCESS; + } + } + bool train_graph = graph->GetNeedIteration(); - if (ppoint.fp_index == 0 && train_graph) { + if (profiling_point.fp_index == 0 && train_graph) { GELOGE(FAILED, "First forward op name can't be found in graph for training trace."); } - if (ppoint.bp_index == 0 && train_graph) { + if (profiling_point.bp_index == 0 && train_graph) { GELOGE(FAILED, "Last backward op name can't be found in graph for training trace."); } return SUCCESS; } -Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &ppoint, - vector &ar_ppoint, uint32_t node_index, +Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, + vector &all_reduce_nodes, uint32_t node_index, vector &task_def_list) { - const char *is_profiling = std::getenv(kProfilingMode); - if ((is_profiling == nullptr) || (ppoint.fp_index == 0) || (ppoint.bp_index == 0) || (ppoint.end_index == 0)) { + const char *profiling_mode = std::getenv(kProfilingMode); + bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn(); + if (!is_profiling || (profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || + (profiling_point.end_index == 0)) { return SUCCESS; } - if (ppoint.fp_index == node_index) { + if (profiling_point.fp_index == node_index) { uint64_t jobid_log_id = ge::GetContext().TraceId(); GELOGI("The first FP operator is %s, idx %u, job_id %lu", op_desc->GetName().c_str(), node_index, jobid_log_id); @@ -645,8 +825,8 @@ Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const task_def_list.emplace_back(fp_task_def); } - for (size_t i = 0; i < ar_ppoint.size(); i++) { - if (ar_ppoint[i] != node_index) { + for (size_t i = 0; i < all_reduce_nodes.size(); i++) { + if (all_reduce_nodes[i] != node_index) { continue; } GELOGI("The start allreduce operator is %s, idx %u", op_desc->GetName().c_str(), node_index); @@ -667,15 +847,17 @@ Status TaskGenerator::InsertProfilingTaskBefore(const OpDescPtr &op_desc, const return SUCCESS; } -Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &ppoint, - vector &ar_ppoint, uint32_t node_index, +Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, + vector &all_reduce_nodes, uint32_t node_index, vector &task_def_list) { GE_CHECK_NOTNULL(op_desc); - const char *is_profiling = std::getenv(kProfilingMode); - if ((is_profiling == nullptr) || (ppoint.fp_index == 0) || (ppoint.bp_index == 0) || (ppoint.end_index == 0)) { + const char *profiling_mode = std::getenv(kProfilingMode); + bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn(); + if (!is_profiling || (profiling_point.fp_index == 0) || (profiling_point.bp_index == 0) || + (profiling_point.end_index == 0)) { return SUCCESS; } - if (ppoint.bp_index == node_index) { + if (profiling_point.bp_index == node_index) { GELOGI("The last BP operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef bp_task_def; bp_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); @@ -686,7 +868,7 @@ Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const P bp_log_def->set_notify(false); task_def_list.emplace_back(bp_task_def); } - if (ppoint.end_index == node_index) { + if (profiling_point.end_index == node_index) { GELOGI("The iteration end operator is %s, idx %u", op_desc->GetName().c_str(), node_index); TaskDef end_task_def; end_task_def.set_type(RT_MODEL_TASK_PROFILER_TRACE); @@ -698,8 +880,8 @@ Status TaskGenerator::InsertProfilingTaskAfter(const OpDescPtr &op_desc, const P task_def_list.emplace_back(end_task_def); } - for (size_t i = 0; i < ar_ppoint.size(); i++) { - if (ar_ppoint[i] != node_index) { + for (size_t i = 0; i < all_reduce_nodes.size(); i++) { + if (all_reduce_nodes[i] != node_index) { continue; } GELOGI("The end allreduce operator is %s, idx %u", op_desc->GetName().c_str(), node_index); diff --git a/src/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h index 1f4a1f0d..c666244b 100644 --- a/src/ge/graph/build/task_generator.h +++ b/src/ge/graph/build/task_generator.h @@ -38,8 +38,8 @@ struct ProfilingPoint { uint32_t bp_index = 0; uint32_t end_index = 0; }; -// Describes infos needed by generate task for l1 fusion node -struct L1FusionTaskInfo { +// Describes infos needed by generate task for fusion node +struct FusionTaskInfo { RunContext &run_context; ComputeGraphPtr &graph; NodePtr &node; @@ -49,8 +49,8 @@ struct L1FusionTaskInfo { const OpsKernelManager &ops_kernel_manager; std::vector &task_def_list; std::map &op_name_map; - ProfilingPoint &ppoint; - vector ar_ppoint; + ProfilingPoint &profiling_point; + vector all_reduce_nodes; }; class TaskGenerator { @@ -103,32 +103,44 @@ class TaskGenerator { Status AddModelTaskToModel(const domi::ModelTaskDef &model_task_def, uint64_t session_id, Model &model_def, RunContext &run_context); - // Mark first and last node according to the same stream and engine - Status MarkFirstAndLastNode(ComputeGraphPtr &graph); + Status MarkNodeAndSetIndex(ComputeGraphPtr &graph); + + // Mark first and last op according to the same stream and engine + Status MarkFirstAndLastOps(const vector &ops, bool is_single_stream) const; // profiling interface - Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &ppoint, - std::vector &ar_ppoint) const; - Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &ppoint, - std::vector &ar_ppoint, uint32_t node_index, + Status AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point) const; + Status AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + vector &all_reduce_nodes) const; + + Status FindFpOfEnv(const ComputeGraphPtr &graph, const std::string &fp_point_str, + ProfilingPoint &profiling_point) const; + Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point, + vector &all_reduce_nodes) const; + + Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, + std::vector &all_reduce_nodes) const; + Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, + std::vector &all_reduce_nodes, uint32_t node_index, std::vector &task_def_list); - Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &ppoint, - std::vector &ar_ppoint, uint32_t node_index, + Status InsertProfilingTaskAfter(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, + std::vector &all_reduce_nodes, uint32_t node_index, std::vector &task_def_list); static bool IsProfPoint(const OpDescPtr &op, const std::string &name); - /// call engine to generate task for l1 fusion node. - /// @param L1FusionTaskInfo - /// @param l1_fusion_nodes: nodes in graph with groud_id attr which means l1 fusion node - /// @param l1_fusion_nodes_seen: l1 fusion node has been called generate task + + /// call engine to generate task for fusion node. + /// @param FusionTaskInfo + /// @param fusion_nodes: nodes in graph with groud_id attr which means fusion node + /// @param fusion_nodes_seen: fusion node has been called generate task /// @return SUCCESS:seccess /// Other: failed /// - Status GenerateTaskForL1FusionNode(L1FusionTaskInfo &fusion_task_info, - std::map> &l1_fusion_nodes, - std::unordered_set &l1_fusion_nodes_seen); + Status GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info, + std::map> &fusion_nodes, + std::unordered_set &fusion_nodes_seen); - Status SaveL1fusionNodes(map> &l1_fusion_nodes, ComputeGraphPtr &graph); + Status SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph); uint8_t *var_mem_base_ = nullptr; uint64_t var_mem_size_ = 0; diff --git a/src/ge/graph/common/ge_call_wrapper.h b/src/ge/graph/common/ge_call_wrapper.h new file mode 100644 index 00000000..a21d642e --- /dev/null +++ b/src/ge/graph/common/ge_call_wrapper.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_CALL_WRAPPER_H_ +#define GE_GE_CALL_WRAPPER_H_ +#include "framework/common/debug/ge_log.h" + +#define RUN_WITH_TIMESTAMP_NAME(var_name, prefix, func, ...) \ + do { \ + GE_TIMESTAMP_START(var_name); \ + auto ret_inner_macro = func(__VA_ARGS__); \ + GE_TIMESTAMP_END(var_name, #prefix "::" #func) \ + if (ret_inner_macro != ge::SUCCESS) { \ + GELOGE(ret_inner_macro, "Failed to process " #prefix "_" #func); \ + return ret_inner_macro; \ + } \ + } while (0) + +#define JOIN_NAME_INNER(a, b) a##b +#define JOIN_NAME(a, b) JOIN_NAME_INNER(a, b) +#define COUNTER_NAME(a) JOIN_NAME(a, __COUNTER__) +#define GE_RUN(prefix, func, ...) \ + RUN_WITH_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) + +#endif // GE_GE_CALL_WRAPPER_H_ diff --git a/src/ge/graph/common/omg_util.cc b/src/ge/graph/common/omg_util.cc index 0a6d98d2..5c76d0a1 100644 --- a/src/ge/graph/common/omg_util.cc +++ b/src/ge/graph/common/omg_util.cc @@ -18,17 +18,10 @@ #include -#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -using domi::ATTR_NAME_STREAM_LABEL; -using domi::FRAMEWORKOP; -using ge::AttrUtils; -using ge::OpDescPtr; - namespace ge { /// /// @brief get the Original Type of FrameworkOp @@ -61,7 +54,7 @@ Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) { OpDescPtr tmp_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(tmp_desc); - if (!AttrUtils::SetStr(tmp_desc, ATTR_NAME_STREAM_LABEL, label)) { + if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) { GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str()); return FAILED; } @@ -78,7 +71,7 @@ Status SetCycleEvent(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); OpDescPtr tmp_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(tmp_desc); - if (!AttrUtils::SetBool(tmp_desc, ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) { + if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) { GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_CYCLE_EVENT_FLAG failed", node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/common/transop_util.cc b/src/ge/graph/common/transop_util.cc index 8631529e..3250929d 100644 --- a/src/ge/graph/common/transop_util.cc +++ b/src/ge/graph/common/transop_util.cc @@ -24,8 +24,8 @@ const int kInvalidTransopDataIndex = -1; namespace ge { TransOpUtil::TransOpUtil() { - transop_index_map_ = {{domi::TRANSDATA, 0}, {domi::TRANSPOSE, 0}, {domi::TRANSPOSED, 0}, {domi::RESHAPE, 0}, - {domi::REFORMAT, 0}, {domi::CAST, 0}, {domi::SQUEEZE, 0}, {domi::EXPANDDIMS, 0}}; + transop_index_map_ = {{TRANSDATA, 0}, {TRANSPOSE, 0}, {TRANSPOSED, 0}, {RESHAPE, 0}, + {REFORMAT, 0}, {CAST, 0}, {SQUEEZE, 0}, {EXPANDDIMS, 0}}; } TransOpUtil::~TransOpUtil() {} diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 56e31de3..6ef1f671 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -16,6 +16,7 @@ #include "graph/execute/graph_execute.h" +#include #include #include "common/ge_inner_error_codes.h" @@ -180,12 +181,11 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor const GeTensor *in_tensor = &input_tensor[i]; GE_CHECK_NOTNULL(in_tensor); if ((addrVec[i] != nullptr) && (in_tensor->GetData().data() != nullptr)) { - errno_t s_ret = memcpy_s(addrVec[i], bufferSizeVec[i], in_tensor->GetData().data(), in_tensor->GetData().size()); - if (s_ret != 0) { - GELOGE(GE_GRAPH_EXECUTE_FAILED, - "[GraphExecutor] memcpy input data failed, errno: %d, dst size: %u, src size: %zu.", s_ret, - bufferSizeVec[i], in_tensor->GetData().size()); - return GE_GRAPH_EXECUTE_FAILED; + rtError_t rt_ret = rtMemcpy(addrVec[i], bufferSizeVec[i], in_tensor->GetData().data(), + in_tensor->GetData().size(), RT_MEMCPY_HOST_TO_HOST); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; } } @@ -275,8 +275,8 @@ Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, - std::vector &output_tensor) { + const std::vector &input_tensor) { GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); if (graph_id != last_graph_id_) { auto ret = FreeExecuteMemory(); @@ -355,7 +354,7 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_m } last_graph_id_ = graph_id; GE_CHECK_NOTNULL_EXEC(ge_model, return FAILED); - Status ret = AsyncExecuteModel(ge_model->GetModelId(), input_tensor, output_tensor); + Status ret = AsyncExecuteModel(ge_model->GetModelId(), input_tensor); if (ret != SUCCESS) { GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); return GE_GRAPH_SYNC_MODEL_FAILED; @@ -365,14 +364,13 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_m return SUCCESS; } -Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector &inputs, - std::vector &outputs) { +Status GraphExecutor::AsyncExecuteModel(uint32_t model_id, const std::vector &inputs) { try { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); GELOGI("RunAsync begin.model_id %u", model_id); - Status ret = model_manager->DataInputTensor(model_id, inputs, outputs); + Status ret = model_manager->DataInputTensor(model_id, inputs); if (ret != SUCCESS) { GELOGE(ret, "RunAsync: DataInput fail"); return ret; diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index 5e926ae3..e7fd2084 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -49,8 +49,8 @@ class GraphExecutor { Status ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, const std::vector &input_tensor, std::vector &output_tensor); - Status ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, const std::vector &input_tensor, - std::vector &output_tensor); + Status ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, + const std::vector &input_tensor); Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr listener); @@ -92,8 +92,7 @@ class GraphExecutor { Status SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor); - Status AsyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, - std::vector &output_tensor); + Status AsyncExecuteModel(uint32_t model_id, const std::vector &input_tensor); void InitModelIdInfo(std::vector &out_model_id_info, std::vector &sub_graph_vec, uint32_t output_size); diff --git a/src/ge/graph/label/case_label_maker.cc b/src/ge/graph/label/case_label_maker.cc index 2d024499..88b7ee8b 100644 --- a/src/ge/graph/label/case_label_maker.cc +++ b/src/ge/graph/label/case_label_maker.cc @@ -23,8 +23,6 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::CASE; - namespace ge { constexpr uint32_t kCasePredIndex = 0; constexpr uint32_t kMinCaseBranch = 1; @@ -57,20 +55,35 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { return SUCCESS; } + NodePtr first_label = nullptr; + ComputeGraphPtr first_graph = nullptr; std::vector switch_labels; uint32_t last_label_index = label_index++; for (uint32_t index = 0; index < graph_num; ++index) { ComputeGraphPtr graph = parent_graph_->GetSubgraph(graph_names[index]); GE_CHECK_NOTNULL(graph); - // all branch, add label node to head. + // all branch, add label and stream active nodes to head. + std::string stream_active_name = + parent_node_->GetName() + "/StreamActive_" + std::to_string(index); // rtStreamActive + NodePtr stream_active = AddStreamActive(graph, stream_active_name); + if (stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", graph->GetName().c_str()); + return FAILED; + } + uint32_t curr_label_index = label_index++; std::string label_set_name = parent_node_->GetName() + "/LabelSet_" + std::to_string(index); // rtLabelSet - if (AddLabelSetEnter(graph, label_set_name, curr_label_index) == nullptr) { + NodePtr label = AddLabelSetEnter(graph, label_set_name, curr_label_index, stream_active); + if (label == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", graph->GetName().c_str()); return FAILED; } switch_labels.emplace_back(curr_label_index); + if (index == 0) { // save first subgraph node for switch. + first_label = label; + first_graph = graph; + } if (index + 1 < graph_num) { // middle node, add goto node to tail. @@ -90,23 +103,27 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { } // Add Switch node for first branch. - ComputeGraphPtr first_graph = parent_graph_->GetSubgraph(graph_names[0]); + GE_CHECK_NOTNULL(first_label); GE_CHECK_NOTNULL(first_graph); - GeTensorDesc pred_desc = case_desc->GetInputDesc(kCasePredIndex); - GeTensorDesc cond_desc(GeShape(pred_desc.GetShape().GetDims()), pred_desc.GetFormat(), DT_UINT32); - // first case, add switch node to head. const std::string label_switch_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex - NodePtr switch_node = AddLabelSwitchEnter(first_graph, label_switch_name, cond_desc, switch_labels); + const GeTensorDesc &pred_desc = case_desc->GetInputDesc(kCasePredIndex); + NodePtr switch_node = AddLabelSwitchEnter(first_graph, label_switch_name, pred_desc, switch_labels); if (switch_node == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label switch failed.", first_graph->GetName().c_str()); return FAILED; } + // Link control edge to then branch head. + if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), first_label->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", first_label->GetName().c_str()); + return FAILED; + } + uint32_t parent_index = 0; // Case cond input is first. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData"; - if (AddLabelSwitchIndex(first_graph, data_name, cond_desc, switch_node, parent_index) == nullptr) { + if (AddLabelSwitchIndex(first_graph, data_name, pred_desc, switch_node, parent_index) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add switch input failed.", first_graph->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/label/case_label_maker.h b/src/ge/graph/label/case_label_maker.h index 3c43911c..2e3b584b 100644 --- a/src/ge/graph/label/case_label_maker.h +++ b/src/ge/graph/label/case_label_maker.h @@ -20,65 +20,71 @@ #include "graph/node.h" #include "graph/label/label_maker.h" /******************************************************************************* - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | Case | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | Case | + +------------+ +-----------+ - | Node | +-----------+ - +-----------+ /|SwitchByIdx| - | Node | A +-----------+ - +-----------+ / \|LabelSet(0)| - | Case | | +-----------+ - +-----------+ | | c | - | Node | | +-----------+ - +-----------+ | | a | - | Node | | +-----------+ - +-----------+ | | s | - | Node | | +-----------+ - +-----------+ | | e | - | +-----------+ - ====> | | LabelGoto |\ - V +-----------+ \ - |\ \ - | \ +-----------+ | - +-----------+ +-----------+ +-----------+ | \|LabelSet(1)| | - | c | | c | | c | | +-----------+ | - +-----------+ +-----------+ +-----------+ | | c | | - | a | | a | | a | | +-----------+ | - +-----------+ +-----------+ +-----------+ | | a | | - | s | | s | | s | | +-----------+ | - +-----------+ +-----------+ +-----------+ | | s | | - | e | | e | | e | | +-----------+ | - +-----------+ +-----------+ +-----------+ | | e | | - | +-----------+ V - | | LabelGoto |\ | - V +-----------+ \ | - \ \| - \ +-----------+ | - \|LabelSet(2)| | - +-----------+ | - | c | | - +-----------+ | - | a | | - +-----------+ | - | s | | - +-----------+ V - | e | / - +-----------+ / - | LabelSet |/ - +-----------+ + | Node | +------------+ + +-----------+ /|SwitchByIdx | + | Node | A +------------+ + +-----------+ / \|LabelSet(0) | + | Case | | +------------+ + +-----------+ | |StreamActive| + | Node | | +------------+ + +-----------+ | | c | + | Node | | +------------+ + +-----------+ | | a | + | Node | | +------------+ + +-----------+ | | s | + | Node | | +------------+ + +-----------+ | | e | + | +------------+ + ====> | | LabelGoto |\ + V +------------+ \ + |\ \ + | \ +------------+ | + +-----------+ +-----------+ +-----------+ | \|LabelSet(1) | | + | c | | c | | c | | +------------+ | + +-----------+ +-----------+ +-----------+ | |StreamActive| | + | a | | a | | a | | +------------+ | + +-----------+ +-----------+ +-----------+ | | c | | + | s | | s | | s | | +------------+ | + +-----------+ +-----------+ +-----------+ | | a | | + | e | | e | | e | | +------------+ | + +-----------+ +-----------+ +-----------+ | | s | | + | +------------+ | + | | e | | + | +------------+ V + | | LabelGoto |\ | + V +------------+ \ | + \ \| + \ +------------+ | + \|LabelSet(2) | | + +------------+ | + |StreamActive| | + +------------+ | + | c | | + +------------+ | + | a | | + +------------+ | + | s | | + +------------+ V + | e | / + +------------+ / + | LabelSet |/ + +------------+ - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | Node | + +------------+ *******************************************************************************/ namespace ge { diff --git a/src/ge/graph/label/if_label_maker.cc b/src/ge/graph/label/if_label_maker.cc index 142cf625..62722e7c 100644 --- a/src/ge/graph/label/if_label_maker.cc +++ b/src/ge/graph/label/if_label_maker.cc @@ -23,24 +23,11 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::_IF; -using domi::IF; -using domi::STATELESSIF; - namespace ge { constexpr uint8_t kIfPredIndex = 0; constexpr uint8_t kThenBranchIndex = 0; constexpr uint8_t kElseBranchIndex = 1; -// All ---> Node ---> If ---> Node ---> -// | -// V -// { Data ---> Node ---> Node ---> NetOutput } -// -// All ---> Node ---> If ---> Node ---> -// \ / -// { Node ---> Node } - /** * @ingroup ge * @brief Make label node to functional call. @@ -70,13 +57,22 @@ Status IfOpLabelMaker::Run(uint32_t &label_index) { const uint32_t then_enter_index = label_index++; const uint32_t else_enter_index = label_index++; const uint32_t else_leave_index = label_index++; - const std::string then_enter_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex - const std::string then_label_name = parent_node_->GetName() + "/ThenLabelSet"; // rtLabelSet(0) - const std::string then_leave_name = parent_node_->GetName() + "/LabelGoto"; // rtLabelGoto - const std::string else_enter_name = parent_node_->GetName() + "/ElseLabelSet"; // rtLabelSet(1) - const std::string else_leave_name = parent_node_->GetName() + "/LeaveLabelSet"; // rtLabelSet + const std::string then_enter_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex + const std::string then_label_name = parent_node_->GetName() + "/ThenLabelSet"; // rtLabelSet(0) + const std::string then_active_name = parent_node_->GetName() + "/ThenStreamActive"; // rtStreamActive + const std::string then_leave_name = parent_node_->GetName() + "/LabelGoto"; // rtLabelGoto + const std::string else_enter_name = parent_node_->GetName() + "/ElseLabelSet"; // rtLabelSet(1) + const std::string else_active_name = parent_node_->GetName() + "/ElseStreamActive"; // rtStreamActive + const std::string else_leave_name = parent_node_->GetName() + "/LeaveLabelSet"; // rtLabelSet + + NodePtr then_stream_active = AddStreamActive(then_sub_graph, then_active_name); + if (then_stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", then_sub_graph->GetName().c_str()); + return FAILED; + } - if (AddLabelSetEnter(then_sub_graph, then_label_name, then_enter_index) == nullptr) { + NodePtr then_enter_label = AddLabelSetEnter(then_sub_graph, then_label_name, then_enter_index, then_stream_active); + if (then_enter_label == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", then_sub_graph->GetName().c_str()); return FAILED; } @@ -86,7 +82,13 @@ Status IfOpLabelMaker::Run(uint32_t &label_index) { return FAILED; } - if (AddLabelSetEnter(else_sub_graph, else_enter_name, else_enter_index) == nullptr) { + NodePtr else_stream_active = AddStreamActive(else_sub_graph, else_active_name); + if (else_stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", else_sub_graph->GetName().c_str()); + return FAILED; + } + + if (AddLabelSetEnter(else_sub_graph, else_enter_name, else_enter_index, else_stream_active) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", else_sub_graph->GetName().c_str()); return FAILED; } @@ -99,17 +101,22 @@ Status IfOpLabelMaker::Run(uint32_t &label_index) { // true ==> 1 ==> switch_labels[1] ==> then_enter_index const std::vector switch_labels = {else_enter_index, then_enter_index}; - GeTensorDesc pred_desc = if_desc->GetInputDesc(kIfPredIndex); - GeTensorDesc cond_desc(GeShape(pred_desc.GetShape().GetDims()), pred_desc.GetFormat(), DT_UINT32); - NodePtr switch_node = AddLabelSwitchEnter(then_sub_graph, then_enter_name, cond_desc, switch_labels); + const GeTensorDesc &pred_desc = if_desc->GetInputDesc(kIfPredIndex); + NodePtr switch_node = AddLabelSwitchEnter(then_sub_graph, then_enter_name, pred_desc, switch_labels); if (switch_node == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label switch failed.", then_sub_graph->GetName().c_str()); return FAILED; } + // Link control edge to then branch head. + if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), then_enter_label->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", then_enter_label->GetName().c_str()); + return FAILED; + } + uint32_t parent_index = 0; // If cond input is first. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData"; - if (AddLabelSwitchIndex(then_sub_graph, data_name, cond_desc, switch_node, parent_index) == nullptr) { + if (AddLabelSwitchIndex(then_sub_graph, data_name, pred_desc, switch_node, parent_index) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add switch input failed.", then_sub_graph->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/label/if_label_maker.h b/src/ge/graph/label/if_label_maker.h index 1ee41819..9ffe8fca 100644 --- a/src/ge/graph/label/if_label_maker.h +++ b/src/ge/graph/label/if_label_maker.h @@ -20,51 +20,55 @@ #include "graph/node.h" #include "graph/label/label_maker.h" /******************************************************************************* - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | If | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | If | + +------------+ +-----------+ - | Node | +-----------+ - +-----------+ /|SwitchByIdx| - | Node | A +-----------+ - +-----------+ / \|LabelSet(1)| - | If | | +-----------+ - +-----------+ | | t | - | Node | | +-----------+ - +-----------+ | | h | - | Node | | +-----------+ - +-----------+ | | e | - | Node | | +-----------+ - +-----------+ | | n | - | +-----------+ - ====> | | LabelGoto |\ - V +-----------+ \ - +-----------+ +-----------+ \ \ - | t | | e | \ +-----------+ | - +-----------+ +-----------+ \|LabelSet(0)| | - | h | | l | +-----------+ | - +-----------+ +-----------+ | e | | - | e | | s | +-----------+ | - +-----------+ +-----------+ | l | | - | n | | e | +-----------+ | - +-----------+ +-----------+ | s | | - +-----------+ V - | e | / - +-----------+ / - | LabelSet |/ - +-----------+ + | Node | +------------+ + +-----------+ /|SwitchByIdx | + | Node | A +------------+ + +-----------+ / \|LabelSet(1) | + | If | | +------------+ + +-----------+ | |StreamActive| + | Node | | +------------+ + +-----------+ | | t | + | Node | | +------------+ + +-----------+ | | h | + | Node | | +------------+ + +-----------+ | | e | + | Node | | +------------+ + +-----------+ | | n | + | +------------+ + ====> | | LabelGoto |\ + V +------------+ \ + +-----------+ +-----------+ \ \ + | t | | e | \ +------------+ | + +-----------+ +-----------+ \|LabelSet(0) | | + | h | | l | +------------+ | + +-----------+ +-----------+ |StreamActive| | + | e | | s | +------------+ | + +-----------+ +-----------+ | e | | + | n | | e | +------------+ | + +-----------+ +-----------+ | l | | + +------------+ | + | s | | + +------------+ V + | e | / + +------------+ / + | LabelSet |/ + +------------+ - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | Node | + +------------+ *******************************************************************************/ namespace ge { diff --git a/src/ge/graph/label/label_maker.cc b/src/ge/graph/label/label_maker.cc index d3701f07..0c3e0adf 100644 --- a/src/ge/graph/label/label_maker.cc +++ b/src/ge/graph/label/label_maker.cc @@ -23,12 +23,149 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::DATA; -using domi::LABELGOTO; -using domi::LABELSET; -using domi::LABELSWITCHBYINDEX; +namespace { +const int64_t kInvalidStreamId = -1; +} // namespace namespace ge { +/** + * @ingroup ge + * @brief Set stream id for head node. + * @param [in] graph: graph for add node. + * @param [in] op_desc: OpDesc for set logical stream id. + * @return: void + */ +void LabelMaker::SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + int64_t stream_id = kInvalidStreamId; + const auto &node_list = graph->GetDirectNode(); + for (size_t i = 0; i < node_list.size(); ++i) { + const auto &node = node_list.at(i); + GE_CHECK_NOTNULL_EXEC(node, continue); + + stream_id = node->GetOpDesc()->GetStreamId(); + if (stream_id != kInvalidStreamId) { + break; + } + } + + GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); + op_desc->SetStreamId(stream_id); +} + +/** + * @ingroup ge + * @brief Set stream id for tail node. + * @param [in] graph: graph for add node. + * @param [in] op_desc: OpDesc for set logical stream id. + * @return: void + */ +void LabelMaker::SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + int64_t stream_id = kInvalidStreamId; + const auto &node_list = graph->GetDirectNode(); + for (size_t i = node_list.size(); i > 0; --i) { + const auto &node = node_list.at(i - 1); // i from list size, need shift 1. + GE_CHECK_NOTNULL_EXEC(node, continue); + + stream_id = node->GetOpDesc()->GetStreamId(); + if (stream_id != kInvalidStreamId) { + break; + } + } + + GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); + op_desc->SetStreamId(stream_id); +} + +/** + * @ingroup ge + * @brief Set stream id for parent node. + * @param [in] graph: graph for add node. + * @param [in] op_desc: OpDesc for set logical stream id. + * @return: void + */ +void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + int64_t stream_id = kInvalidStreamId; + const auto &node = graph->GetParentNode(); + if (node != nullptr) { + stream_id = node->GetOpDesc()->GetStreamId(); + } + + GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); + op_desc->SetStreamId(stream_id); +} + +/** + * @ingroup ge + * @brief Link Node to Graph head. + * @param [in] graph: graph for add node. + * @param [in] lb_node: Node for set link to head. + * @return: SUCCESS / FAILED + */ +Status LabelMaker::AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(node); + + std::set linked_nodes; + for (const NodePtr &n : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(n); + if (n->GetType() != DATA) { + continue; + } + + // Link control edge to graph head. + for (const NodePtr &out_node : n->GetOutAllNodes()) { + if (linked_nodes.count(out_node) > 0) { + continue; + } + + (void)linked_nodes.insert(out_node); + if (GraphUtils::AddEdge(node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", node->GetName().c_str(), + out_node->GetName().c_str()); + return FAILED; + } + } + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Add StreamActive node at graph front. + * @param [in] graph: graph for add node. + * @param [in] name: stream active node name. + * @return: NodePtr for success / nullptr for fail + */ +NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::string &name) { + GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + + const auto &node_list = graph->GetDirectNode(); + if (node_list.empty()) { + GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str()); + return nullptr; + } + + OpDescPtr op_desc = MakeShared(name, STREAMACTIVE); + GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); + + GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str()); + vector active_streams; + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName()); + (void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); + NodePtr stream_active = graph->AddNodeFront(op_desc); + GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); + + // Link control edge to graph head. + if (AddCtrlLink2Data(graph, stream_active) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add ctrl edge for graph %s failed.", graph->GetName().c_str()); + return nullptr; + } + + return stream_active; +} + /** * @ingroup ge * @brief Add LabelSet node at graph front. @@ -37,31 +174,29 @@ namespace ge { * @param [in] index: label id for set. * @return: NodePtr for success / nullptr for fail */ -NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { +NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index, + NodePtr &stream_active) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); + GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); const auto &node_list = graph->GetDirectNode(); - auto it = node_list.begin(); - if (it == node_list.end()) { + if (node_list.empty()) { GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } - const NodePtr &node = *it; - GE_CHECK_NOTNULL_EXEC(node, return nullptr); OpDescPtr op_desc = MakeShared(name, LABELSET); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); NodePtr label_set = graph->AddNodeFront(op_desc); GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); - // Link control edge to graph head. - if (GraphUtils::AddEdge(label_set->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str()); + if (GraphUtils::AddEdge(label_set->GetOutControlAnchor(), stream_active->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", label_set->GetName().c_str(), + stream_active->GetName().c_str()); return nullptr; } @@ -78,8 +213,6 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st */ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -93,10 +226,11 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st OpDescPtr op_desc = MakeShared(name, LABELSET); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); - NodePtr label_set = graph->AddNodeFront(op_desc); + NodePtr label_set = graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); // Link control edge to graph tail. @@ -118,8 +252,6 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st */ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.begin(); @@ -127,20 +259,16 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } - const NodePtr &node = *it; - GE_CHECK_NOTNULL_EXEC(node, return nullptr); - OpDescPtr op_desc = MakeShared(name, LABELGOTO); + OpDescPtr op_desc = MakeShared(name, LABELGOTOEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); NodePtr label_goto = graph->AddNodeFront(op_desc); - GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); - - // Link control edge to graph head. - if (GraphUtils::AddEdge(label_goto->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelGoto: Add ctrl edge to %s failed.", node->GetName().c_str()); + if (label_goto == nullptr) { + GELOGE(INTERNAL_ERROR, "LabelGoto: Add to graph %s failed.", graph->GetName().c_str()); return nullptr; } @@ -157,8 +285,6 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s */ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -170,13 +296,15 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s const NodePtr &node = *it; GE_CHECK_NOTNULL_EXEC(node, return nullptr); - OpDescPtr op_desc = MakeShared(name, LABELGOTO); + OpDescPtr op_desc = MakeShared(name, LABELGOTOEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdLeave(graph, op_desc); GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); NodePtr label_goto = graph->AddNode(op_desc); GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); + SetStreamIdOwner(graph, op_desc); // Link control edge to graph tail. if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_goto->GetInControlAnchor()) != SUCCESS) { @@ -199,8 +327,6 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const std::vector &labels) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.begin(); @@ -208,11 +334,10 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } - const NodePtr &node = *it; - GE_CHECK_NOTNULL_EXEC(node, return nullptr); OpDescPtr op_desc = MakeShared(name, LABELSWITCHBYINDEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { @@ -226,11 +351,8 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: } NodePtr label_switch = graph->AddNodeFront(op_desc); - GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); - - // Link control edge to graph head. - if (GraphUtils::AddEdge(label_switch->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", node->GetName().c_str()); + if (label_switch == nullptr) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add to graph %s failed.", graph->GetName().c_str()); return nullptr; } @@ -249,8 +371,6 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const std::vector &labels) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -264,6 +384,7 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: OpDescPtr op_desc = MakeShared(name, LABELSWITCHBYINDEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + SetStreamIdOwner(graph, op_desc); GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { @@ -301,13 +422,16 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const NodePtr &sw_node, uint32_t parent_index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); - GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); OpDescPtr op_desc = MakeShared(name, DATA); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + op_desc->SetStreamId(kInvalidStreamId); GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); + if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add data input desc failed."); + return nullptr; + } if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add data output desc failed."); return nullptr; diff --git a/src/ge/graph/label/label_maker.h b/src/ge/graph/label/label_maker.h index d5878bc9..f77c3dc9 100644 --- a/src/ge/graph/label/label_maker.h +++ b/src/ge/graph/label/label_maker.h @@ -35,7 +35,10 @@ class LabelMaker { virtual Status Run(uint32_t &label_index) = 0; - NodePtr AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index); + NodePtr AddStreamActive(const ComputeGraphPtr &graph, const std::string &name); + + NodePtr AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index, + NodePtr &stream_active); NodePtr AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index); NodePtr AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index); @@ -55,6 +58,12 @@ class LabelMaker { protected: NodePtr parent_node_; ComputeGraphPtr parent_graph_; + + private: + Status AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node); + void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); + void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); + void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); }; } // namespace ge #endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ diff --git a/src/ge/graph/label/partitioned_call_label_maker.cc b/src/ge/graph/label/partitioned_call_label_maker.cc index da36431e..39c88717 100644 --- a/src/ge/graph/label/partitioned_call_label_maker.cc +++ b/src/ge/graph/label/partitioned_call_label_maker.cc @@ -22,9 +22,6 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::PARTITIONEDCALL; -using domi::STATEFULPARTITIONEDCALL; - namespace ge { constexpr int32_t kSubGraphIndex = 0; diff --git a/src/ge/graph/label/while_label_maker.cc b/src/ge/graph/label/while_label_maker.cc index e2a6ddbd..55b5dfb2 100644 --- a/src/ge/graph/label/while_label_maker.cc +++ b/src/ge/graph/label/while_label_maker.cc @@ -23,10 +23,6 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -using domi::_WHILE; -using domi::STATELESSWHILE; -using domi::WHILE; - namespace ge { constexpr uint8_t kCondOutputNum = 1; constexpr uint8_t kCondOutputIndex = 0; @@ -62,18 +58,32 @@ Status WhileOpLabelMaker::Run(uint32_t &label_index) { const uint32_t cond_enter_index = label_index++; const uint32_t body_enter_index = label_index++; const uint32_t body_leave_index = label_index++; - const std::string cond_enter_name = parent_node_->GetName() + "/CondLabelSet"; // rtLabelSet - const std::string cond_leave_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex - const std::string body_enter_name = parent_node_->GetName() + "/EnterLabelSet"; // rtLabelSet - const std::string goto_leave_name = parent_node_->GetName() + "/LabelGoto"; // rtLabelGoto - const std::string body_leave_name = parent_node_->GetName() + "/LeaveLabelSet"; // rtLabelSet + const std::string cond_enter_name = parent_node_->GetName() + "/CondLabelSet"; // rtLabelSet + const std::string cond_active_name = parent_node_->GetName() + "/CondStreamActive"; // rtStreamActive + const std::string cond_leave_name = parent_node_->GetName() + "/LabelSwitch"; // rtLabelSwitchByIndex + const std::string body_enter_name = parent_node_->GetName() + "/EnterLabelSet"; // rtLabelSet + const std::string body_active_name = parent_node_->GetName() + "/EnterStreamActive"; // rtStreamActive + const std::string goto_leave_name = parent_node_->GetName() + "/LabelGoto"; // rtLabelGoto + const std::string body_leave_name = parent_node_->GetName() + "/LeaveLabelSet"; // rtLabelSet + + NodePtr cond_stream_active = AddStreamActive(cond_graph, cond_active_name); + if (cond_stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", cond_graph->GetName().c_str()); + return FAILED; + } - if (AddLabelSetEnter(cond_graph, cond_enter_name, cond_enter_index) == nullptr) { + if (AddLabelSetEnter(cond_graph, cond_enter_name, cond_enter_index, cond_stream_active) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", cond_graph->GetName().c_str()); return FAILED; } - if (AddLabelSetEnter(body_graph, body_enter_name, body_enter_index) == nullptr) { + NodePtr body_stream_active = AddStreamActive(body_graph, body_active_name); + if (body_stream_active == nullptr) { + GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active failed.", body_graph->GetName().c_str()); + return FAILED; + } + + if (AddLabelSetEnter(body_graph, body_enter_name, body_enter_index, body_stream_active) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", body_graph->GetName().c_str()); return FAILED; } @@ -88,7 +98,7 @@ Status WhileOpLabelMaker::Run(uint32_t &label_index) { return FAILED; } - NodePtr cond_out_node = cond_graph->FindNode(domi::NODE_NAME_NET_OUTPUT); + NodePtr cond_out_node = cond_graph->FindNode(NODE_NAME_NET_OUTPUT); GE_CHECK_NOTNULL(cond_out_node); OpDescPtr cond_out_desc = cond_out_node->GetOpDesc(); GE_CHECK_NOTNULL(cond_out_desc); diff --git a/src/ge/graph/label/while_label_maker.h b/src/ge/graph/label/while_label_maker.h index ea7787a2..42e6a490 100644 --- a/src/ge/graph/label/while_label_maker.h +++ b/src/ge/graph/label/while_label_maker.h @@ -20,51 +20,55 @@ #include "graph/node.h" #include "graph/label/label_maker.h" /******************************************************************************* - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | While | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | While | + +------------+ +-----------+ - | Node | +-----------+ - +-----------+ | LabelSet |\ - | Node | +-----------+ \ - +-----------+ | c | \ - | While | +-----------+ A - +-----------+ | o | | - | Node | +-----------+ | - +-----------+ | n | | - | Node | +-----------+ | - +-----------+ | d | | - | Node | +-----------+ | - +-----------+ /|SwitchByIdx| | - / +-----------+ | - ====> / | - | \ +-----------+ | - | \|LabelSet(1)| | - | +-----------+ | - +-----------+ +-----------+ | | b | | - | c | | b | | +-----------+ | - +-----------+ +-----------+ | | o | | - | o | | o | | +-----------+ | - +-----------+ +-----------+ | | d | | - | n | | d | | +-----------+ | - +-----------+ +-----------+ | | y | / - | d | | y | V +-----------+ / - +-----------+ +-----------+ \ | LabelGoto |/ - \ +-----------+ - \|LabelSet(0)| - +-----------+ + | Node | +------------+ + +-----------+ | LabelSet |\ + | Node | +------------+ \ + +-----------+ |StreamActive| \ + | Node | +------------+ A + +-----------+ | c | | + | While | +------------+ | + +-----------+ | o | | + | Node | +------------+ | + +-----------+ | n | | + | Node | +------------+ | + +-----------+ | d | | + | Node | +------------+ | + +-----------+ /|SwitchByIdx | | + / +------------+ | + ====> / | + | \ +------------+ | + | \|LabelSet(1) | | + | +------------+ | + | |StreamActive| | + | +------------+ | + +-----------+ +-----------+ | | b | | + | c | | b | | +------------+ | + +-----------+ +-----------+ | | o | | + | o | | o | | +------------+ | + +-----------+ +-----------+ | | d | | + | n | | d | | +------------+ | + +-----------+ +-----------+ | | y | / + | d | | y | V +------------+ / + +-----------+ +-----------+ \ | LabelGoto |/ + \ +------------+ + \|LabelSet(0) | + +------------+ - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ - | Node | - +-----------+ + +------------+ + | Node | + +------------+ + | Node | + +------------+ + | Node | + +------------+ *******************************************************************************/ namespace ge { diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index 5f1704af..87db7f3d 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -33,63 +33,6 @@ GraphLoader::GraphLoader() = default; GraphLoader::~GraphLoader() = default; -Status GraphLoader::LoadGraph(const std::shared_ptr &ge_model_ptr, - const std::shared_ptr &model_listener, ModelIdInfo &model_id_info) { - if (ge_model_ptr == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); - return GE_GRAPH_PARAM_NULLPTR; - } - - if (model_listener == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_listener is nullptr."); - return GE_GRAPH_PARAM_NULLPTR; - } - - std::shared_ptr model_ptr; - if (ModelHelper::TransGeModelToModel(ge_model_ptr, model_ptr) != SUCCESS) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph ge_model_ptr trans to ModelPtr failed."); - return GE_GRAPH_PARAM_NULLPTR; - } - GELOGI("[LoadGraph] GE load graph via new ome begin."); - Status ret = LoadModelOnline(model_id_info.model_id, model_ptr, model_listener); - if (ret != SUCCESS) { - GELOGE(ret, "[LoadGraph] GE load graph LoadGraph() return fail. err: %u", ret); - return ret; - } - GELOGI("[LoadGraph] GE load graph success. modelId: %u", model_id_info.model_id); - return ret; -} - -Status GraphLoader::LoadGraphAsync(const std::shared_ptr &ge_model_ptr, - const std::shared_ptr &model_async_listener, - ModelIdInfo &model_id_info) { - if (ge_model_ptr == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraphAsync] GE load graph model_ptr is nullptr."); - return GE_GRAPH_PARAM_NULLPTR; - } - - if (model_async_listener == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraphAsync] GE load graph model_listener is nullptr."); - return GE_GRAPH_PARAM_NULLPTR; - } - - std::shared_ptr model_ptr; - if (ModelHelper::TransGeModelToModel(ge_model_ptr, model_ptr) != SUCCESS) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph ge_model_ptr trans to ModelPtr failed."); - return GE_GRAPH_PARAM_NULLPTR; - } - - GELOGI("[LoadGraphAsync] GE load graph begin."); - Status ret = LoadModelOnline(model_id_info.model_id, model_ptr, model_async_listener); - if (ret != SUCCESS) { - GELOGE(ret, "[LoadGraphAsync] GE load graph LoadGraphAsync() return fail. err: %u", ret); - return ret; - } - - GELOGI("[LoadGraphAsync] GE load graph success. modelId: %u", model_id_info.model_id); - return ret; -} - Status GraphLoader::UnloadModel(uint32_t model_id) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); @@ -110,73 +53,56 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { return SUCCESS; } -Status GraphLoader::LoadModelOnline(uint32_t &model_id, std::shared_ptr &model, +Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr &ge_model_ptr, const std::shared_ptr &listener) { + GELOGI("Load model online begin."); rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_LOAD); return RT_FAILED; } + if (ge_model_ptr == nullptr) { + GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); + return GE_GRAPH_PARAM_NULLPTR; + } + model_id = ge_model_ptr->GetModelId(); - try { - GELOGI("Load begin, model_id:%u.", model_id); - auto model_manager = ModelManager::GetInstance(); - GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->LoadModelOnline(model_id, model, listener); - if (ret != SUCCESS) { - GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); - CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); - - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - } - return ret; - } - - ret = model_manager->Start(model_id); - if (ret != SUCCESS) { - if (model_manager->Unload(model_id) != SUCCESS) { - GELOGE(ret, "LoadModel: Unload failed while trying to unload after a failed start."); - } - - rt_ret = rtDeviceReset(GetContext().DeviceId()); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - } - - GELOGE(ret, "LoadModel: Start failed."); - CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); - return ret; - } + auto model_manager = ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status ret = model_manager->LoadModelOnline(model_id, ge_model_ptr, listener); + if (ret != SUCCESS) { + GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); + CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); - GELOGI("Load model success, model_id:%u.", model_id); - } catch (std::bad_alloc &) { rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); } + return ret; + } + + ret = model_manager->Start(model_id); + if (ret != SUCCESS) { + if (model_manager->Unload(model_id) != SUCCESS) { + GELOGE(ret, "LoadModel: Unload failed while trying to unload after a failed start."); + } - GELOGE(MEMALLOC_FAILED, "Load model failed, bad memory allocation occur !"); - CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); - return MEMALLOC_FAILED; - } catch (...) { rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); } - GELOGE(FAILED, "Load failed, some exceptions occur !"); - CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); - return FAILED; + GELOGE(ret, "LoadModel: Start failed."); + CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); + return ret; } - rt_ret = rtDeviceReset(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } + GELOGI("Load model online success, model_id:%u.", model_id); return SUCCESS; } @@ -196,13 +122,13 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string ModelData &model_data) { Status ret; try { - if (!domi::CheckInputPathValid(path)) { + if (!CheckInputPathValid(path)) { GELOGE(PARAM_INVALID, "model path is invalid: %s", path.c_str()); return PARAM_INVALID; } GELOGI("Load model begin, model path is: %s", path.c_str()); - if (!key_path.empty() && !domi::CheckInputPathValid(key_path)) { + if (!key_path.empty() && !CheckInputPathValid(key_path)) { GELOGE(PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); return PARAM_INVALID; } @@ -439,4 +365,15 @@ Status GraphLoader::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { } return SUCCESS; } + +Status GraphLoader::DestroyAicpuSessionForInfer(uint32_t model_id) { + auto model_manager = ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status ret = model_manager->DestroyAicpuSessionForInfer(model_id); + if (ret != SUCCESS) { + GELOGE(ret, "Destroy aicpu serrion for infer failed."); + return ret; + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/load/graph_loader.h b/src/ge/graph/load/graph_loader.h index 946e39ec..5fe37a36 100644 --- a/src/ge/graph/load/graph_loader.h +++ b/src/ge/graph/load/graph_loader.h @@ -40,12 +40,6 @@ class GraphLoader { GraphLoader &operator=(const GraphLoader &in) = delete; - Status LoadGraph(const std::shared_ptr &ge_model_ptr, - const std::shared_ptr &model_listener, ModelIdInfo &model_id_info); - - Status LoadGraphAsync(const std::shared_ptr &ge_model_ptr, - const std::shared_ptr &model_async_listener, ModelIdInfo &model_id_info); - static Status UnloadModel(uint32_t model_id); static Status GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size); @@ -75,8 +69,9 @@ class GraphLoader { static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id); - private: - static Status LoadModelOnline(uint32_t &model_id, std::shared_ptr &model, + static Status DestroyAicpuSessionForInfer(uint32_t model_id); + + static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &model, const std::shared_ptr &listener); }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc index c3de44c9..06111015 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc @@ -20,11 +20,11 @@ namespace { const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch const char *const kCpuTaskModelEnqueue = "modelEnqueue"; -const char *const kCpuTaskPrepareInput = "modelPrepareInput"; const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; -const char *const kCpuTaskPrepareOutput = "modelPrepareOutput"; +const char *const kCpuTaskPrepareOutput = "bufferPrepareOutput"; const char *const kCpuTaskModelDequeue = "modelDequeue"; const char *const kCpuTaskModelRepeat = "modelRepeat"; +const char *const kCpuTaskZeroCopy = "zeroCpy"; } // namespace namespace ge { @@ -93,19 +93,19 @@ Status CpuTaskModelDequeue::Distribute() { /// /// @ingroup ge -/// @brief definiteness queue schedule, bind output queue to task. -/// @param [in] addr: NetOutput Op input tensor address. -/// @param [in] size: NetOutput Op input tensor size. -/// @param [in] in_mbuf: input mbuf addr for input data. +/// @brief definiteness queue schedule, zero copy. +/// @param [in] mbuf_list: input/output mbuf addr list for input/output data. +/// @param [in] outside_addrs: model input/output memory addr /// @return: 0 for success / others for failed /// -Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf) { +Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, + std::map> &outside_addrs) { if ((args_ != nullptr) || (args_size_ > 0)) { GELOGE(FAILED, "Task already initialized, size: %u", args_size_); return FAILED; } - args_size_ = sizeof(PrepareInputInfo); + args_size_ = sizeof(AddrMapInfo); rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); @@ -113,36 +113,99 @@ Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbu } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) - PrepareInputInfo prepare; - prepare.in_mbuf = in_mbuf; - prepare.mbuf_offset = 0; - prepare.data_size = size; - prepare.data_addr = addr; - status = rtMemcpy(args_, args_size_, &prepare, args_size_, RT_MEMCPY_HOST_TO_DEVICE); + AddrMapInfo addr_map_info; + for (const auto &addrs : outside_addrs) { + addr_map_info.addr_num += addrs.second.size(); + } + GELOGI("addr_map_info.addr_num is %u", addr_map_info.addr_num); + + // init src_addrs/dst_addrs + size_t index = 0; + vector src_addrs; + vector dst_addrs; + for (const auto &addrs : outside_addrs) { + for (size_t i = 0; i < addrs.second.size(); ++i) { + src_addrs.push_back(mbuf_list.at(index)); + dst_addrs.push_back(reinterpret_cast(reinterpret_cast(addrs.second.at(i)))); + } + index++; + } + + // malloc mem for src_addrs/dst_addrs, and copy data of src_addrs/dst_addrs + status = rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); + if (status != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); + return RT_FAILED; + } + status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), + src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); return RT_FAILED; } + status = rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); + if (status != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); + return RT_FAILED; + } + status = rtMemcpy(dst_addr_, dst_addrs.size() * sizeof(uint64_t), dst_addrs.data(), + dst_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (status != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); + return RT_FAILED; + } + + // src_addr_list is init to src_addr, which is the point to src_addrs + if (!src_addrs.empty() && !dst_addrs.empty()) { + addr_map_info.src_addr_list = static_cast(reinterpret_cast(src_addr_)); + addr_map_info.dst_addr_list = static_cast(reinterpret_cast(dst_addr_)); + GELOGI("src_addr_list is %lu, dst_addr_list is %lu", addr_map_info.src_addr_list, addr_map_info.dst_addr_list); + } + + status = rtMemcpy(args_, args_size_, &addr_map_info, sizeof(AddrMapInfo), RT_MEMCPY_HOST_TO_DEVICE); + if (status != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); + return RT_FAILED; + } return SUCCESS; } -Status CpuTaskPrepareInput::Distribute() { +Status CpuTaskZeroCopy::Distribute() { if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); return FAILED; } - rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskPrepareInput, kCoreDim, args_, args_size_, nullptr, stream_); + rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskZeroCopy, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt CpuKernelLaunch PrepareInput failed, status: 0x%X", status); + GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ZeroCopy failed, status: 0x%X", status); return RT_FAILED; } - GELOGI("Cpu kernel launch prepare input task success."); + GELOGI("Cpu kernel launch zero copy task success."); return SUCCESS; } +CpuTaskZeroCopy::~CpuTaskZeroCopy() { + if (src_addr_ == nullptr && dst_addr_ == nullptr) { + return; + } + if (src_addr_ != nullptr) { + rtError_t status = rtFree(src_addr_); + if (status != RT_ERROR_NONE) { + GELOGW("Call rt free failed, status: 0x%x", status); + } + } + if (dst_addr_ != nullptr) { + rtError_t status = rtFree(dst_addr_); + if (status != RT_ERROR_NONE) { + GELOGW("Call rt free failed, status: 0x%x", status); + } + } + src_addr_ = nullptr; + dst_addr_ = nullptr; +} /// /// @ingroup ge /// @brief definiteness queue schedule, bind output queue to task. diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h index 8a9af63f..c4ae4df5 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h @@ -47,6 +47,13 @@ struct PrepareOutputInfo { uintptr_t out_mbuf; // output mbuf addr }; +// For AICPU task "modelZeroCopy" +struct AddrMapInfo { + uint32_t addr_num = 0; + uint64_t src_addr_list; + uint64_t dst_addr_list; +}; + /// /// @ingroup ge /// @brief CpuTask base, inherit from TaskInfo used for manage. @@ -78,17 +85,21 @@ class CpuTaskModelDequeue : public CpuTaskInfo { /// /// @ingroup ge -/// @brief definiteness queue schedule, bind output queue to task. +/// @brief definiteness queue schedule, zero copy. /// -class CpuTaskPrepareInput : public CpuTaskInfo { +class CpuTaskZeroCopy : public CpuTaskInfo { public: - explicit CpuTaskPrepareInput(rtStream_t stream) : CpuTaskInfo(stream) {} - ~CpuTaskPrepareInput() override {} + explicit CpuTaskZeroCopy(rtStream_t stream) : CpuTaskInfo(stream) {} + ~CpuTaskZeroCopy() override; Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } - Status Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf); + Status Init(std::vector &mbuf_list, std::map> &outside_addrs); Status Distribute() override; + + private: + void *src_addr_ = nullptr; + void *dst_addr_ = nullptr; }; /// diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index 85bbd5bc..db675132 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -15,22 +15,26 @@ */ #include "graph/load/new_model_manager/data_dumper.h" +#include +#include +#include #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/anchor.h" #include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/model_utils.h" #include "graph/utils/attr_utils.h" -#include "model_utils.h" #include "proto/ge_ir.pb.h" #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" -using domi::ENDGRAPH; - namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; +const char *const kDumpOutput = "output"; +const char *const kDumpInput = "input"; +const char *const kDumpAll = "all"; } // namespace static int32_t GetIrDataType(ge::DataType data_type) { @@ -122,14 +126,20 @@ void DataDumper::SaveDumpInput(const std::shared_ptr &node) { } } -void DataDumper::SaveDumpTask(uint32_t task_id, const std::shared_ptr &op_desc, uintptr_t args) { +void DataDumper::SaveEndGraphId(uint32_t task_id, uint32_t stream_id) { + end_graph_task_id_ = task_id; + end_graph_stream_id_ = stream_id; +} + +void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, + uintptr_t args) { if (op_desc == nullptr) { GELOGE(PARAM_INVALID, "Opdesc is nullptr"); return; } - GELOGI("Save dump task %s, id: %u.", op_desc->GetName().c_str(), task_id); - op_list_.push_back({task_id, op_desc, args, true}); + GELOGI("Save dump task %s, task id: %u, stream id: %u", op_desc->GetName().c_str(), task_id, stream_id); + op_list_.push_back({task_id, stream_id, op_desc, args, true}); for (auto iter = input_map_.equal_range(op_desc->GetName()); iter.first != iter.second; ++iter.first) { InnerInputMapping &inner_input_mapping = iter.first->second; @@ -149,7 +159,7 @@ void DataDumper::SaveDumpTask(uint32_t task_id, const std::shared_ptr &o uintptr_t data_addr = args - sizeof(void *) * op_desc->GetInputOffset().size() + sizeof(void *) * static_cast(inner_input_mapping.input_anchor_index); GELOGI("Save input dump task %s, id: %u.", data_op->GetName().c_str(), task_id); - op_list_.push_back({task_id, data_op, data_addr, false, inner_input_mapping.input_anchor_index, + op_list_.push_back({task_id, stream_id, data_op, data_addr, false, inner_input_mapping.input_anchor_index, inner_input_mapping.output_anchor_index, input_tensor->GetShape().GetDims()}); } } @@ -178,99 +188,107 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin } } -Status DataDumper::LoadDumpInfo() { - PrintCheckLog(); +Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + GELOGI("Start dump output"); + if (inner_dump_info.is_task) { + // tbe or aicpu op + const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); + if (output_descs.size() != output_addrs.size()) { + GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), + inner_dump_info.op->GetName().c_str(), output_descs.size()); + return PARAM_INVALID; + } - if (op_list_.empty()) { + for (size_t i = 0; i < output_descs.size(); ++i) { + aicpu::dump::Output output; + output.set_data_type(static_cast(GetIrDataType(output_descs.at(i).GetDataType()))); + output.set_format(static_cast(output_descs.at(i).GetFormat())); + + for (auto dim : output_descs.at(i).GetShape().GetDims()) { + output.mutable_shape()->add_dim(dim); + } + + std::string origin_name; + int32_t origin_output_index = -1; + (void)AttrUtils::GetStr(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); + (void)AttrUtils::GetInt(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + output.set_original_name(origin_name); + output.set_original_output_index(origin_output_index); + output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); + output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); + // due to lhisi virtual addr bug, cannot use args now + output.set_address(static_cast(reinterpret_cast(output_addrs[i]))); + + task.mutable_output()->Add(std::move(output)); + } return SUCCESS; } - aicpu::dump::OpMappingInfo op_mapping_info; - op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); - op_mapping_info.set_model_name(model_name_); - op_mapping_info.set_model_id(model_id_); - op_mapping_info.set_flag(kAicpuLoadFlag); - op_mapping_info.set_dump_step(PropertiesManager::Instance().GetDumpStep()); - SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); - GELOGD("Dump step in load dump info is %s", PropertiesManager::Instance().GetDumpStep().c_str()); + // else data, const or variable op + aicpu::dump::Output output; + auto output_tensor = inner_dump_info.op->GetOutputDescPtr(inner_dump_info.output_anchor_index); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); + if (output_tensor == nullptr) { + GELOGE(PARAM_INVALID, "output_tensor is null, index: %d, size: %zu.", inner_dump_info.output_anchor_index, + inner_dump_info.op->GetOutputsSize()); + return PARAM_INVALID; + } - for (const auto &op_iter : op_list_) { - aicpu::dump::Task task; - auto op_desc = op_iter.op; - task.set_end_graph(op_desc->GetType() == ENDGRAPH); - task.set_task_id(op_iter.task_id); - task.mutable_op()->set_op_name(op_desc->GetName()); - task.mutable_op()->set_op_type(op_desc->GetType()); + output.set_data_type(static_cast(GetIrDataType(output_tensor->GetDataType()))); + output.set_format(static_cast(output_tensor->GetFormat())); - if (op_iter.is_task) { - // tbe or aicpu op - const auto &output_descs = op_iter.op->GetAllOutputsDesc(); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_iter.op, false); - if (output_descs.size() != output_addrs.size()) { - GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), - op_iter.op->GetName().c_str(), output_descs.size()); - return PARAM_INVALID; - } + for (auto dim : inner_dump_info.dims) { + output.mutable_shape()->add_dim(dim); + } - for (size_t i = 0; i < output_descs.size(); ++i) { - aicpu::dump::Output output; - output.set_data_type(static_cast(GetIrDataType(output_descs.at(i).GetDataType()))); - output.set_format(static_cast(output_descs.at(i).GetFormat())); + std::string origin_name; + int32_t origin_output_index = -1; + (void)AttrUtils::GetStr(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); + (void)AttrUtils::GetInt(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + output.set_original_name(origin_name); + output.set_original_output_index(origin_output_index); + output.set_original_output_format(static_cast(output_tensor->GetOriginFormat())); + output.set_original_output_data_type(static_cast(output_tensor->GetOriginDataType())); + // due to lhisi virtual addr bug, cannot use args now + if (inner_dump_info.output_anchor_index >= static_cast(output_addrs.size())) { + GELOGE(FAILED, "Index is out of range."); + return FAILED; + } + output.set_address( + static_cast(reinterpret_cast(output_addrs[inner_dump_info.output_anchor_index]))); - for (auto dim : output_descs.at(i).GetShape().GetDims()) { - output.mutable_shape()->add_dim(dim); - } + task.mutable_output()->Add(std::move(output)); - std::string origin_name; - int32_t origin_output_index = -1; - (void)AttrUtils::GetStr(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - (void)AttrUtils::GetInt(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - output.set_original_name(origin_name); - output.set_original_output_index(origin_output_index); - output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); - output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); - // due to lhisi virtual addr bug, cannot use args now - output.set_address(static_cast(reinterpret_cast(output_addrs[i]))); - - task.mutable_output()->Add(std::move(output)); - } - op_mapping_info.mutable_task()->Add(std::move(task)); - continue; - } + return SUCCESS; +} - // else data, const or variable op - aicpu::dump::Output output; - auto output_tensor = op_iter.op->GetOutputDescPtr(op_iter.output_anchor_index); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, op_iter.op, false); - if (output_tensor == nullptr) { - GELOGE(PARAM_INVALID, "output_tensor is null, index: %d, size: %zu.", op_iter.output_anchor_index, - op_iter.op->GetOutputsSize()); - return PARAM_INVALID; - } +Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + GELOGI("Start dump input"); + const auto &input_descs = inner_dump_info.op->GetAllInputsDesc(); + const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op, false); + if (input_descs.size() != input_addrs.size()) { + GELOGE(PARAM_INVALID, "Invalid input desc addrs size %zu, op %s has %zu input desc.", input_addrs.size(), + inner_dump_info.op->GetName().c_str(), input_descs.size()); + return PARAM_INVALID; + } - output.set_data_type(static_cast(GetIrDataType(output_tensor->GetDataType()))); - output.set_format(static_cast(output_tensor->GetFormat())); + for (size_t i = 0; i < input_descs.size(); ++i) { + aicpu::dump::Input input; + input.set_data_type(static_cast(GetIrDataType(input_descs.at(i).GetDataType()))); + input.set_format(static_cast(input_descs.at(i).GetFormat())); - for (auto dim : op_iter.dims) { - output.mutable_shape()->add_dim(dim); + for (auto dim : input_descs.at(i).GetShape().GetDims()) { + input.mutable_shape()->add_dim(dim); } - std::string origin_name; - int32_t origin_output_index = -1; - (void)AttrUtils::GetStr(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - (void)AttrUtils::GetInt(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - output.set_original_name(origin_name); - output.set_original_output_index(origin_output_index); - output.set_original_output_format(static_cast(output_tensor->GetOriginFormat())); - output.set_original_output_data_type(static_cast(output_tensor->GetOriginDataType())); - // due to lhisi virtual addr bug, cannot use args now - output.set_address(static_cast(reinterpret_cast(output_addrs[op_iter.output_anchor_index]))); - - task.mutable_output()->Add(std::move(output)); - - op_mapping_info.mutable_task()->Add(std::move(task)); + input.set_address(static_cast(reinterpret_cast(input_addrs[i]))); + task.mutable_input()->Add(std::move(input)); } + return SUCCESS; +} +Status DataDumper::ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info) { std::string proto_str; size_t proto_size = op_mapping_info.ByteSizeLong(); bool ret = op_mapping_info.SerializeToString(&proto_str); @@ -308,23 +326,7 @@ Status DataDumper::LoadDumpInfo() { return SUCCESS; } -Status DataDumper::UnloadDumpInfo() { - if (!load_flag_) { - GELOGI("No need to UnloadDumpInfo."); - load_flag_ = false; - return SUCCESS; - } - - GELOGI("UnloadDumpInfo start."); - aicpu::dump::OpMappingInfo op_mapping_info; - op_mapping_info.set_model_id(model_id_); - op_mapping_info.set_flag(kAicpuUnloadFlag); - - for (const auto &op_iter : op_list_) { - aicpu::dump::Task task; - task.set_task_id(op_iter.task_id); - op_mapping_info.mutable_task()->Add(std::move(task)); - } +Status DataDumper::ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info) { std::string proto_str; size_t proto_size = op_mapping_info.ByteSizeLong(); bool ret = op_mapping_info.SerializeToString(&proto_str); @@ -360,6 +362,117 @@ Status DataDumper::UnloadDumpInfo() { GELOGI("UnloadDumpInfo success, proto size: %zu.", proto_size); return SUCCESS; } +Status DataDumper::LoadDumpInfo() { + PrintCheckLog(); + + if (op_list_.empty()) { + return SUCCESS; + } + + aicpu::dump::OpMappingInfo op_mapping_info; + op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); + op_mapping_info.set_model_name(model_name_); + op_mapping_info.set_model_id(model_id_); + op_mapping_info.set_flag(kAicpuLoadFlag); + op_mapping_info.set_dump_step(PropertiesManager::Instance().GetDumpStep()); + SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); + GELOGD("Dump step in load dump info is %s", PropertiesManager::Instance().GetDumpStep().c_str()); + + for (const auto &op_iter : op_list_) { + aicpu::dump::Task task; + auto op_desc = op_iter.op; + task.set_end_graph(false); + task.set_task_id(op_iter.task_id); + task.set_stream_id(op_iter.stream_id); + task.mutable_op()->set_op_name(op_desc->GetName()); + task.mutable_op()->set_op_type(op_desc->GetType()); + + if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput) { + if (DumpOutput(op_iter, task) != SUCCESS) { + GELOGE(FAILED, "Dump output failed"); + return FAILED; + } + op_mapping_info.mutable_task()->Add(std::move(task)); + continue; + } + if (PropertiesManager::Instance().GetDumpMode() == kDumpInput) { + if (op_iter.is_task) { + if (DumpInput(op_iter, task) != SUCCESS) { + GELOGE(FAILED, "Dump input failed"); + return FAILED; + } + } + op_mapping_info.mutable_task()->Add(std::move(task)); + continue; + } + if (PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + auto ret = DumpOutput(op_iter, task); + if (ret != SUCCESS) { + GELOGE(FAILED, "Dump output failed when in dumping all"); + return FAILED; + } + if (op_iter.is_task) { + ret = DumpInput(op_iter, task); + if (ret != SUCCESS) { + GELOGE(FAILED, "Dump input failed when in dumping all"); + return FAILED; + } + } + op_mapping_info.mutable_task()->Add(std::move(task)); + continue; + } + } + + SetEndGraphIdToAicpu(end_graph_task_id_, end_graph_stream_id_, op_mapping_info); + + auto ret = ExecuteLoadDumpInfo(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "Execute load dump info failed"); + return FAILED; + } + return SUCCESS; +} + +void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, + aicpu::dump::OpMappingInfo &op_mapping_info) { + if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput || + PropertiesManager::Instance().GetDumpMode() == kDumpInput || + PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + GELOGI("add end_graph_info to aicpu, task_id is %u, stream_id is %u", end_graph_task_id_, end_graph_stream_id_); + aicpu::dump::Task task; + task.set_end_graph(true); + task.set_task_id(end_graph_task_id_); + task.set_stream_id(end_graph_stream_id_); + task.mutable_op()->set_op_name(NODE_NAME_END_GRAPH); + task.mutable_op()->set_op_type(ENDGRAPH); + op_mapping_info.mutable_task()->Add(std::move(task)); + } +} + +Status DataDumper::UnloadDumpInfo() { + if (!load_flag_) { + GELOGI("No need to UnloadDumpInfo."); + load_flag_ = false; + return SUCCESS; + } + + GELOGI("UnloadDumpInfo start."); + aicpu::dump::OpMappingInfo op_mapping_info; + op_mapping_info.set_model_id(model_id_); + op_mapping_info.set_flag(kAicpuUnloadFlag); + + for (const auto &op_iter : op_list_) { + aicpu::dump::Task task; + task.set_task_id(op_iter.task_id); + op_mapping_info.mutable_task()->Add(std::move(task)); + } + auto ret = ExecuteUnLoadDumpInfo(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "Execute unload dump info failed"); + return FAILED; + } + return SUCCESS; +} void DataDumper::PrintCheckLog() { std::set model_list = PropertiesManager::Instance().GetAllDumpModel(); diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index 823f7079..efcc989a 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -17,11 +17,16 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DATA_DUMPER_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DATA_DUMPER_H_ -#include +#include #include +#include +#include #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" +#include "proto/ge_ir.pb.h" +#include "proto/op_mapping_info.pb.h" +#include "runtime/mem.h" #include "task_info/task_info.h" namespace ge { @@ -44,19 +49,28 @@ class DataDumper { ~DataDumper(); void SetModelName(const std::string &model_name) { model_name_ = model_name; } + void SetModelId(uint32_t model_id) { model_id_ = model_id; } + void SetMemory(const RuntimeParam &runtime_param) { runtime_param_ = runtime_param; } + void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); void SaveDumpInput(const std::shared_ptr &node); + // args is device memory stored first output addr - void SaveDumpTask(uint32_t task_id, const std::shared_ptr &op_desc, uintptr_t args); + void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args); + void SaveEndGraphId(uint32_t task_id, uint32_t stream_id); + Status LoadDumpInfo(); + Status UnloadDumpInfo(); private: void ReleaseDevMem(void **ptr) noexcept; + void PrintCheckLog(); std::string model_name_; @@ -69,16 +83,24 @@ class DataDumper { struct InnerInputMapping; std::vector op_list_; + uint32_t end_graph_task_id_ = 0; + uint32_t end_graph_stream_id_ = 0; std::multimap input_map_; bool load_flag_; uint32_t device_id_; uintptr_t global_step_; uintptr_t loop_per_iter_; uintptr_t loop_cond_; -}; + Status DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); + void SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info); + Status ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); +}; struct DataDumper::InnerDumpInfo { uint32_t task_id; + uint32_t stream_id; std::shared_ptr op; uintptr_t args; bool is_task; diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index 64a106ef..33a4fcf4 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -22,13 +22,14 @@ #include #include #include +#include #include +#include #include "common/debug/log.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/math/math_util.h" -#include "common/op/attr_define.h" #include "common/op/ge_op_utils.h" #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" @@ -51,7 +52,6 @@ #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "mmpa/mmpa_api.h" -#include "model_output.h" #include "omm/csa_interact.h" #include "runtime/base.h" #include "runtime/dev.h" @@ -77,6 +77,7 @@ namespace { const uint32_t kDataIndex = 0; const uint32_t kTrueBranchStreamNum = 1; const uint32_t kThreadNum = 16; +const uint32_t kAddrLen = sizeof(void *); const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; @@ -202,7 +203,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats formats::TransResult result_last_time{}; bool use_init_data = true; for (const auto &trans_info : trans_road) { - if (trans_info.node_type == domi::RESHAPE || trans_info.node_type == domi::REFORMAT) { + if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { GELOGD("Skip to trans variable data on the reshape/reformat node"); continue; } @@ -215,7 +216,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats } formats::TransResult tmp_result{}; - if (trans_info.node_type == domi::TRANSDATA) { + if (trans_info.node_type == TRANSDATA) { auto src_format = trans_info.input.GetFormat(); auto src_shape = trans_info.input.GetShape().GetDims(); auto dst_format = trans_info.output.GetFormat(); @@ -235,9 +236,9 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); return ret; } - } else if (trans_info.node_type == domi::CAST) { + } else if (trans_info.node_type == CAST) { auto input_shape = trans_info.input.GetShape(); - auto src_data_size = input_shape.GetShapeSize(); + auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); auto src_data_type = trans_info.input.GetDataType(); auto dst_data_type = trans_info.output.GetDataType(); GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", @@ -301,7 +302,7 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t GE_CHECK_NOTNULL(var); bool need_trans = false; for (auto &road : trans_road) { - if (road.node_type != domi::RESHAPE && road.node_type != domi::REFORMAT) { + if (road.node_type != RESHAPE && road.node_type != REFORMAT) { need_trans = true; break; } @@ -351,25 +352,8 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t return SUCCESS; } -bool CheckDynamicBatchZeroCopyAddr(const void *addr, const vector &dynamic_input_addrs, - const vector &fix_input_addrs) { - if (fix_input_addrs.empty()) { - if (!dynamic_input_addrs.empty() && - std::find(dynamic_input_addrs.begin(), dynamic_input_addrs.end(), addr) == dynamic_input_addrs.end()) { - return false; - } - } else { - if (!dynamic_input_addrs.empty() && - std::find(dynamic_input_addrs.begin(), dynamic_input_addrs.end(), addr) == dynamic_input_addrs.end() && - std::find(fix_input_addrs.begin(), fix_input_addrs.end(), addr) == fix_input_addrs.end()) { - return false; - } - } - return true; -} - inline bool IsDataOp(const std::string &node_type) { - return node_type == domi::DATA_TYPE || node_type == domi::AIPP_DATA_TYPE || node_type == domi::ANN_DATA_TYPE; + return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; } inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { bool skip_task_generate = false; @@ -377,10 +361,38 @@ inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { return skip_task_generate; } -} // namespace +void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input) { + uint32_t n, c, h, w; + n = format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N; + c = format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C; + h = format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H; + w = format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W; -domi::SysMode DavinciModel::mode_ = domi::INFERENCE; -std::mutex DavinciModel::mutex_mode_; + if (!op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + if (op_desc->GetInputDescPtr(0)->GetShape().GetDimNum() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = op_desc->GetInputDescPtr(0)->GetShape().GetDim(n); + input.shape_info.height = op_desc->GetInputDescPtr(0)->GetShape().GetDim(h); + input.shape_info.width = op_desc->GetInputDescPtr(0)->GetShape().GetDim(w); + input.shape_info.channel = op_desc->GetInputDescPtr(0)->GetShape().GetDim(c); + } + for (size_t k = 0; k < op_desc->GetInputDescPtr(0)->GetShape().GetDimNum(); k++) { + input.shape_info.dims.push_back(op_desc->GetInputDescPtr(0)->GetShape().GetDim(k)); + } + } else { + vector origin_input_dims; + (void)AttrUtils::GetListInt(op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); + if (origin_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = origin_input_dims[n]; + input.shape_info.height = origin_input_dims[h]; + input.shape_info.width = origin_input_dims[w]; + input.shape_info.channel = origin_input_dims[c]; + } + for (size_t k = 0; k < origin_input_dims.size(); ++k) { + input.shape_info.dims.push_back(origin_input_dims[k]); + } + } +} +} // namespace std::mutex DavinciModel::tvm_bin_mutex_; std::set DavinciModel::tvm_bin_kernel_; @@ -408,14 +420,11 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptrGetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGW("Node OpDesc is nullptr"); continue); - GE_IF_BOOL_EXEC(((op_desc->GetType() == domi::HCOMBROADCAST) || (op_desc->GetType() == domi::HCOMALLGATHER) || - (op_desc->GetType() == domi::HCOMALLREDUCE) || (op_desc->GetType() == domi::HCOMSEND) || - (op_desc->GetType() == domi::HCOMRECEIVE) || (op_desc->GetType() == domi::HCOMREDUCESCATTER)), + GE_IF_BOOL_EXEC(((op_desc->GetType() == HCOMBROADCAST) || (op_desc->GetType() == HCOMALLGATHER) || + (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || + (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); @@ -639,46 +650,57 @@ void DavinciModel::CheckHasHcomOp() { } } +/// +/// @ingroup ge +/// @brief Make active stream list and bind to model. +/// @return: 0 for success / others for fail +/// +Status DavinciModel::BindModelStream() { + // Stream not in active_stream_indication_ is active stream. + if (!input_queue_ids_.empty() || !output_queue_ids_.empty()) { + // Asynchronous Queue, need add S0, deactive all model stream. + for (size_t i = 0; i < stream_list_.size(); ++i) { + if (active_stream_indication_.count(i) == 0) { + active_stream_list_.push_back(stream_list_[i]); + active_stream_indication_.insert(i); // deactive all model stream. + } + } + } else { + for (size_t i = 0; i < stream_list_.size(); ++i) { + if (active_stream_indication_.count(i) == 0) { + active_stream_list_.push_back(stream_list_[i]); + } + } + } + + for (size_t i = 0; i < stream_list_.size(); ++i) { + if (active_stream_indication_.count(i) > 0) { + GELOGI("rtModelBindStream[%zu]", i); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_INVALID_FLAG)); + } else { + // bind rt_model_handel to all streams that relates to op + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_HEAD_STREAM)); + } + } + + return SUCCESS; +} + Status DavinciModel::DoTaskSink() { // task sink is supported as model_task_def is set if (model_task_def_) { GELOGI("do task_sink."); + GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); - // will adjust stream indication, load fist. - GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); - - for (size_t i = 0; i < stream_list_.size(); i++) { - GE_IF_BOOL_EXEC(active_stream_indication_.count(i) > 0, GELOGI("rtModelBindStream[%zu]", i); - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_INVALID_FLAG)); continue;); - // bind rt_model_handel to all streams that relates to op - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], 0)); - } + GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(InitTaskInfo(*model_task_def_.get()) != SUCCESS, return FAILED, - "InitTaskInfo failed."); + GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); } - for (const auto &addrs : input_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { - GELOGI("Not sinked data found, disable input zero copy."); - input_use_zero_copy_ = false; - break; - } - } - - for (const auto &addrs : output_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { - GELOGI("Not sinked data found, disable output zero copy."); - output_use_zero_copy_ = false; - break; - } - } return SUCCESS; } @@ -714,16 +736,26 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size GELOGD("The value of ge.l1Fusion in ge_model_ is %d.", is_l1_fusion_enable_); CheckHasHcomOp(); + vector huge_stream_list; + (void)ge::AttrUtils::GetListInt(ge_model_, ATTR_MODEL_HUGE_STREAM_LIST, huge_stream_list); + std::set huge_streams(huge_stream_list.begin(), huge_stream_list.end()); + for (uint32_t i = 0; i < StreamNum(); i++) { rtStream_t stream = nullptr; GE_MAKE_GUARD_RTSTREAM(stream); + uint32_t stream_flags = RT_STREAM_PERSISTENT; + if (huge_streams.find(i) != huge_streams.end()) { + GELOGI("Stream %u is huge stream.", i); + stream_flags |= RT_STREAM_HUGE; + } + if (hcom_streams_.find(i) != hcom_streams_.end()) { - GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY)); + GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_FORCE_COPY)); } else if (aicpu_streams_.find(i) != aicpu_streams_.end()) { - GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_AICPU)); + GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_AICPU)); } else { - GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, RT_STREAM_PERSISTENT)); + GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags)); } GE_DISMISS_GUARD(stream); @@ -737,28 +769,22 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size event_list_.push_back(rt_event); } - for (uint32_t i = 0; i < LabelNum(); i++) { - rtLabel_t rt_label; - GE_CHK_RT_RET(rtLabelCreate(&rt_label)); - GE_CHK_BOOL_RET_STATUS(rt_label != nullptr, FAILED, "rt_label is nullptr."); - label_list_.push_back(rt_label); - } + label_list_.resize(LabelNum(), nullptr); // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); GE_CHK_RT_RET(rtModelGetId(rt_model_handle_, &runtime_model_id_)); Graph graph = ge_model_->GetGraph(); - auto compute_graph = GraphUtils::GetComputeGraph(graph); - compute_graph_ = compute_graph; - GE_CHK_BOOL_RET_STATUS(compute_graph != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); + compute_graph_ = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); - runtime_param_.graph_id = GetGraphID(compute_graph->GetName()); + runtime_param_.graph_id = GetGraphID(compute_graph_->GetName()); GE_TIMESTAMP_START(TransAllVarData); - GE_CHK_STATUS_RET(TransAllVarData(compute_graph, runtime_param_.graph_id), "TransAllVarData failed."); + GE_CHK_STATUS_RET(TransAllVarData(compute_graph_, runtime_param_.graph_id), "TransAllVarData failed."); GE_TIMESTAMP_END(TransAllVarData, "GraphLoader::TransAllVarData"); - GE_CHK_STATUS_RET(CopyVarData(compute_graph), "copy var data failed."); + GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); GE_TIMESTAMP_START(InitModelMem); GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); @@ -767,14 +793,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size data_inputer_ = new (std::nothrow) DataInputer(); GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); - for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { + for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(IsBroadCastOpData(node), - (void)ge::AttrUtils::SetStr(node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); } // for profiling - op_name_map_ = compute_graph->GetGraphOpName(); + op_name_map_ = compute_graph_->GetGraphOpName(); vector op_name; GE_IF_BOOL_EXEC(ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_TASK_INDEX_OP_NAME, op_name), @@ -786,7 +812,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size GELOGI("infer profiling: op_name_size(%zu)", op_name.size()); } - if (InitNodes(compute_graph) != SUCCESS) { + if (InitNodes(compute_graph_) != SUCCESS) { return FAILED; } @@ -817,11 +843,16 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size /// Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { uint32_t data_op_index = 0; - std::map> input_data_info; - GE_TIMESTAMP_CALLNUM_START(LoadTBEKernelBinToOpDesc); GE_TIMESTAMP_CALLNUM_START(InitTbeHandle); + typedef Status (DavinciModel::*OpDescCall)(const OpDescPtr &); + static std::map op_desc_handle = { + {VARIABLE, &DavinciModel::InitVariable}, {CONSTANTOP, &DavinciModel::InitConstant}, + {STREAMACTIVE, &DavinciModel::InitStreamActive}, {STREAMSWITCH, &DavinciModel::InitStreamSwitch}, + {STREAMSWITCHN, &DavinciModel::InitStreamSwitchN}, {LABELSET, &DavinciModel::InitLabelSet}, + }; + auto nodes = compute_graph->GetAllNodes(); const TBEKernelStore &tbekernel_store = ge_model_->GetTBEKernelStore(); for (size_t i = 0; i < nodes.size(); i++) { @@ -839,7 +870,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_ADD(LoadTBEKernelBinToOpDesc); if (IsDataOp(op_desc->GetType())) { - if (InitDataOp(node, data_op_index, input_data_info) != SUCCESS) { + if (InitDataOp(node, data_op_index) != SUCCESS) { GELOGE(PARAM_INVALID, "Data init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } @@ -853,32 +884,23 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } - if (op_desc->GetType() == VARIABLE) { - variable_op_list_.push_back(op_desc); - continue; - } - if (op_desc->GetType() == NETOUTPUT) { - if (InitNetOutput(op_desc) != SUCCESS) { + if (InitNetOutput(node) != SUCCESS) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } continue; } - // Initialize constant op, only applies to training, ignoring inference constant op - if (op_desc->GetType() == CONSTANTOP) { - if (InitConstant(op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "Constant init failed. %s", op_desc->GetName().c_str()); + auto it = op_desc_handle.find(op_desc->GetType()); + if (it != op_desc_handle.end()) { + if ((this->*it->second)(op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } continue; } - if (op_desc->GetType() == ENDGRAPH) { - end_graph_op_ = op_desc; - } - GE_TIMESTAMP_RESTART(InitTbeHandle); uint32_t run_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode) && @@ -897,17 +919,11 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { } } GE_TIMESTAMP_ADD(InitTbeHandle); - - if (MarkActiveStream(op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "MarkActiveStream failed, node:%s, opIndex:%zu", op_desc->GetName().c_str(), i); - return PARAM_INVALID; - } } - Status ret = CombineDataInfo(input_data_info); GE_TIMESTAMP_CALLNUM_END(LoadTBEKernelBinToOpDesc, "GraphLoader::LoadTBEKernelBinToOpDesc."); GE_TIMESTAMP_CALLNUM_END(InitTbeHandle, "GraphLoader::InitTbeHandle."); - return ret; + return SUCCESS; } /// @ingroup ge @@ -916,8 +932,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { /// @param [in/out] data_op_index: NetOutput addr size info. /// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status -Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, - std::map> &input_data_info) { +Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // op_desc Checked by Init: Data, valid. auto op_desc = node->GetOpDesc(); uint32_t parent_index = 0; // Ignore subgraph Data Node. @@ -939,20 +954,19 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, // Make information for copy input data. const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - const vector output_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); - if (output_size_list.empty() || output_addr_list.empty() || (output_size_list.size() != output_addr_list.size())) { + const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc, false); + if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), - output_size_list.size(), output_addr_list.size()); + output_size_list.size(), virtual_addr_list.size()); return PARAM_INVALID; } auto data_index = data_op_index; if (AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { - GELOGI("ge_train:get new index %u, old %u", data_index, data_op_index); + GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); } - - input_data_info[data_index] = {output_size_list[kDataIndex], output_addr_list[kDataIndex]}; - SetInputOutsideAddr(output_addr_list); + input_data_info_[data_index] = {output_size_list[kDataIndex], virtual_addr_list[kDataIndex]}; + SetInputOutsideAddr(virtual_addr_list); data_op_index++; if (InitInputZeroCopyNodes(node) != SUCCESS) { GELOGE(PARAM_INVALID, "Input zero copy nodes init failed!"); @@ -996,62 +1010,134 @@ Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { /// @ingroup ge /// @brief NetOutput Op Initialize. -/// @param [in] op_desc: NetOutput Op descriptor. +/// @param [in] NodePtr: NetOutput Op. /// @return Status -Status DavinciModel::InitNetOutput(const OpDescPtr &op_desc) { - // op_desc Checked by Init: NetOutput, valid. - uint32_t parent_index = 0; // Ignore subgraph NetOutput Node. - if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { +Status DavinciModel::InitNetOutput(const NodePtr &node) { + // node->GetOpDesc Checked by Init: NetOutput, valid. + auto op_desc = node->GetOpDesc(); + ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + if (owner_graph->GetParentGraph() != nullptr) { GELOGI("Skip subgraph NetOutput node: %s.", op_desc->GetName().c_str()); + op_list_.erase(op_desc->GetId()); return SUCCESS; } output_op_list_.push_back(op_desc); - std::vector output_size_list; // useless, just for check. - if (ModelUtils::GetOutputSize(op_desc, output_size_list, output_memory_size_list_) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get output size failed: %s", op_desc->GetName().c_str()); - return PARAM_INVALID; - } // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); - const vector input_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc); - if (input_size_list.empty() && input_addr_list.empty()) { + const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); + if (input_size_list.empty() && virtual_addr_list.empty()) { GELOGI("NetOutput[%s] is empty.", op_desc->GetName().c_str()); return SUCCESS; } - if (input_size_list.empty() || input_size_list.size() != input_addr_list.size() || - input_size_list.size() != output_size_list.size()) { - GELOGE(PARAM_INVALID, "NetOutput[%s] init failed: Input size is %zu, Input addr is %zu, Output size is %zu", - op_desc->GetName().c_str(), input_size_list.size(), input_addr_list.size(), output_size_list.size()); + if (input_size_list.empty() || input_size_list.size() != virtual_addr_list.size()) { + GELOGE(PARAM_INVALID, "NetOutput[%s] init failed: Input size is %zu, Input addr is %zu", op_desc->GetName().c_str(), + input_size_list.size(), virtual_addr_list.size()); return PARAM_INVALID; } - output_size_list_.insert(output_size_list_.end(), input_size_list.begin(), input_size_list.end()); - output_addr_list_.insert(output_addr_list_.end(), input_addr_list.begin(), input_addr_list.end()); - SetOutputOutsideAddr(input_addr_list); + size_t num = output_data_info_.size(); + for (size_t idx = 0; idx < input_size_list.size(); ++idx) { + output_data_info_[num + idx] = {input_size_list[idx], virtual_addr_list[idx]}; + } + + SetOutputOutsideAddr(virtual_addr_list); + if (InitOutputZeroCopyNodes(node) != SUCCESS) { + GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); + return PARAM_INVALID; + } return SUCCESS; } +/// /// @ingroup ge -/// @brief Make Input and Output addr for feature use. -/// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. +/// @brief output zero copy node Initialize. +/// @param [in] NodePtr: netoutput Op or merge op. /// @return Status -Status DavinciModel::CombineDataInfo(const std::map> &input_data_info) { - input_size_list_.resize(data_op_list_.size()); - input_addr_list_.resize(data_op_list_.size()); - for (size_t index = 0; index < data_op_list_.size(); ++index) { - auto it = input_data_info.find(index); - if (it == input_data_info.end()) { - GELOGE(PARAM_INVALID, "Data init failed: index %zu, Data Op size is %zu, Input addr is %zu", index, - data_op_list_.size(), input_data_info.size()); - return INTERNAL_ERROR; +/// +Status DavinciModel::InitOutputZeroCopyNodes(const NodePtr &node) { + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto node = peer_out_data_anchor->GetOwnerNode(); + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "Op desc is nullptr"); + return FAILED; + } + + // Merge node output multiplexed input, upstream nodes need to be considered in multiple batch scenarios + if (node->GetType() == MERGE) { + if (InitOutputZeroCopyNodes(node) != SUCCESS) { + GELOGE(PARAM_INVALID, "Output merge zero copy nodes init failed!"); + return PARAM_INVALID; + } + } + + string batch_label; + (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); + if (batch_label.empty()) { + batch_label = kDefaultBatchLable; + } + if (zero_copy_op_id_batch_label_.find(op_desc->GetId()) == zero_copy_op_id_batch_label_.end()) { + zero_copy_op_id_batch_label_.emplace(pair(op_desc->GetId(), batch_label)); + GELOGD("Init Output zero copy nodes success, op name:%s, op id: %ld, batch label: %s.", + op_desc->GetName().c_str(), op_desc->GetId(), batch_label.c_str()); } - input_size_list_[index] = it->second.first; - input_addr_list_[index] = it->second.second; } + return SUCCESS; +} - GELOGI("Data init success, input size %zu, output size %zu", input_size_list_.size(), output_size_list_.size()); +/// @ingroup ge +/// @brief LabelSet Op Initialize. +/// @param [in] op_desc: LabelSet Op descriptor. +/// @return Status +Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) { + uint32_t label_index = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, label_index)) { + GELOGE(INTERNAL_ERROR, "InitLabelSet: %s attr [%s] not exist.", op_desc->GetName().c_str(), + ATTR_NAME_LABEL_SWITCH_INDEX.c_str()); + return INTERNAL_ERROR; + } + if (label_index >= LabelNum()) { + GELOGE(INTERNAL_ERROR, "InitLabelSet: label index: %u >= label size: %zu.", label_index, LabelNum()); + return INTERNAL_ERROR; + } + if (label_id_indication_.count(label_index) > 0) { + GELOGE(INTERNAL_ERROR, "InitLabelSet: %s label index: %u already used.", op_desc->GetName().c_str(), label_index); + return INTERNAL_ERROR; + } + + rtStream_t stream = nullptr; + uint32_t stream_id = static_cast(op_desc->GetStreamId()); + if (stream_list_.size() == 1) { + stream = stream_list_[0]; + } else if (stream_list_.size() > stream_id) { + stream = stream_list_[stream_id]; + } else { + GELOGE(INTERNAL_ERROR, "InitLabelSet: stream index: %u >= stream size: %zu.", stream_id, stream_list_.size()); + return INTERNAL_ERROR; + } + + rtLabel_t rt_label = nullptr; + rtError_t rt_error = rtLabelCreateEx(&rt_label, stream); + if (rt_error != RT_ERROR_NONE || rt_label == nullptr) { + GELOGE(INTERNAL_ERROR, "InitLabelSet: %s create label failed, error=0x%x.", op_desc->GetName().c_str(), rt_error); + return INTERNAL_ERROR; + } + + GELOGI("InitLabelSet: label[%u]=%p stream[%u]=%p.", label_index, rt_label, stream_id, stream); + label_id_indication_.insert(label_index); + label_list_[label_index] = rt_label; + return SUCCESS; +} + +Status DavinciModel::InitVariable(const OpDescPtr &op_desc) { + variable_op_list_.push_back(op_desc); return SUCCESS; } @@ -1084,31 +1170,34 @@ Status DavinciModel::LoadWithQueue() { return SUCCESS; } - if (input_queue_ids_.size() != data_op_list_.size()) { + if (input_queue_ids_.size() != input_data_info_.size()) { GELOGE(PARAM_INVALID, "Input queue ids not match model: input_queue=%zu input_data=%zu", input_queue_ids_.size(), - data_op_list_.size()); + input_data_info_.size()); return PARAM_INVALID; } - if (output_queue_ids_.size() != output_size_list_.size()) { + if (output_queue_ids_.size() != output_data_info_.size()) { GELOGE(PARAM_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", - output_queue_ids_.size(), output_size_list_.size()); + output_queue_ids_.size(), output_data_info_.size()); return PARAM_INVALID; } // create stream instance which rt_model_handel is running on, this is S0. GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_model_stream_, priority_, RT_STREAM_AICPU)); is_inner_model_stream_ = true; - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, 0)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, RT_HEAD_STREAM)); // Binding input_queue and Data Op. GE_CHK_STATUS_RET(BindInputQueue(), "Launch bind input queue failed."); - - GE_CHK_STATUS_RET(BindActiveStream(), "Launch active entry stream failed."); - GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); + GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, input_outside_addrs_), "Launch zero copy failed."); // Binding output_queue and NetOutput Op. GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); + GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, output_outside_addrs_), "Launch zero copy failed."); + + GE_CHK_STATUS_RET(CpuActiveStream(active_stream_list_), "Launch active entry stream failed."); + GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); + GE_CHK_STATUS_RET(BindEnqueue(), "Launch enqueue failed."); GE_CHK_STATUS_RET(CpuModelRepeat(), "Launch model repeat failed."); return SUCCESS; @@ -1120,9 +1209,15 @@ Status DavinciModel::LoadWithQueue() { Status DavinciModel::BindInputQueue() { // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() for (size_t i = 0; i < input_queue_ids_.size(); ++i) { + auto it = input_data_info_.find(i); + if (it == input_data_info_.end()) { + GELOGE(FAILED, "Input not match: tensor num=%zu, Queue id index=%zu", input_data_info_.size(), i); + return FAILED; + } + uint32_t queue_id = input_queue_ids_[i]; - uint32_t data_size = input_size_list_[i]; - uintptr_t data_addr = reinterpret_cast(input_addr_list_[i]); + uint32_t data_size = static_cast(it->second.first); + uintptr_t data_addr = reinterpret_cast(it->second.second); GELOGI("BindInputToQueue: graph_%u index[%zu] queue id[%u] output addr[0x%lx] output size[%u]", runtime_param_.graph_id, i, queue_id, data_addr, data_size); @@ -1130,7 +1225,7 @@ Status DavinciModel::BindInputQueue() { return INTERNAL_ERROR; } - if (CpuModelDequeue(queue_id, data_addr, data_size) != SUCCESS) { + if (CpuModelDequeue(queue_id) != SUCCESS) { return INTERNAL_ERROR; } } @@ -1138,58 +1233,13 @@ Status DavinciModel::BindInputQueue() { return SUCCESS; } -/// @ingroup ge -/// @brief queue schedule, bind output queue to NetOutput input address. -/// @return: 0 for success / others for failed -Status DavinciModel::BindOutputQueue() { - // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() - for (size_t i = 0; i < output_queue_ids_.size(); ++i) { - uint32_t queue_id = output_queue_ids_[i]; - uint32_t data_size = output_size_list_[i]; - uintptr_t data_addr = reinterpret_cast(output_addr_list_[i]); - GELOGI("BindOutputToQueue: graph_%u index[%zu] queue id[%u] input addr[0x%lx] input size[%u]", - runtime_param_.graph_id, i, queue_id, data_addr, data_size); - - if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE) != RT_ERROR_NONE) { - return INTERNAL_ERROR; - } - - if (CpuModelEnqueue(queue_id, data_addr, data_size) != SUCCESS) { - return INTERNAL_ERROR; - } - } - - return SUCCESS; -} - -/// @ingroup ge -/// @brief queue schedule, active stream will schedule by S0. -/// @return: 0 for success / others for failed -Status DavinciModel::BindActiveStream() { - // Stream not in active_stream_indication_ is active stream. - std::vector active_stream_list; - for (size_t i = 0; i < stream_list_.size(); ++i) { - if (active_stream_indication_.count(i) == 0) { - active_stream_list.push_back(stream_list_[i]); - active_stream_indication_.insert(i); // deactive all model stream. - } - } - - // Active stream add to active entry, will active by S0. - if (CpuActiveStream(active_stream_list) != SUCCESS) { - return INTERNAL_ERROR; - } - - return SUCCESS; -} - /// @ingroup ge /// @brief definiteness queue schedule, bind input queue to task. /// @param [in] queue_id: input queue id from user. /// @param [in] addr: Data Op output tensor address. /// @param [in] size: Data Op output tensor size. /// @return: 0 for success / others for failed -Status DavinciModel::CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size) { +Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { GELOGI("Set CpuKernel model dequeue task enter."); std::shared_ptr dequeue_task = MakeShared(rt_model_stream_); if (dequeue_task == nullptr) { @@ -1203,20 +1253,55 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t return FAILED; } - std::shared_ptr prepare_input = MakeShared(rt_model_stream_); - if (dequeue_task == nullptr) { - GELOGE(FAILED, "Make CpuTaskPrepareInput task failed."); + cpu_task_list_.push_back(dequeue_task); + input_mbuf_list_.push_back(in_mbuf); + GELOGI("Set CpuKernel model dequeue task success."); + return SUCCESS; +} + +Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, + std::map> &outside_addrs) { + GELOGI("Set CpuKernel model zero_copy task enter."); + std::shared_ptr zero_copy = MakeShared(rt_model_stream_); + if (zero_copy == nullptr) { + GELOGE(FAILED, "Make CpuTaskZeroCopy task failed."); return FAILED; } - if (prepare_input->Init(addr, size, in_mbuf) != SUCCESS) { + if (zero_copy->Init(mbuf_list, outside_addrs) != SUCCESS) { return FAILED; } + cpu_task_list_.push_back(zero_copy); + GELOGI("Set CpuKernel model zero_copy task success."); + return SUCCESS; +} + +/// @ingroup ge +/// @brief queue schedule, bind output queue to NetOutput input address. +/// @return: 0 for success / others for failed +Status DavinciModel::BindOutputQueue() { + // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() + for (size_t i = 0; i < output_queue_ids_.size(); ++i) { + auto it = output_data_info_.find(i); + if (it == output_data_info_.end()) { + GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); + return FAILED; + } + + uint32_t queue_id = output_queue_ids_[i]; + uint32_t data_size = static_cast(it->second.first); + uintptr_t data_addr = reinterpret_cast(it->second.second); + GELOGI("BindOutputToQueue: graph_%u index[%zu] queue id[%u] input addr[0x%lx] input size[%u]", + runtime_param_.graph_id, i, queue_id, data_addr, data_size); + + if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE) != RT_ERROR_NONE) { + return INTERNAL_ERROR; + } + if (CpuModelPrepareOutput(data_addr, data_size) != SUCCESS) { + return INTERNAL_ERROR; + } + } - cpu_task_list_.push_back(dequeue_task); - cpu_task_list_.push_back(prepare_input); - input_mbuf_list_.push_back(in_mbuf); - GELOGI("Set CpuKernel model dequeue task success."); return SUCCESS; } @@ -1226,7 +1311,7 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t /// @param [in] addr: NetOutput Op input tensor address. /// @param [in] size: NetOutput Op input tensor size. /// @return: 0 for success / others for failed -Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t size) { +Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { GELOGI("Set CpuKernel model enqueue task enter."); if (input_mbuf_list_.empty()) { GELOGE(FAILED, "Need input mbuf for fill output mbuf head info."); @@ -1240,22 +1325,11 @@ Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t } uintptr_t out_mbuf = 0; - if (prepare_output->Init(addr, size, input_mbuf_list_[0], out_mbuf) != SUCCESS) { - return FAILED; - } - - std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); - if (model_enqueue == nullptr) { - GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); - return FAILED; - } - - if (model_enqueue->Init(queue_id, out_mbuf) != SUCCESS) { + if (prepare_output->Init(addr, size, input_mbuf_list_.back(), out_mbuf) != SUCCESS) { return FAILED; } cpu_task_list_.push_back(prepare_output); - cpu_task_list_.push_back(model_enqueue); output_mbuf_list_.push_back(out_mbuf); GELOGI("Set CpuKernel model enqueue task success."); return SUCCESS; @@ -1307,6 +1381,38 @@ Status DavinciModel::CpuWaitEndGraph() { return SUCCESS; } +Status DavinciModel::BindEnqueue() { + for (size_t i = 0; i < output_queue_ids_.size(); ++i) { + auto it = output_data_info_.find(i); + if (it == output_data_info_.end()) { + GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); + return FAILED; + } + + uint32_t queue_id = output_queue_ids_[i]; + if (CpuModelEnqueue(queue_id, output_mbuf_list_[i]) != SUCCESS) { + return INTERNAL_ERROR; + } + } + return SUCCESS; +} + +Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { + GELOGI("Set CpuKernel model enqueue task enter."); + std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); + if (model_enqueue == nullptr) { + GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); + return FAILED; + } + + if (model_enqueue->Init(queue_id, out_mbuf) != SUCCESS) { + return FAILED; + } + cpu_task_list_.push_back(model_enqueue); + GELOGI("Set CpuKernel model enqueue task enter."); + return SUCCESS; +} + /// @ingroup ge /// @brief definiteness queue schedule, repeat run model. /// @return: 0 for success / others for failed @@ -1327,67 +1433,18 @@ Status DavinciModel::CpuModelRepeat() { return SUCCESS; } -/// @ingroup domi_ome -/// @brief get sys mode -/// @return SysMode required system mode -/// @author -domi::SysMode DavinciModel::GetSysMode() { - std::unique_lock lock(mutex_mode_); - return mode_; -} - -/// @ingroup domi_ome -/// @brief set sys mode -/// @param [in] mode to be set -/// @return Status mode set result -/// @author -Status DavinciModel::SetSysMode(domi::SysMode mode) { - GE_CHK_BOOL_RET_STATUS(mode < domi::RESERVED, PARAM_INVALID, "DavinciModel::SetSysMode Para Error"); - - std::unique_lock lock(mutex_mode_); - mode_ = mode; - return SUCCESS; -} - Status DavinciModel::GetInputOutputDescInfo(vector &input_desc, vector &output_desc) { if ((data_op_list_.empty()) || (data_op_list_[0]->GetInputsSize()) != 1) { GELOGI("data_op_list_ is empty or input_desc size is not 1."); } else { - std::vector input_formats; - GE_CHK_STATUS_RET(GetInputDescInfo(input_desc, input_formats), "get input desc info failed."); - } - - std::vector outputFormats; - GE_CHK_STATUS_RET(GetOutputDescInfo(output_desc, outputFormats), "get output desc info failed."); - - return SUCCESS; -} - -Status DavinciModel::GetInputOutputDescInfoForZeroCopy(vector &input_desc, - vector &output_desc) { - if ((data_op_list_.empty()) || (data_op_list_[0]->GetInputsSize()) != 1) { - GELOGE(FAILED, "OP List Pointer is null or input_desc size is not 1."); - return FAILED; - } - - std::vector input_formats; - GE_CHK_STATUS_RET(GetInputDescInfo(input_desc, input_formats), "get input desc info failed."); - std::vector outputFormats; - GE_CHK_STATUS_RET(GetOutputDescInfo(output_desc, outputFormats), "get output desc info failed."); - - GE_CHK_BOOL_RET_STATUS(output_desc.size() == output_memory_size_list_.size(), INTERNAL_ERROR, - "output_desc size[%zu] not equal output_size_list_[%zu] size!", output_desc.size(), - output_memory_size_list_.size()); - - /// For function zero copy,the memory should be aligned by 512 bytes. - /// And, because of the cce op limit, size should be lager than the real shape size. The memory should be padded by 32 - /// bytes. - /// *size equals to ((tensorDesc->dataSize + 2 * 32 - 1) / 32) * 32; - for (size_t i = 0; i < output_memory_size_list_.size(); i++) { - output_desc[i].size = output_memory_size_list_[i]; + std::vector input_formats; + GE_CHK_STATUS_RET(GetInputDescInfo(input_desc, input_formats), "get input desc info failed."); } + std::vector outputFormats; + GE_CHK_STATUS_RET(GetOutputDescInfo(output_desc, outputFormats), "get output desc info failed."); + return SUCCESS; } @@ -1408,7 +1465,7 @@ Status DavinciModel::GetInputOutputDescInfo(vector &input_d } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get dynamic batch_info /// @param [out] batch_info /// @return execute result @@ -1421,7 +1478,7 @@ Status DavinciModel::GetDynamicBatchInfo(std::vector> &batc return FAILED; } - if (op_desc->GetType() != domi::STREAMSWITCHN) { + if (op_desc->GetType() != STREAMSWITCHN) { continue; } @@ -1478,26 +1535,11 @@ Status DavinciModel::GetInputOutputDescInfoForZeroCopy(vector &input_desc, std::vector &formats) { for (std::size_t index = 0; index < data_op_list_.size(); ++index) { InputOutputDescInfo input; - uint32_t n, c, h, w; GE_CHECK_NOTNULL(data_op_list_[index]); GE_CHECK_NOTNULL(data_op_list_[index]->GetInputDescPtr(0)); - Format format = data_op_list_[index]->GetInputDescPtr(0)->GetFormat(); - n = format == FORMAT_NHWC ? domi::NHWC_DIM_N : domi::NCHW_DIM_N; - c = format == FORMAT_NHWC ? domi::NHWC_DIM_C : domi::NCHW_DIM_C; - h = format == FORMAT_NHWC ? domi::NHWC_DIM_H : domi::NCHW_DIM_H; - w = format == FORMAT_NHWC ? domi::NHWC_DIM_W : domi::NCHW_DIM_W; - - if (data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDimNum() == - static_cast(domi::NORMAL_TENSOR_SIZE)) { - input.shape_info.num = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(n); - input.shape_info.height = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(h); - input.shape_info.width = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(w); - input.shape_info.channel = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(c); - } - for (size_t k = 0; k < data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDimNum(); k++) { - input.shape_info.dims.push_back(data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(k)); - } + Format format = data_op_list_[index]->GetInputDescPtr(0)->GetFormat(); + CreateInputDimsInfo(data_op_list_[index], format, input); input.data_type = data_op_list_[index]->GetInputDescPtr(0)->GetDataType(); input.name = data_op_list_[index]->GetName(); int64_t input_size = 0; @@ -1525,11 +1567,11 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD for (size_t i = 0; i < shape.GetDimNum() && i < (sizeof(dims) / sizeof(dims[0])); i++) { dims[i] = shape.GetDim(i); } - } else { // FOR FORMAT_NHWC or FORMAT_NCHW - dims[0] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_N : domi::NCHW_DIM_N); // 0: first dim - dims[1] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_C : domi::NCHW_DIM_C); // 1: second dim - dims[2] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_H : domi::NCHW_DIM_H); // 2: third dim - dims[3] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_W : domi::NCHW_DIM_W); // 3: forth dim + } else { // FOR FORMAT_NHWC or FORMAT_NCHW + dims[0] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N); // 0: first dim + dims[1] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C); // 1: second dim + dims[2] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H); // 2: third dim + dims[3] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W); // 3: forth dim } output.shape_info.num = dims[0]; // 0: first dim output.shape_info.channel = dims[1]; // 1: second dim @@ -1559,21 +1601,12 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD } Status DavinciModel::GetOutputDescInfo(vector &output_desc, std::vector &formats) { + GELOGI("Output node size: %zu", output_op_list_.size()); for (size_t i = 0; i < output_op_list_.size(); i++) { auto &op_desc = output_op_list_[i]; - uint32_t out_size = static_cast(op_desc->GetOutputsSize()); + uint32_t out_size = static_cast(op_desc->GetInputsSize()); for (uint32_t index = 0; index < out_size; index++) { - bool is_output = false; - GE_IF_BOOL_EXEC(op_desc->GetOutputDescPtr(index) == nullptr, - GELOGE(INTERNAL_ERROR, "OpDesc GetOutputDescPtr is nullptr"); - return INTERNAL_ERROR); - GE_CHK_STATUS(TensorUtils::GetOutputTensor(*op_desc->GetOutputDescPtr(index), is_output), - "get output tensor failed."); - if (!is_output) { - continue; - } - string output_name; InputOutputDescInfo output; uint32_t format_result; @@ -1603,25 +1636,40 @@ ge::Format DavinciModel::GetFormat() { return data_op_list_[0]->GetInputDescPtr(0)->GetFormat(); } -Status DavinciModel::CopyInputData(const InputData ¤t_data, bool device_data) { - Status ret = SUCCESS; - uint32_t data_op_index = 0; +Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data) { + rtMemcpyKind_t kind = device_data ? RT_MEMCPY_DEVICE_TO_DEVICE : RT_MEMCPY_HOST_TO_DEVICE; + const std::vector &blobs = input_data.blobs; + for (const auto &data : input_data_info_) { + if (data.first >= blobs.size()) { + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), + input_data_info_.size(), data.first, data.second.first); + return FAILED; + } - for (auto op_desc : data_op_list_) { - ret = CopyInputDataToModel(current_data.blobs, data_op_index, device_data); + const DataBuffer &data_buf = blobs[data.first]; + void *mem_addr = data.second.second; + uint32_t mem_size = static_cast(data.second.first); + GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, + "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, + mem_size); - GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy input data to model ret failed, index:%u, model id:%u", - current_data.index, current_data.model_id); - data_op_index++; + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] memaddr[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, data.first, mem_addr, mem_size, data_buf.length); + if (data_buf.length == 0) { + GELOGW("No data need to memcpy!"); + return SUCCESS; + } + GE_CHK_RT_RET(rtMemcpy(mem_addr, mem_size, data_buf.data, data_buf.length, kind)); } - return ret; + + return SUCCESS; } Status DavinciModel::SyncVarData() { GELOGI("Sync var data, model id:%u", model_id_); Status ret = SUCCESS; - OpDescPtr global_step = GetVariableOp(domi::NODE_NAME_GLOBAL_STEP); + OpDescPtr global_step = GetVariableOp(NODE_NAME_GLOBAL_STEP); if (global_step != nullptr) { auto v_output_size = ModelUtils::GetOutputSize(global_step); auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param_, global_step); @@ -1652,10 +1700,9 @@ inline int64_t SumSize(const vector &size_list) { return sum_size; } -Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { - GE_CHECK_NOTNULL(model); +Status DavinciModel::SinkModelProfile() { // not support non-sink model - GE_CHK_BOOL_EXEC(model->model_task_def_ != nullptr, return SUCCESS); + GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); @@ -1669,51 +1716,48 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { Msprof::Engine::ReporterData reporter_data{}; // report model data tag name std::string tag_name; - tag_name.append("model_load_info_").append(std::to_string(model->Id())); + tag_name.append("model_load_info_").append(std::to_string(this->Id())); GE_CHK_BOOL_EXEC(memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN, tag_name.c_str(), tag_name.size()) == EOK, return FAILED, "Sink model tag memcpy error."); // Model Header - string name = model->Name(); + string name = this->Name(); int32_t name_len = name.size(); reporter_data.deviceId = device_id_; reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); reporter_data.data = (unsigned char *)name.c_str(); reporter_data.dataLen = name.size(); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); - uint32_t model_id = model->Id(); + uint32_t model_id = this->Id(); reporter_data.data = (unsigned char *)&model_id; reporter_data.dataLen = sizeof(uint32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); // Load Start/End Time - int64_t start_time = model->GetLoadBeginTime(); + int64_t start_time = this->GetLoadBeginTime(); reporter_data.data = (unsigned char *)&start_time; reporter_data.dataLen = sizeof(int64_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); - int64_t end_time = model->GetLoadEndTime(); + int64_t end_time = this->GetLoadEndTime(); reporter_data.data = (unsigned char *)&end_time; reporter_data.dataLen = sizeof(int64_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); - - auto task_list = model->GetTaskList(); - auto op_list = model->GetOpList(); + this->Id()); - int32_t task_num = task_list.size(); + int32_t task_num = task_list_.size(); std::multimap op_id_map; std::set task_id_set; for (int32_t i = 0; i < task_num; i++) { - auto task = task_list[i]; + auto task = task_list_[i]; auto fusion_op_info = task->GetFusionOpInfo(); // when type is RT_MODEL_TASK_KERNEL, ctx is not null @@ -1740,7 +1784,7 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { using CIT = std::multimap::const_iterator; using Range = std::pair; for (int32_t i = 0; i < task_num; i++) { - auto task = task_list[i]; + auto task = task_list_[i]; auto fusion_op_info = task->GetFusionOpInfo(); if (fusion_op_info != nullptr && fusion_op_info->original_op_names.size() > 0) { uint32_t task_id = task->GetTaskID(); @@ -1764,18 +1808,18 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { reporter_data.data = (unsigned char *)&fusion_op_name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); reporter_data.data = (unsigned char *)fusion_op_name.c_str(); reporter_data.dataLen = fusion_op_name_len; GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); // original op name before fusion reporter_data.data = (unsigned char *)&op_num; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); for (uint32_t k = 0; k < op_num; k++) { std::string op_name = fusion_op_info->original_op_names[k]; @@ -1783,25 +1827,25 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { reporter_data.data = (unsigned char *)&op_name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); reporter_data.data = (unsigned char *)op_name.c_str(); reporter_data.dataLen = op_name_len; GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); } // stream id info - uint32_t streamId = fusion_op_info->stream_id; + uint32_t streamId = task->GetStreamId(); reporter_data.data = (unsigned char *)&streamId; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); // memory info struct memoryInfo memory_info; uint32_t op_index = fusion_op_info->op_index; - auto iter = op_list.find(op_index); - GE_CHK_BOOL_EXEC(iter != op_list.end(), return FAILED, "index is out of range, index: %u", op_index); + auto iter = op_list_.find(op_index); + GE_CHK_BOOL_EXEC(iter != op_list_.end(), return FAILED, "index is out of range, index: %u", op_index); auto op_desc = iter->second; memory_info.input_size = SumSize(ModelUtils::GetInputSize(op_desc)); memory_info.output_size = SumSize(ModelUtils::GetOutputSize(op_desc)); @@ -1812,13 +1856,13 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { reporter_data.data = (unsigned char *)&memory_info; reporter_data.dataLen = sizeof(struct memoryInfo); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); // task info reporter_data.data = (unsigned char *)&task_count; reporter_data.dataLen = sizeof(uint32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); Range task_range = op_id_map.equal_range(op_id); for (CIT idx = task_range.first; idx != task_range.second; ++idx) { @@ -1826,7 +1870,7 @@ Status DavinciModel::SinkModelProfile(std::shared_ptr &model) { reporter_data.data = (unsigned char *)&task_id; reporter_data.dataLen = sizeof(uint32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", - model->Id()); + this->Id()); } } } @@ -1932,143 +1976,9 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { } return; } -/// -/// @ingroup domi_ome -/// @brief copy input data to Model's firat OP. Address already malloced when Load -/// @copy need datatype transfer: FLOAT to FP16, 4D to 5D; -/// @param [in] data data pointer to be copy -/// @return Status result -/// @author -/// -Status DavinciModel::CopyInputDataToModel(const std::vector &data, uint32_t data_op_index, - bool device_data) { - GE_CHK_BOOL_RET_STATUS(!data_op_list_.empty(), PARAM_INVALID, "data_op_list_ is empty!"); - - GE_CHK_BOOL_RET_STATUS(data_op_list_.size() == data.size(), PARAM_INVALID, - "The input data list size (%zu) does not match the model input list size (%zu)", data.size(), - data_op_list_.size()); - - GE_CHK_BOOL_RET_STATUS(data_op_index < data_op_list_.size(), PARAM_INVALID, - "input data op index(%zu) is invalid, exceeds input op size(%zu)", data_op_index, - data_op_list_.size()); - - /// input datatype conversion, converting FLOAT to FP16, 4D to 5D at the same time. - /// Choose respective mode in API parameters. - auto op_def = data_op_list_[data_op_index]; - GE_CHK_BOOL_EXEC(op_def != nullptr, return PARAM_INVALID, "op_def is null!"); - - auto data_index = data_op_index; - if (AttrUtils::GetInt(op_def, "index", data_index)) { - GELOGI("ge_train:get new index %u , old %u", data_index, data_op_index); - } - - GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); - GE_CHK_BOOL_RET_STATUS(op_def->GetInputsSize() == 1 && op_def->GetOutputsSize() == 1, PARAM_INVALID, - "Data Op has invalid input_desc_size(%zu) or output_desc_size(%zu)", op_def->GetInputsSize(), - op_def->GetOutputsSize()); - - // float to float16 - bool need_trans_flag = ModelUtils::IsInputTensorNeedTrans(data_op_list_[data_op_index], 0); - - int64_t output_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(*op_def->GetOutputDescPtr(0), output_size), "get output size failed."); - GE_CHK_BOOL_RET_STATUS(output_size >= data[data_index].length, PARAM_INVALID, - "input data size(%u) does not match model required size(%zu), ret failed.", - data[data_index].length, output_size); - - vector outputs = op_def->GetOutputOffset(); - if (device_data) { - return CopyPlainData(data, data_index, data_op_index, outputs, RT_MEMCPY_DEVICE_TO_DEVICE); - } else if (need_trans_flag) { - return CopyTransData(data, data_index, data_op_index, outputs); - } else { - return CopyPlainData(data, data_index, data_op_index, outputs, RT_MEMCPY_HOST_TO_DEVICE); - } -} - -Status DavinciModel::CopyTransData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, - const std::vector &outputs) { - GE_CHECK_VECTOR_NOT_EMPTY(outputs); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(outputs[0] == -1, return PARAM_INVALID, "output offset is -1"); - GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); - - auto input_tensor_desc = data_op_input_tensor_desc_map_[data_op_list_[data_op_index]->GetName()]; - auto output_tensor_desc = data_op_output_tensor_desc_map_[data_op_list_[data_op_index]->GetName()]; - - uint8_t *src_data = reinterpret_cast(data[data_index].data); - - formats::TransResult tmp_result{}; - auto input_shape = input_tensor_desc->GetShape(); - auto src_data_size = input_shape.GetShapeSize(); - auto src_data_type = input_tensor_desc->GetDataType(); - auto dst_data_type = output_tensor_desc->GetDataType(); - GELOGD("Trans data type from %s to %s, input shape %s, data size %zu", - TypeUtils::DataTypeToSerialString(src_data_type).c_str(), - TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), - src_data_size); - auto ret = - formats::TransDataType({src_data, static_cast(src_data_size), src_data_type, dst_data_type}, tmp_result); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %zu, error code %u", - TypeUtils::DataTypeToSerialString(src_data_type).c_str(), - TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), - src_data_size, ret); - return ret; - } - - void *mem_addr = mem_base_ + outputs[0]; - auto rt_ret = rtMemcpy(mem_addr, static_cast(runtime_param_.mem_size - outputs[0]), - reinterpret_cast(tmp_result.data.get()), static_cast(tmp_result.length), - RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", tmp_result.length); - return RT_FAILED; - } - GELOGI("[IMAS]CopyTransData memcpy graph_%u type[F] name[%s] output[%d] memaddr[%p] datasize[%zu]", - runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, mem_addr, tmp_result.length); - return SUCCESS; -} - -Status DavinciModel::CopyPlainData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, - const std::vector &outputs, rtMemcpyKind_t kind) { - GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); - bool flag = data[data_index].isDataSupportMemShare && support_mem_shared_flag_; - // if data attr support zero cpy,then update addrs info to flowtable - if (flag) { - GELOGI("No need to copy input data, user's input data buffer can be shared."); - return SUCCESS; - } - - GE_CHECK_VECTOR_NOT_EMPTY(outputs); - // P2P memory space parameters - void *host_data_addr = data[data_index].data; - uint32_t copy_size = data[data_index].length; - GELOGD("data output tensor is aipp tensor,copy data only."); - - void *data_out_addr = nullptr; - if (VarManager::Instance(session_id_)->IsVarAddr(outputs[0])) { - data_out_addr = var_mem_base_ + outputs[0] - runtime_param_.logic_var_base; - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, data_out_addr, copy_size, - copy_size); - } else { - data_out_addr = mem_base_ + outputs[0]; - GELOGI("output[0]=%ld, copy_size=%u, total_size=%zu", outputs[0], copy_size, TotalMemSize()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(((uint64_t)outputs[0] + (uint64_t)copy_size) > TotalMemSize(), return INTERNAL_ERROR, - "input offset add size is large than total memory."); - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, data_out_addr, copy_size, - copy_size); - } - - GE_CHK_RT_RET(rtMemcpy(data_out_addr, copy_size, host_data_addr, copy_size, kind)); - - return SUCCESS; -} /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again /// @param [in] sink_op Sink Op @@ -2082,9 +1992,9 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { } else { output_data.index = data_id; output_data.model_id = model_id_; - GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_size_list_.size(), INTERNAL_ERROR, + GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_data_info_.size(), INTERNAL_ERROR, "output buffer size[%zu] not equal output_size_list[%zu] size!", output_data.blobs.size(), - output_size_list_.size()); + output_data_info_.size()); // index of data in output_data uint32_t output_data_index = 0; @@ -2140,8 +2050,50 @@ Status DavinciModel::SyncDataAndDump() { return ret; } +Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, + std::vector &outputs) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(output_data); + if (output_data->blobs.size() > data_index) { + GELOGI("No need to generate output tensor info, model id:%u", model_id_); + return SUCCESS; + } + std::vector out_buffer_size_vec; + std::vector> shape_info_vec; + size_t input_num = op_desc->GetInputsSize(); + for (size_t i = 0; i < input_num; ++i) { + int64_t size = 0; + auto input_desc = op_desc->GetInputDescPtr(i); + GE_CHECK_NOTNULL(input_desc); + auto ret = TensorUtils::GetTensorSizeInBytes(*input_desc, size); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Get size from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); + return ret; + } + out_buffer_size_vec.push_back(size); + shape_info_vec.push_back(input_desc->GetShape().GetDims()); + } + + GELOGI("Output blobs size:%zu, data index:%u, model id:%u", out_buffer_size_vec.size(), data_index, model_id_); + for (size_t i = 0; i < out_buffer_size_vec.size(); ++i) { + std::unique_ptr data_buf(new (std::nothrow) uint8_t[out_buffer_size_vec[i]]); + if (data_buf == nullptr) { + GELOGE(GE_GRAPH_MALLOC_FAILED, "Malloc buffer failed."); + return GE_GRAPH_MALLOC_FAILED; + } + output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); + ge::OutputTensorInfo output; + output.dims = shape_info_vec[i]; + output.data = std::move(data_buf); + output.length = out_buffer_size_vec[i]; + outputs.emplace_back(std::move(output)); + GELOGI("Output index:%zu, data_length:%u.", i, output.length); + } + return SUCCESS; +} + /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again /// @param [in] sink_op Sink Op @@ -2151,22 +2103,23 @@ Status DavinciModel::SyncDataAndDump() { Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flag, OutputData *output_data) { GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null."); + std::vector outputs; if (seq_end_flag) { GELOGW("End of sequence, model id: %u", model_id_); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE), "OnComputeDone failed"); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnComputeDone failed"); return END_OF_SEQUENCE; } // return result is not required if (!rslt_flg) { GELOGW("Compute failed, model id: %u", model_id_); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR), "OnComputeDone failed."); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); return INTERNAL_ERROR; } if (output_op_list_.empty()) { GELOGW("Output tensor list is empty, model id: %u", model_id_); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR), "OnComputeDone failed."); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); return INTERNAL_ERROR; } @@ -2179,22 +2132,27 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b // copy output data from op to designated position for (auto &op_desc : output_op_list_) { - Status ret = ModelOutput::CopyResult(this, op_desc, *output_data, data_index, support_mem_shared_flag_); + Output model_output(op_desc, this); + if (model_output.Init() != SUCCESS || GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { + return INTERNAL_ERROR; + } + + Status ret = model_output.CopyResult(*output_data, data_index, data_index, false); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "CopyResult failed, op name: %s", op_desc->GetName().c_str()); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR), "OnComputeDone failed"); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); return INTERNAL_ERROR; } } GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS), "OnComputeDone failed"); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); return SUCCESS; } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief return not output to upper layer for cloud case /// @return Status result /// @@ -2209,12 +2167,13 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS), "OnComputeDone failed."); + std::vector outputs; + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); return SUCCESS; } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief dump all op input and output information /// @param [in] op_list model_id /// @return Status result @@ -2248,14 +2207,14 @@ Status DavinciModel::DumpOpInputOutput() { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief dump single op input and output information /// @param [in] dump_op model_id /// @return Status result /// Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { GE_CHK_BOOL_EXEC(nullptr != op_def, return PARAM_INVALID, "op_def is null!"); - string op_name = domi::StringUtils::ReplaceAll(op_def->GetName(), "/", "-"); + string op_name = ge::StringUtils::ReplaceAll(op_def->GetName(), "/", "-"); GELOGI("dump op name:%s, type:%s, model_id: %u.", op_def->GetName().c_str(), op_def->GetType().c_str(), model_id_); string model_path = "./dump" + to_string(model_id_); if (mmAccess(model_path.c_str()) != EN_OK) { @@ -2272,7 +2231,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { GELOGD("DumpSingleOp[%s], input size[%zu], input memory type size[%zu]", op_def->GetName().c_str(), op_def->GetInputsSize(), v_memory_type.size()); for (size_t i = 0; i < input_addr_vec.size(); i++) { - if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { + if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { continue; } int64_t input_size = input_size_vec.at(i); @@ -2296,7 +2255,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { op_def->GetOutputsSize(), v_memory_type.size()); if (!(op_def->GetType() == "Const")) { for (size_t i = 0; i < output_addr_vec.size(); i++) { - if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { + if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { continue; } int64_t output_size = output_size_vec.at(i); @@ -2441,12 +2400,11 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().WriteInternalErrorCode(); GELOGI("Model run end, model id:%u", model->model_id_); - GEEVENT("Model Run thread end, model_id:%u.", model->model_id_); return nullptr; } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief call API provided by data inputer to destroy thread /// @param [in] no /// @return Status Destroy result @@ -2467,7 +2425,7 @@ Status DavinciModel::DestroyThread() { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief create model std::thread, /// @brief start to execute Model /// @param [in] no @@ -2475,9 +2433,6 @@ Status DavinciModel::DestroyThread() { /// @author /// Status DavinciModel::ModelRunStart() { - GE_CHK_BOOL_RET_STATUS((domi::RESET != DavinciModel::GetSysMode()) && (domi::STOP != DavinciModel::GetSysMode()), - INTERNAL_ERROR, "Model Start FAIL in wrong sys mode!"); - GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); LockRunFlg(); @@ -2502,16 +2457,13 @@ Status DavinciModel::ModelRunStart() { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief call API provided by data inputer and destroy model Thread /// @param [in] no /// @return Status Destroy result /// @author /// Status DavinciModel::ModelRunStop() { - GE_CHK_BOOL_RET_STATUS((DavinciModel::GetSysMode() != domi::RESET) && (DavinciModel::GetSysMode() != domi::STOP), - INTERNAL_ERROR, "Model stop FAIL in wrong sys mode!"); - LockRunFlg(); GE_MAKE_GUARD(tmp_lock, [&] { UnlockRunFlg(); }); @@ -2554,9 +2506,6 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { } for (int32_t i = 0; i < model_task_def.task_size(); ++i) { - if (model_task_def.task(i).type() == static_cast(RT_MODEL_TASK_MODEL_END_GRAPH)) { - end_graph_id_ = i; - } std::future f = executor.commit( [](const domi::TaskDef &task, DavinciModel *model, rtContext_t ctx, int32_t idx) -> Status { rtError_t rt_ret = rtCtxSetCurrent(ctx); @@ -2626,7 +2575,7 @@ Status DavinciModel::DistributeTask() { } if (PropertiesManager::Instance().IsLayerNeedDump(name_, op->GetName())) { - SaveDumpTask(task->GetTaskID(), op, task->GetDumpArgs()); + SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); } } @@ -2656,16 +2605,18 @@ Status DavinciModel::DistributeTask() { } } } + // launch dump kernel to aicpu + GE_CHK_STATUS_RET(data_dumper_.LoadDumpInfo(), "Load dump info failed."); + return SUCCESS; +} + +void DavinciModel::SetEndGraphId(uint32_t task_id, uint32_t stream_id) { auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(name_) != all_dump_model.end()) { - data_dumper_.SaveDumpTask(task_list_[end_graph_id_]->GetTaskID(), end_graph_op_, 0); - GELOGI("The type of op is %s and the task id is %u", end_graph_op_->GetType().c_str(), - task_list_[end_graph_id_]->GetTaskID()); + GELOGI("start save end_graph_info to dumper, task_id is %u, stream_id is %u", task_id, stream_id); + data_dumper_.SaveEndGraphId(task_id, stream_id); } - // launch dump kernel to aicpu - GE_CHK_STATUS_RET(data_dumper_.LoadDumpInfo(), "Load dump info failed."); - return SUCCESS; } /// @@ -2702,6 +2653,23 @@ void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs } } +/// +/// @ingroup ge +/// @brief Set disabled input zero copy addr. +/// @param [in] const void *addr: address of task +/// @return None. +/// +void DavinciModel::DisableZeroCopy(const void *addr) { + auto it = input_outside_addrs_.find(addr); + if (it == input_outside_addrs_.end()) { + return; + } + + // Data link to RTS Op directly. + std::lock_guard lock(outside_addrs_mutex_); + copy_only_addrs_.insert(addr); +} + /// /// @ingroup ge /// @brief Save outside address used info for ZeroCopy. @@ -2710,39 +2678,60 @@ void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs /// @param [in] const char *args_offset: arguments address save the address. /// @return None. /// -void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, - void *args_offset) { +void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, + void *args, size_t size, size_t offset) { // Internal call has ensured that op_desc is not nullptr - int64_t op_id = op_desc->GetId(); size_t nums = outside_addrs.size(); + ZeroCopyTask zero_copy_task(op_desc->GetName(), static_cast(args), size); for (size_t i = 0; i < nums; ++i) { std::lock_guard lock(outside_addrs_mutex_); - auto input_iter = input_outside_addrs_.find(outside_addrs[i]); - if (input_iter != input_outside_addrs_.end()) { - input_iter->second.push_back(static_cast(args_offset) + i * sizeof(void *)); - GELOGI("SetZeroCopyAddr of input outside_addrs."); - } - auto output_iter = output_outside_addrs_.find(outside_addrs[i]); - if (output_iter != output_outside_addrs_.end()) { - output_iter->second.push_back(static_cast(args_offset) + i * sizeof(void *)); - GELOGI("SetZeroCopyAddr of output outside_addrs."); + const uintptr_t addr_val = reinterpret_cast(outside_addrs[i]); + void *args_val = static_cast(args) + offset + i * kAddrLen; + auto it = input_outside_addrs_.find(outside_addrs[i]); + if (it != input_outside_addrs_.end()) { + GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset + i * kAddrLen), "Input args invalid."); + it->second.push_back(args_val); + SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); + GELOGI("[ZCPY] %s set copy input: %zu, addr: 0x%lx, args: %p, size: %zu, offset: %zu.", + op_desc->GetName().c_str(), i, addr_val, args, size, offset + i * kAddrLen); + continue; } - // Establish a mapping between batch label and zero copy address for multi-batch scenes - if (zero_copy_op_id_batch_label_.find(op_id) == zero_copy_op_id_batch_label_.end()) { + it = output_outside_addrs_.find(outside_addrs[i]); + if (it != output_outside_addrs_.end()) { + GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset + i * kAddrLen), "Output args invalid."); + it->second.push_back(args_val); + SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); + GELOGI("[ZCPY] %s set copy output: %zu, args: %p, addr: 0x%lx.", op_desc->GetName().c_str(), i, args, addr_val); continue; } - std::string batch_label = zero_copy_op_id_batch_label_.find(op_id)->second; - auto iter = zero_copy_batch_label_addrs_.find(batch_label); - if (iter != zero_copy_batch_label_addrs_.end()) { - iter->second.push_back(static_cast(args_offset) + i * sizeof(void *)); - GELOGD("Set zero copy batch label and addrs success, batch label: %s", batch_label.c_str()); - } else { - std::vector addrs; - addrs.emplace_back(static_cast(args_offset) + i * sizeof(void *)); - zero_copy_batch_label_addrs_.emplace(pair>(batch_label, addrs)); - GELOGD("New added zero copy batch label and addrs success, batch label: %s", batch_label.c_str()); - } + } + + std::lock_guard lock(outside_addrs_mutex_); + if (zero_copy_task.IsTaskArgsSet()) { + zero_copy_task.SetOriginalArgs(info, offset + nums * kAddrLen); + zero_copy_tasks_.emplace_back(zero_copy_task); + } +} + +void DavinciModel::SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr) { + // Establish a mapping between batch label and zero copy address for multi-batch scenes + auto it = zero_copy_op_id_batch_label_.find(op_desc->GetId()); + if (it == zero_copy_op_id_batch_label_.end()) { + return; + } + + const string &batch_label = it->second; + auto iter = zero_copy_batch_label_addrs_.find(batch_label); + if (iter != zero_copy_batch_label_addrs_.end()) { + iter->second.insert(addr); + GELOGD("[ZCPY] Set zero copy batch label and addrs success, batch label: %s, op name:%s.", batch_label.c_str(), + op_desc->GetName().c_str()); + } else { + set addrs = {addr}; + zero_copy_batch_label_addrs_.emplace(pair>(batch_label, addrs)); + GELOGD("[ZCPY] New added zero copy batch label and addrs success, batch label: %s, op name:%s.", + batch_label.c_str(), op_desc->GetName().c_str()); } } @@ -2751,11 +2740,11 @@ void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vectorGetType() == domi::AIPP_DATA_TYPE) { + if (op_desc->GetType() == AIPP_DATA_TYPE) { GELOGI("This is dynamic aipp model."); is_dynamic_aipp = true; break; @@ -2794,24 +2783,26 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 /// /// @ingroup ge /// @brief Copy Inputs and Outputs addr to model for direct use. -/// @param [in] const domi::InputData &input_data: model input data. -/// @param [in] domi::OutputData &output_data: model output data. +/// @param [in] const InputData &input_data: model input data. +/// @param [in] OutputData &output_data: model output data. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input /// @return SUCCESS handle successfully / PARAM_INVALID for failed /// -Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic_input) { - if (ZeroCopyBlobs(input_addr_list_, input_size_list_, input_data.blobs, is_dynamic_input, kInputZeroCopy, - input_data.batch_label) != SUCCESS) { - GELOGE(PARAM_INVALID, "Copy input data to model failed."); +Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { + if (UpdateIoTaskArgs(input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { + GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed."); return PARAM_INVALID; } - if (ZeroCopyBlobs(output_addr_list_, output_size_list_, output_data.blobs, is_dynamic_input, kOutputZeroCopy, - input_data.batch_label) != SUCCESS) { - GELOGE(PARAM_INVALID, "Copy output data to model failed."); + if (UpdateIoTaskArgs(output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { + GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed."); return PARAM_INVALID; } + for (ZeroCopyTask &task : zero_copy_tasks_) { + GE_CHK_STATUS_RET(task.DistributeParam(is_async_mode_ ? rt_model_stream_ : nullptr), "[ZCPY] Update args failed."); + } + output_data.index = input_data.index; output_data.model_id = model_id_; return SUCCESS; @@ -2820,49 +2811,53 @@ Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &outp /// /// @ingroup ge /// @brief Copy Data addr to model for direct use. -/// @param [in] const vector &addrs: model input memory addr list. -/// @param [in] const vector &sizes: model input memory size list. -/// @param [in] const std::vector &blobs: user input data list. -/// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input -/// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy -/// @param [in] string batch_label: batch label for multi-batch scenes +/// @param [in] data_info: model memory addr/size map { data_index, { tensor_size, tensor_addr } }. +/// @param [in] is_input: input data or output data +/// @param [in] blobs: user input/output data list. +/// @param [in] is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input +/// @param [in] batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// -Status DavinciModel::ZeroCopyBlobs(const std::vector &addr_list, const std::vector &size_list, - const std::vector &blobs, bool is_dynamic_input, - ZeroCopyMode zero_copy_mode, std::string batch_label) { - if ((blobs.size() != addr_list.size()) || (blobs.size() != size_list.size())) { - GELOGE(FAILED, "Blobs not match: blobs=%zu addr=%zu size=%zu", blobs.size(), addr_list.size(), size_list.size()); +Status DavinciModel::UpdateIoTaskArgs(const map> &data_info, bool is_input, + const vector &blobs, bool is_dynamic, const string &batch_label) { + if (blobs.size() != data_info.size()) { + GELOGE(FAILED, "Blobs not match: blobs=%zu datas=%zu", blobs.size(), data_info.size()); return FAILED; } - for (size_t idx = 0; idx < size_list.size(); ++idx) { - const DataBuffer &data_buf = blobs[idx]; - if (data_buf.data == nullptr) { - GELOGE(FAILED, "data_buf.data is nullptr, index=%zu", idx); + for (const auto &data : data_info) { + if (data.first >= blobs.size()) { // check data index. + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u", blobs.size(), data_info.size(), data.first); return FAILED; } - GELOGI("Copy Blobs %zu: Input data length is %u, Op data size is %u.", idx, data_buf.length, size_list[idx]); + int64_t size = data.second.first; // size of tensor. + void *addr = data.second.second; // addr of tensor. - if (!CheckInputAndModelSize(data_buf.length, size_list[idx], is_dynamic_input)) { - GELOGE(FAILED, "Check input size and model size failed"); + const DataBuffer &buffer = blobs[data.first]; // index of data. + if (buffer.data == nullptr) { + GELOGE(FAILED, "data_buf.data is nullptr, index=%u", data.first); return FAILED; } - if (!is_dynamic_input) { - zero_copy_batch_label_addrs_.clear(); + GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %u.", data.first, data.second.second, + data.second.first, buffer.data, buffer.length); + if (!CheckInputAndModelSize(buffer.length, size, is_dynamic)) { + GELOGE(FAILED, "Check input size and model size failed"); + return FAILED; } - if (zero_copy_mode == kInputZeroCopy) { - if (ZeroCopyInputBlobs(addr_list[idx], size_list[idx], data_buf, zero_copy_mode, batch_label) != SUCCESS) { - GELOGE(FAILED, "Zero copy input blobs failed"); + // For input data, just copy for rts task. + if (is_input && copy_only_addrs_.count(addr) > 0) { + if (rtMemcpy(addr, size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { + GELOGE(FAILED, "Non-zero copy data node copy failed"); return FAILED; } + continue; } - if (zero_copy_mode == kOutputZeroCopy && !is_dynamic_input) { - if (ZeroCopyImpl(addr_list[idx], data_buf, zero_copy_mode, batch_label) != SUCCESS) { - GELOGE(FAILED, "Output zero copy data node copy failed"); + for (ZeroCopyTask &task : zero_copy_tasks_) { + uintptr_t addr_val = reinterpret_cast(addr); + if (task.UpdateTaskParam(addr_val, buffer, zero_copy_batch_label_addrs_, batch_label) != SUCCESS) { return FAILED; } } @@ -2873,103 +2868,6 @@ Status DavinciModel::ZeroCopyBlobs(const std::vector &addr_list, const s /// /// @ingroup ge -/// @brief Copy input addr to model for direct use. -/// @param [in] void *addr: model input memory addr. -/// @param [in] uint32_t size: model input memory size. -/// @param [in] const DataBuffer &data_buffer: user input data. -/// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input -/// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy -/// @param [in] string batch_label: batch label for multi-batch scenes -/// @return SUCCESS handle successfully / others handle failed -/// -Status DavinciModel::ZeroCopyInputBlobs(void *addr, int64_t size, const DataBuffer &data_buffer, - ZeroCopyMode zero_copy_mode, string batch_label) { - auto iter = input_outside_addrs_.find(addr); - if (iter == input_outside_addrs_.end()) { - GELOGE(FAILED, "Can not find addr in input outside addrs"); - return FAILED; - } - const auto &used_zero_copy_list = iter->second; - if (used_zero_copy_list.empty()) { - if (rtMemcpy(addr, size, data_buffer.data, data_buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { - GELOGE(FAILED, "Non-zero copy data node copy failed"); - return FAILED; - } - } else { - if (ZeroCopyImpl(addr, data_buffer, zero_copy_mode, batch_label) != SUCCESS) { - GELOGE(FAILED, "Input zero copy data node copy failed"); - return FAILED; - } - } - return SUCCESS; -} - -/// -/// @ingroup ge -/// @brief Copy address to args_ space for direct use. -/// @param [in] const void *src_addr: source address of the Op. -/// @param [in] const void *dst_addr: destination address of user data. -/// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy -/// @param [in] string batch_label: batch label for multi-batch scenes -/// @return SUCCESS handle successfully / others handle failed -/// -Status DavinciModel::ZeroCopyImpl(const void *src_addr, const DataBuffer &data_buf, ZeroCopyMode zero_copy_mode, - std::string batch_label) { - auto dst_addr = static_cast(data_buf.data); - auto dst_size = static_cast(data_buf.length); - Status ret = ModelUtils::ConvertVirtualAddressToPhysical(dst_addr, dst_size, dst_addr); - if (ret != SUCCESS) { - GELOGE(FAILED, "Convert virtual address to physical for dst_addr failed."); - return FAILED; - } - - map>::iterator iter; - if (zero_copy_mode == kInputZeroCopy) { - iter = input_outside_addrs_.find(src_addr); - if (iter == input_outside_addrs_.end()) { - GELOGE(FAILED, "ZeroCopyImpl failed to find input outside_addrs."); - return FAILED; - } - } - - if (zero_copy_mode == kOutputZeroCopy) { - iter = output_outside_addrs_.find(src_addr); - if (iter == output_outside_addrs_.end()) { - GELOGE(FAILED, "ZeroCopyImpl failed to find output outside_addrs."); - return FAILED; - } - } - - // Used for dynamic batch/resolution scene - vector dynamic_input_addrs; - auto dynamic_input_iter = zero_copy_batch_label_addrs_.find(batch_label); - if (dynamic_input_iter != zero_copy_batch_label_addrs_.end()) { - dynamic_input_addrs = dynamic_input_iter->second; - } - vector fix_input_addrs; - auto fix_input_iter = zero_copy_batch_label_addrs_.find(kDefaultBatchLable); - if (fix_input_iter != zero_copy_batch_label_addrs_.end()) { - fix_input_addrs = fix_input_iter->second; - } - - for (auto &addr : iter->second) { - if (!CheckDynamicBatchZeroCopyAddr(addr, dynamic_input_addrs, fix_input_addrs)) { - continue; - } - __builtin_prefetch(addr); - rtError_t rt_err = rtMemcpy(addr, sizeof(void *), &dst_addr, sizeof(void *), RT_MEMCPY_HOST_TO_DEVICE); - if (rt_err != RT_ERROR_NONE) { - GELOGE(FAILED, "ZeroCopyImpl: rtMemcpy failed"); - return FAILED; - } - GELOGI("[IMAS]refresh in/out addr new:%p, old:%p", dst_addr, src_addr); - } - - return SUCCESS; -} - -/// -/// @ingroup domi_ome /// @brief get unique identification for op when load two or more models /// @param [in] const OpDescPtr: current op. /// @param [in] string identification: unique identification for current op. @@ -2989,7 +2887,7 @@ void DavinciModel::GetUniqueId(const OpDescPtr &op_desc, std::string &unique_ide } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief For TVM Op, avoid Addr Reuse. /// @return void* /// @@ -3011,11 +2909,11 @@ const char *DavinciModel::GetRegisterStub(const string &binfile, const string &s } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Constant Op Init. /// @return Status /// -Status DavinciModel::InitConstant(const ConstOpDescPtr &op_desc) const { +Status DavinciModel::InitConstant(const OpDescPtr &op_desc) { auto v_weights = ModelUtils::GetWeights(op_desc); auto v_output_size = ModelUtils::GetOutputSize(op_desc); auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); @@ -3039,13 +2937,17 @@ Status DavinciModel::InitConstant(const ConstOpDescPtr &op_desc) const { /// the logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// unknown shape will not appear here, so we can use zero judge a tensor is scaler or not - int64_t elem_num = tensor_shape.GetShapeSize() == 0 ? 1 : tensor_shape.GetShapeSize(); + int64_t elem_num = tensor_shape.GetShapeSize(); + if (elem_num == 0 && tensor_shape.GetDims().size() == 0) { + elem_num = 1; + } uint64_t *buff = reinterpret_cast(tensor->MutableData().data()); GE_CHK_BOOL_RET_STATUS(ge::CheckInt64Uint32MulOverflow(elem_num, kBytes) == SUCCESS, FAILED, "Shape size is invalid"); - int64_t offset = elem_num * kBytes; + uint64_t offset = static_cast(elem_num * kBytes); - uint64_t hbm_raw_data_base_addr = reinterpret_cast(v_output_addr[0]) + offset; + uint64_t hbm_raw_data_base_addr = + reinterpret_cast(reinterpret_cast(v_output_addr[0])) + offset; for (int64_t i = elem_num - 1; i >= 0; --i) { buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); } @@ -3060,7 +2962,7 @@ Status DavinciModel::InitConstant(const ConstOpDescPtr &op_desc) const { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief TVM Op Init. /// @return Status /// @@ -3158,49 +3060,52 @@ void DavinciModel::CleanTbeHandle() { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief insert active_stream_indication_ /// @return Status /// -Status DavinciModel::MarkActiveStream(const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(op_desc); - std::string type = op_desc->GetType(); - GE_IF_BOOL_EXEC( - type == domi::STREAMSWITCH, std::vector active_stream_list; - GE_LOGI_IF(!ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), - "GetInt ACTIVE_STREAM_LIST failed."); - if (active_stream_list.size() != kTrueBranchStreamNum) { - GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kTrueBranchStreamNum); - return INTERNAL_ERROR; - } uint32_t true_stream_id = active_stream_list.front(); - active_stream_indication_.insert(true_stream_id); - GELOGI("flowctrl_op_index_map node:%s, true_stream_id=%u.", op_desc->GetName().c_str(), true_stream_id);); - GE_IF_BOOL_EXEC( - type == domi::STREAMACTIVE, if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { - std::vector active_stream_list; - GE_CHK_BOOL_EXEC(AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), - return INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM failed."); - - for (size_t j = 0; j < active_stream_list.size(); ++j) { - active_stream_indication_.insert(active_stream_list[j]); - GELOGI("flowctrl_op_index_map node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), - active_stream_list[j]); - } - }); - - if (type == domi::STREAMSWITCHN) { +Status DavinciModel::InitStreamActive(const OpDescPtr &op_desc) { + if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { std::vector active_stream_list; - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { - GELOGE(INTERNAL_ERROR, "StreamSwitchNOp get attr ACTIVE_STREAM failed."); - return INTERNAL_ERROR; - } + GE_CHK_BOOL_EXEC(AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), + return INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM failed."); for (size_t j = 0; j < active_stream_list.size(); ++j) { active_stream_indication_.insert(active_stream_list[j]); - GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); - }; + GELOGI("flowctrl_op_index_map node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); + } + } + + return SUCCESS; +} + +Status DavinciModel::InitStreamSwitch(const OpDescPtr &op_desc) { + std::vector active_stream_list; + GE_LOGI_IF(!ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), + "GetInt ACTIVE_STREAM_LIST failed."); + if (active_stream_list.size() != kTrueBranchStreamNum) { + GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kTrueBranchStreamNum); + return INTERNAL_ERROR; + } + + uint32_t true_stream_id = active_stream_list.front(); + active_stream_indication_.insert(true_stream_id); + GELOGI("flowctrl_op_index_map node:%s, true_stream_id=%u.", op_desc->GetName().c_str(), true_stream_id); + + return SUCCESS; +} + +Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { + std::vector active_stream_list; + if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { + GELOGE(INTERNAL_ERROR, "StreamSwitchNOp get attr ACTIVE_STREAM failed."); + return INTERNAL_ERROR; + } + + for (size_t j = 0; j < active_stream_list.size(); ++j) { + active_stream_indication_.insert(active_stream_list[j]); + GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); } - GELOGI("Flow control: active_stream_indication_ size = %zu.", active_stream_indication_.size()); return SUCCESS; } @@ -3212,7 +3117,7 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); ge::NodePtr dst_node = in_anchor->GetOwnerNode(); GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == domi::HCOMBROADCAST) { + if (dst_node->GetType() == HCOMBROADCAST) { return true; } } @@ -3220,16 +3125,46 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } +void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { + auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); + auto enable_dump = !dump_path.empty(); + + auto dump_op_env = std::getenv("DUMP_OP"); + if (dump_op_env != nullptr) { + string dump_op_flag(dump_op_env); + if (dump_op_flag == "1") { + enable_dump = true; + } + } + + GELOGI("dump path: %s, dump_op_env: %s", dump_path.c_str(), dump_op_env); + if (!is_dynamic_batch) { + zero_copy_batch_label_addrs_.clear(); + } + + if (enable_dump) { + input_zero_copy = false; + output_zero_copy = false; + } else { + for (const auto &addrs : output_outside_addrs_) { + const auto &used_list = addrs.second; + if (used_list.empty()) { + output_zero_copy = false; + break; + } + } + } +} + /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Init model stream for NN model. /// @param [in] stream user input model stream. -/// @param [in] async_mode is asynchronize mode. /// @return Status /// -Status DavinciModel::InitModelStream(rtStream_t stream, bool async_mode) { +Status DavinciModel::InitModelStream(rtStream_t stream) { // asynchronize mode, use user input stream. - if (async_mode) { + if (is_async_mode_) { rt_model_stream_ = stream; is_inner_model_stream_ = false; return SUCCESS; @@ -3255,7 +3190,7 @@ Status DavinciModel::InitModelStream(rtStream_t stream, bool async_mode) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief ACL case, do not start new thread, return execute result. /// @param [in] stream execute model stream. /// @param [in] async_mode is asynchronize mode. @@ -3264,45 +3199,24 @@ Status DavinciModel::InitModelStream(rtStream_t stream, bool async_mode) { /// Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data) { - GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, async_mode); - GE_CHK_STATUS(InitModelStream(stream, async_mode), "Init model stream failed."); - - GELOGI("do rtModelExecute task sink, model id:%u", input_data.model_id); - - auto enable_dump = false; - auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); - if (!dump_path.empty()) { - enable_dump = true; - } + is_async_mode_ = async_mode; + GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); + GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); - auto dump_op_env = std::getenv("DUMP_OP"); - if (dump_op_env != nullptr) { - string dump_op_flag(dump_op_env); - if (dump_op_flag == "1") { - enable_dump = true; - } - } - GELOGI("dump path: %s, dump_op_env: %s", dump_path.c_str(), dump_op_env); + bool input_use_zero_copy = true; + bool output_use_zero_copy = true; bool is_dynamic_batch = input_data.is_dynamic_batch; - if (is_dynamic_batch) { - input_use_zero_copy_ = true; - output_use_zero_copy_ = false; - } - - if (enable_dump) { - input_use_zero_copy_ = false; - output_use_zero_copy_ = false; - } + InitZeroCopyUtil(is_dynamic_batch, input_use_zero_copy, output_use_zero_copy); - // Asynchronous mode depends on zero copy. - if (async_mode && !input_use_zero_copy_ && !output_use_zero_copy_ && !task_list_.empty()) { - GELOGE(FAILED, "Asynchronous mode but zero copy disabled."); + // Empty task, Just copy input to output, need direct copy. + if (task_list_.empty() && (input_use_zero_copy || output_use_zero_copy)) { + GELOGE(FAILED, "Empty task, Just copy input to output, need direct copy."); return FAILED; } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_START)); Status ret = - input_use_zero_copy_ ? CopyModelData(input_data, output_data, is_dynamic_batch) : CopyInputData(input_data, true); + input_use_zero_copy ? CopyModelData(input_data, output_data, is_dynamic_batch) : CopyInputData(input_data, true); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy input data to model failed."); GELOGI("current_data.index=%u", input_data.index); @@ -3317,26 +3231,28 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("rtModelExecute end"); } - GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); - ret = output_use_zero_copy_ ? SyncDataAndDump() : CopyOutputData(input_data.index, output_data); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); - GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); + if (!is_async_mode_) { + GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); + ret = output_use_zero_copy ? SyncDataAndDump() : CopyOutputData(input_data.index, output_data); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); + GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); + } // report model time data GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), (void)SinkTimeProfile(input_data)); GELOGI("Model run end, model id:%u", model_id_); - GEEVENT("Model Run thread end, model_id:%u", model_id_); return SUCCESS; } uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { uint8_t *mem_base = nullptr; + const string purpose("feature map,used for op input and output."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); string memory_key = std::to_string(0) + "_f"; - mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(memory_key, data_size, GetDeviceId()); + mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); } else { - mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(data_size, GetDeviceId()); + mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, data_size, GetDeviceId()); } if (mem_base != nullptr) { @@ -3347,12 +3263,13 @@ uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { uint8_t *DavinciModel::MallocWeightsMem(uint32_t weights_size) { uint8_t *weights_mem_base = nullptr; + const string purpose("weights memory in inference network."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { string weight_memory_key = std::to_string(0) + "_w"; weights_mem_base = - MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(weight_memory_key, weights_size, GetDeviceId()); + MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, weight_memory_key, weights_size, GetDeviceId()); } else { - weights_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(weights_size, GetDeviceId()); + weights_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, weights_size, GetDeviceId()); } return weights_mem_base; } @@ -3508,9 +3425,9 @@ void DavinciModel::SetDataDumperArgs() { return nullptr; }; - data_dumper_.SetLoopAddr(get_var_addr(GetVariableOp(domi::NODE_NAME_GLOBAL_STEP), runtime_param_), - get_var_addr(GetVariableOp(domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER), runtime_param_), - get_var_addr(GetVariableOp(domi::NODE_NAME_FLOWCTRL_LOOP_COND), runtime_param_)); + data_dumper_.SetLoopAddr(get_var_addr(GetVariableOp(NODE_NAME_GLOBAL_STEP), runtime_param_), + get_var_addr(GetVariableOp(NODE_NAME_FLOWCTRL_LOOP_PER_ITER), runtime_param_), + get_var_addr(GetVariableOp(NODE_NAME_FLOWCTRL_LOOP_COND), runtime_param_)); GELOGI("SetDataDumperArgs end."); } @@ -3525,6 +3442,21 @@ void DavinciModel::PushHcclStream(rtStream_t value) { all_hccl_stream_list_.push_back(value); } +void DavinciModel::CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap) { + std::lock_guard lock(capacity_of_stream_mutex_); + capacity_of_stream_.emplace_back(make_pair(stream, remain_cap)); +}; + +void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { + std::lock_guard lock(capacity_of_stream_mutex_); + if (remain_cap == 0) { + capacity_of_stream_.erase(capacity_of_stream_.begin() + index); + } else { + capacity_of_stream_.at(index).second = remain_cap; + index++; + } +} + Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) { GE_CHECK_NOTNULL(var_src); GE_CHECK_NOTNULL(var_src->GetOpDesc()); @@ -3597,7 +3529,7 @@ Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { string cp_from_node; bool copy_value = false; for (auto &node : compute_graph->GetAllNodes()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != domi::VARIABLE, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), GELOGI("Get original type of cp_from_node")); if (cp_from_node.length() != 0) { diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index d5a7baf4..25cb0a3a 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -27,13 +27,14 @@ #include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" -#include "common/op/attr_define.h" #include "common/opskernel/ge_task_info.h" #include "common/types.h" #include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/data_dumper.h" #include "graph/load/new_model_manager/data_inputer.h" #include "graph/load/new_model_manager/model_utils.h" +#include "graph/load/new_model_manager/zero_copy_task.h" #include "graph/model.h" #include "graph/node.h" #include "graph/op_desc.h" @@ -44,19 +45,7 @@ #include "proto/task.pb.h" #include "task_info/task_info.h" -#define WEIGHTS_ADDR_TO_CCE(var) - namespace ge { -using domi::CONSTANTOP; -using domi::ENDGRAPH; -using domi::NETOUTPUT; -using domi::VARIABLE; -using std::vector; -enum ZeroCopyMode { - kInputZeroCopy, - kOutputZeroCopy, -}; - typedef enum tagModelProcStage { MODEL_LOAD_START = 1, MODEL_LOAD_END, @@ -83,27 +72,27 @@ struct timeInfo { class DavinciModel { public: /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief DavinciModel constructor /// @author /// DavinciModel(int32_t priority, const std::shared_ptr &listener); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief DavinciModel desctructor, free Parse and Init resources /// @author /// ~DavinciModel(); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief apply model to model_def_ /// Status Assign(const GeModelPtr &ge_model); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief DavinciModel initialization, including Stream, ccHandle, Event, DataInputer, etc /// @return execute result /// @author @@ -120,14 +109,14 @@ class DavinciModel { Status SetQueIds(const std::vector &input_queue_ids, const std::vector &output_queue_ids); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get DataInputer /// @return model ID /// uint32_t Id() const { return model_id_; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get DataInputer /// @return model ID /// @@ -136,7 +125,7 @@ class DavinciModel { static void *Run(DavinciModel *model_pointer); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief NnExecute /// @param [in] stream execute stream /// @param [in] async_mode is asynchronize mode. @@ -146,35 +135,21 @@ class DavinciModel { Status NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data); /// - /// @ingroup domi_ome - /// @brief get sys mode - /// @return SysMode - /// - static domi::SysMode GetSysMode(); - - /// - /// @ingroup domi_ome - /// @brief set sys mode - /// @return Status - /// - static Status SetSysMode(domi::SysMode mode); - - /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief lock mutex run flag /// @author /// void LockRunFlg() { mux_run_flg_.lock(); } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief unlock mutex run flag /// @author /// void UnlockRunFlg() { mux_run_flg_.unlock(); } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief get DataInputer /// @return DataInputer pointer /// @@ -195,9 +170,9 @@ class DavinciModel { // get session id uint64_t SessionId() const { return runtime_param_.session_id; } - vector GetOpDesc() { - vector opDescVector; - GE_IF_BOOL_EXEC(ge::AttrUtils::GetListOpDesc(GetGeModel(), domi::MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), + vector GetOpDesc() { + vector opDescVector; + GE_IF_BOOL_EXEC(AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), GELOGI("get opDesc of opDescVector")); return opDescVector; } @@ -240,9 +215,9 @@ class DavinciModel { const vector &GetDataList() const { return data_op_list_; } // get Op - map GetOpList() const { return op_list_; } + const map &GetOpList() const { return op_list_; } - OpDescPtr GetOpByIndex(uint32_t index) { + OpDescPtr GetOpByIndex(uint32_t index) const { if (op_list_.find(index) == op_list_.end()) { return nullptr; } @@ -264,11 +239,11 @@ class DavinciModel { std::vector GetTaskList() { return task_list_; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief get model input and output format /// @return ccTensorFormat_t current model input and output format /// - ge::Format GetFormat(); + Format GetFormat(); rtModel_t GetRtModelHandle() { rtModel_t res = rt_model_handle_; @@ -285,17 +260,17 @@ class DavinciModel { void PushHcclStream(rtStream_t value); - bool IsBroadCastOpData(const ge::NodePtr &var_node); + bool IsBroadCastOpData(const NodePtr &var_node); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief For TVM Op, avoid Addr Reuse. /// @return void* /// static const char *GetRegisterStub(const string &tvm_binfile_key, const string &session_graph_model_id = ""); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief get model input and output desc info /// @param [out] input_shape model input size /// @param [out] output_shape model output size @@ -307,7 +282,7 @@ class DavinciModel { std::vector &inputFormats, std::vector &output_formats); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get dynamic batch_info /// @param [out] batch_info /// @return execute result @@ -315,14 +290,14 @@ class DavinciModel { Status GetDynamicBatchInfo(std::vector> &batch_info); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get model_id. /// @return model_id /// uint32_t GetModelId() const { return model_id_; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief get unique identification for op when load two or more models /// @param [in] op_desc : current op. /// @param [in] string identification: unique identification for current op. @@ -331,32 +306,22 @@ class DavinciModel { void GetUniqueId(const OpDescPtr &op_desc, std::string &unique_identification); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief get model input and output desc for zero copy /// @param [out] input_shape model input size /// @param [out] output_shape model output size /// @return execute result /// - Status GetInputOutputDescInfoForZeroCopy(vector &input_desc, - vector &output_desc); - Status GetInputOutputDescInfoForZeroCopy(vector &input_desc, vector &output_desc, std::vector &inputFormats, std::vector &output_formats); - /// - /// @ingroup domi_ome - /// @brief copy input data to model - /// @return Status - /// - Status CopyInputDataToModel(const std::vector &data, uint32_t data_op_index, bool device_data); - Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); Status ReturnNoOutput(uint32_t data_id); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief dump all op input and output information /// @param [in] op_list model_id /// @return Status @@ -364,7 +329,7 @@ class DavinciModel { Status DumpOpInputOutput(); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief dump single op input and output information /// @param [in] dump_op model_id /// @return Status @@ -374,14 +339,14 @@ class DavinciModel { Status ModelRunStart(); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief stop run model /// @return Status /// Status ModelRunStop(); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief model run flag /// @return Status /// @@ -390,47 +355,33 @@ class DavinciModel { Status GetOutputDescInfo(vector &output_desc, std::vector &formats); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Set Session Id /// @return void /// void SetSessionId(uint64_t session_id) { session_id_ = session_id; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get Session Id /// @return sessionID /// uint64_t GetSessionId() const { return session_id_; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief SetDeviceId /// @return void /// void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get device Id /// @return device id /// uint32_t GetDeviceId() const { return device_id_; } - /// - /// @ingroup domi_ome - /// @brief Set Train Mode - /// @return void - /// - void SetTrainMode(bool mode) { is_train_mode_ = mode; } - - /// - /// @ingroup domi_ome - /// @brief Get Train Mode - /// @return bool true - /// - bool GetTrainMode() { return is_train_mode_; } - GeModelPtr GetGeModel() { return ge_model_; } const RuntimeParam &GetRuntimeParam() { return runtime_param_; } @@ -438,15 +389,18 @@ class DavinciModel { int32_t GetDataInputTid() const { return dataInputTid; } void SetDataInputTid(int32_t data_input_tid) { dataInputTid = data_input_tid; } + void DisableZeroCopy(const void *addr); + /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Save outside address of Data or NetOutput used info for ZeroCopy. /// @param [in] const OpDescPtr &op_desc: current op desc /// @param [in] const std::vector &outside_addrs: address of task /// @param [in] const void *args_offset: arguments address save the address. /// @return None. /// - void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs_, void *args_offset); + void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, void *args, + size_t size, size_t offset); bool GetL1FusionEnableOption() { return is_l1_fusion_enable_; } @@ -456,18 +410,23 @@ class DavinciModel { int64_t GetLoadEndTime() { return load_end_time_; } - Status SinkModelProfile(std::shared_ptr &model); + Status SinkModelProfile(); Status SinkTimeProfile(const InputData ¤t_data); - void SaveDumpTask(uint32_t task_id, const std::shared_ptr &op_desc, uintptr_t args) { - data_dumper_.SaveDumpTask(task_id, op_desc, args); + void SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { + data_dumper_.SaveDumpTask(task_id, stream_id, op_desc, args); } + void SetEndGraphId(uint32_t task_id, uint32_t stream_id); DavinciModel &operator=(const DavinciModel &model) = delete; DavinciModel(const DavinciModel &model) = delete; + const vector> &GetHcclFolowStream() { return capacity_of_stream_; } + void CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap); + void ReuseHcclFollowStream(int64_t remain_cap, int64_t &index); + private: // memory address of weights uint8_t *weights_mem_base_; @@ -484,8 +443,19 @@ class DavinciModel { struct timeInfo time_info_; int32_t dataInputTid; + void InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy); + /// - /// @ingroup domi_ome + /// @ingroup ge + /// @brief Save Batch label Info. + /// @param [in] const OpDescPtr &op_desc + /// @param [in] uintptr_t addr: address value in args block. + /// @return None. + /// + void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr); + + /// + /// @ingroup ge /// @brief Save Data address info for ZeroCopy. /// @param [in] const std::vector &outside_addrs /// @return None. @@ -493,7 +463,7 @@ class DavinciModel { void SetInputOutsideAddr(const std::vector &outside_addrs); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Save NetOutput address info for ZeroCopy. /// @param [in] const std::vector &outside_addrs /// @return None. @@ -505,69 +475,35 @@ class DavinciModel { /// @brief Copy Check input size and model op size. /// @param [in] const int64_t &input_size: input size. /// @param [in] const int64_t &op_size: model op size. - /// @param [in] is_dynamic_input: dynamic batch input flag. + /// @param [in] is_dynamic: dynamic batch input flag. /// @return true if success /// - bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic_input); + bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic); /// /// @ingroup ge /// @brief Copy Input/Output to model for direct use. /// @param [in] const InputData &input_data: user input data info. /// @param [in/out] OutputData &output_data: user output data info. - /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input + /// @param [in] bool is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input /// @return SUCCESS handle successfully / others handle failed /// - Status CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic_input); + Status CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic); /// /// @ingroup ge /// @brief Copy Data addr to model for direct use. - /// @param [in] const vector &addrs: model input memory addr list. - /// @param [in] const vector &sizes: model input memory size list. - /// @param [in] const std::vector &blobs: user input data list. - /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input - /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy - /// @param [in] string batch_label: batch label for multi-batch scenes - /// @return SUCCESS handle successfully / others handle failed - /// - Status ZeroCopyBlobs(const std::vector &addr_list, const std::vector &size_list, - const std::vector &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, - string batch_label); - - /// - /// @ingroup ge - /// @brief Copy input addr to model for direct use. - /// @param [in] void *addr: model input memory addr. - /// @param [in] uint32_t size: model input memory size. - /// @param [in] const DataBuffer &data_buffer: user input data. - /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input - /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy - /// @param [in] string batch_label: batch label for multi-batch scenes + /// @param [in] data_info: model memory addr/size map { data_index, { tensor_size, tensor_addr } }. + /// @param [in] is_input: input data or output data + /// @param [in] blobs: user input/output data list. + /// @param [in] is_dynamic: whether is dynamic input, true: is dynamic input; false: not is dynamic input + /// @param [in] batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// - Status ZeroCopyInputBlobs(void *addr, int64_t size, const DataBuffer &data_buffer, ZeroCopyMode zero_copy_mode, - string batch_label); + Status UpdateIoTaskArgs(const map> &data_info, bool is_input, + const vector &blobs, bool is_dynamic, const string &batch_label); - /// - /// @ingroup ge - /// @brief Copy address to args_ space for direct use. - /// @param [in] const void *src_addr: source address of the Op. - /// @param [in] const void *dst_addr: destination address of user data. - /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy - /// @param [in] string batch_label: batch label for multi-batch scenes - /// @return SUCCESS handle successfully / others handle failed - /// - Status ZeroCopyImpl(const void *src_addr, const DataBuffer &data_buf, ZeroCopyMode zero_copy_mode, - string batch_label); - - Status CopyInputData(const InputData ¤t_data, bool device_data = false); - - Status CopyTransData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, - const std::vector &outputs); - - Status CopyPlainData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, - const std::vector &outputs, rtMemcpyKind_t kind); + Status CopyInputData(const InputData &input_data, bool device_data = false); Status CopyOutputData(uint32_t data_id, OutputData &output_data); @@ -612,11 +548,9 @@ class DavinciModel { /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. - /// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status /// - Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, - std::map> &input_data_info); + Status InitDataOp(const NodePtr &node, uint32_t &data_op_index); /// /// @ingroup ge @@ -629,28 +563,42 @@ class DavinciModel { /// /// @ingroup ge /// @brief NetOutput Op Initialize. - /// @param [in] op_desc: NetOutput Op descriptor. + /// @param [in] NodePtr: NetOutput Op. /// @return Status /// - Status InitNetOutput(const OpDescPtr &op_desc); + Status InitNetOutput(const NodePtr &node); /// /// @ingroup ge - /// @brief Make Input and Output addr for feature use. - /// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. + /// @brief output zero copy node Initialize. + /// @param [in] NodePtr: Data Op. /// @return Status /// - Status CombineDataInfo(const std::map> &input_data_info); + Status InitOutputZeroCopyNodes(const NodePtr &node); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Constant Op Init. /// @return Status /// - Status InitConstant(const ConstOpDescPtr &op_desc) const; + Status InitConstant(const OpDescPtr &op_desc); + + Status InitVariable(const OpDescPtr &op_desc); + + /// @ingroup ge + /// @brief LabelSet Op Initialize. + /// @param [in] op_desc: LabelSet Op descriptor. + /// @return Status + Status InitLabelSet(const OpDescPtr &op_desc); + + Status InitStreamSwitch(const OpDescPtr &op_desc); + + Status InitStreamActive(const OpDescPtr &op_desc); + + Status InitStreamSwitchN(const OpDescPtr &op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief TVM Op Init. /// @return Status /// @@ -660,11 +608,18 @@ class DavinciModel { void CleanTbeHandle(); /// - /// @ingroup domi_ome + /// @ingroup ge + /// @brief Make active stream list and bind to model. + /// @return: 0 for success / others for fail + /// + Status BindModelStream(); + + /// + /// @ingroup ge /// @brief Init model stream for NN model. /// @return Status /// - Status InitModelStream(rtStream_t stream, bool async_mode); + Status InitModelStream(rtStream_t stream); /// /// @ingroup ge @@ -680,26 +635,16 @@ class DavinciModel { /// Status BindInputQueue(); + Status CpuTaskModelZeroCopy(std::vector &mbuf_list, + std::map> &outside_addrs); + /// /// @ingroup ge /// @brief ACL, Bind NetOutput Op addr to output queue. /// @return: 0 for success / others for fail /// Status BindOutputQueue(); - - /// - /// @ingroup ge - /// @brief ACL, Make active stream for S0. - /// @return: 0 for success / others for fail - /// - Status BindActiveStream(); - - /// - /// @ingroup domi_ome - /// @brief insert active_stream_indication_ - /// @return Status - /// - Status MarkActiveStream(const OpDescPtr &op_desc); + Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size); /// /// @ingroup ge @@ -709,7 +654,7 @@ class DavinciModel { /// @param [in] size: Data Op output tensor size. /// @return: 0 for success / others for fail /// - Status CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size); + Status CpuModelDequeue(uint32_t queue_id); /// /// @ingroup ge @@ -736,6 +681,8 @@ class DavinciModel { /// Status CpuWaitEndGraph(); + Status BindEnqueue(); + Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf); /// /// @ingroup ge /// @brief definiteness queue schedule, repeat run model. @@ -769,6 +716,9 @@ class DavinciModel { void SetDataDumperArgs(); + Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, + std::vector &outputs); + bool is_model_has_inited_; uint32_t model_id_; uint32_t runtime_model_id_; @@ -785,10 +735,8 @@ class DavinciModel { vector variable_op_list_; - vector output_size_list_; // Init by NetOutput Input Tensor - vector output_addr_list_; // Init by NetOutput Input Tensor - vector input_size_list_; // Init by Data Output Tensor - vector input_addr_list_; // Init by Data Output Tensor + std::map> input_data_info_; // Virtual address from Data output. + std::map> output_data_info_; // Virtual address from NetOutput input. // output op: save cce op actual needed memory size vector output_memory_size_list_; @@ -801,10 +749,6 @@ class DavinciModel { std::mutex mux_run_flg_; - static domi::SysMode mode_; - - static std::mutex mutex_mode_; - int32_t priority_; vector stream_list_; @@ -812,17 +756,25 @@ class DavinciModel { std::mutex all_hccl_stream_list_mutex_; vector all_hccl_stream_list_; + // for reuse hccl_follow_stream + std::mutex capacity_of_stream_mutex_; + std::vector> capacity_of_stream_; + vector event_list_; vector label_list_; + set label_id_indication_; std::mutex outside_addrs_mutex_; - std::map> input_outside_addrs_; - std::map> output_outside_addrs_; + std::vector zero_copy_tasks_; // Task used Data or NetOutput addr. + std::set copy_only_addrs_; // Address need copy to original place. + // {node_addr, {addr_in_task_args}} + std::map> input_outside_addrs_; // Key is virtual address from Data. + std::map> output_outside_addrs_; // Key is virtual address from NetOutput. // {op_id, batch_label} - map zero_copy_op_id_batch_label_; + std::map zero_copy_op_id_batch_label_; // {batch_label, addrs} - map> zero_copy_batch_label_addrs_; + std::map> zero_copy_batch_label_addrs_; std::vector task_list_; // rt_moodel_handle @@ -832,6 +784,8 @@ class DavinciModel { bool is_inner_model_stream_; + bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. + // ACL queue schedule, save queue ids for Init. std::vector cpu_task_list_; std::vector input_queue_ids_; // input queue ids created by caller. @@ -843,16 +797,14 @@ class DavinciModel { std::map data_op_input_tensor_desc_map_; std::map data_op_output_tensor_desc_map_; - bool support_mem_shared_flag_; - uint64_t session_id_; uint32_t device_id_; - bool is_train_mode_; - std::mutex flowctrl_op_index_internal_map_mutex_; std::map flowctrl_op_index_internal_map_; + + std::vector active_stream_list_; std::set active_stream_indication_; std::shared_ptr model_task_def_; @@ -873,18 +825,8 @@ class DavinciModel { int64_t maxDumpOpNum_; // for data dump DataDumper data_dumper_; - bool input_use_zero_copy_; - bool output_use_zero_copy_; uint64_t iterator_count_; bool is_l1_fusion_enable_; - - uint32_t end_graph_id_; - OpDescPtr end_graph_op_; }; - -#define TIME_LOG_HEAD_FMT " OP_ID OP_NAME OP_TYPE ELAPSED TIME(ms)" -#define OP_TIME_LOG_FMT "%d_%-5d %-5d | %-20s | %-15s | %10f | %10d" -#define MODEL_TIME_LOG_FMT "******** Model %d ends, elapsed time: %f ms ********" -const size_t INPUT_OUTPUT_NAME_MAX_LEN = 256; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/src/ge/graph/load/new_model_manager/davinci_model_parser.cc b/src/ge/graph/load/new_model_manager/davinci_model_parser.cc index 0c5d0073..b744f907 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model_parser.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model_parser.cc @@ -35,14 +35,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelInfoParser(const Mo GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, GE_CHK_RT(rtDeviceReset(0)); return ret, "Parse model failed"); - domi::ModelFileHeader *file_header = (domi::ModelFileHeader *)model.model_data; + auto *file_header = reinterpret_cast(model.model_data); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_header == nullptr, GE_CHK_RT(rtDeviceReset(0)); return PARAM_INVALID, "file_header is null."); model_info.version = file_header->version; model_info.is_encrypt = false; - GE_IF_BOOL_EXEC(domi::ENCRYPTED == file_header->is_encrypt, model_info.is_encrypt = true); + GE_IF_BOOL_EXEC(ENCRYPTED == file_header->is_encrypt, model_info.is_encrypt = true); std::shared_ptr davinci_model = std::shared_ptr(new (std::nothrow) DavinciModel(model.priority, nullptr)); diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index 8cf866d0..b7bd3deb 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -154,6 +154,20 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) { } } +ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { + GELOGI("Destroy aicpu session for infer, model id is %u.", model_id); + std::lock_guard lock(map_mutex_); + auto it = model_map_.find(model_id); + if (it == model_map_.end()) { + GELOGE(PARAM_INVALID, "model id %u does not exists.", model_id); + return PARAM_INVALID; + } + uint64_t session_id = it->second->GetSessionId(); + GELOGI("Destroy aicpu session for infer, session id is %u.", session_id); + DestroyAicpuSession(session_id); + return SUCCESS; +} + ge::Status ModelManager::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id) { GELOGD("destroy aicpu kernel in session_id %lu, model_id %u.", session_id, model_id); std::lock_guard lock(sess_ids_mutex_); @@ -205,10 +219,12 @@ Status ModelManager::SetDevice(int32_t deviceId) const { /// @brief load model online /// @return Status run result /// -Status ModelManager::LoadModelOnline(uint32_t &model_id, shared_ptr &model, +Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr &ge_model, std::shared_ptr listener) { GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); - GenModelId(&model_id); + if (model_id == INVALID_MODEL_ID) { + GenModelId(&model_id); + } GE_CHK_STATUS_RET(SetDevice(static_cast(GetContext().DeviceId())), "Set device failed, model id:%u.", model_id); @@ -224,9 +240,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, shared_ptr & Status ret = SUCCESS; do { - GeModelPtr ge_model; - GE_IF_BOOL_EXEC( - ModelHelper::TransModelToGeModel(model, ge_model) != SUCCESS, GELOGW("trans model to ge_model failed."); break;); GE_TIMESTAMP_START(Assign); GE_IF_BOOL_EXEC(SUCCESS != (ret = davinci_model->Assign(ge_model)), GELOGW("assign model to modeldef failed."); break;); @@ -244,7 +257,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, shared_ptr & davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile(davinci_model) != SUCCESS) { + if (davinci_model->SinkModelProfile() != SUCCESS) { GELOGW("Sink model profile failed."); } } @@ -271,7 +284,6 @@ Status ModelManager::DeleteModel(uint32_t id) { } (void)model_map_.erase(it); - free_model_id_.push_back(id); return SUCCESS; } @@ -301,13 +313,6 @@ Status ModelManager::UnloadModeldef(uint32_t model_id) { Status ModelManager::DataInput(const InputData &input_data, OutputData &output_data) { GELOGI("calling the DataInput"); - - domi::SysMode mode = DavinciModel::GetSysMode(); - if ((mode == domi::RESET) || (mode == domi::STOP)) { - GELOGE(domi::MODEL_NOT_READY, "System mode is reset or stop"); - return domi::MODEL_NOT_READY; - } - shared_ptr data_wrap(new (std::nothrow) InputDataWrapper()); GE_CHECK_NOTNULL(data_wrap); @@ -339,16 +344,10 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d /// /// @ingroup domi_ome -/// @brief load Input and output TensorInfor for Model +/// @brief load Input and output TensorInfo for Model /// @return Status run result /// -Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector &inputs, - std::vector &outputs) { - domi::SysMode mode = DavinciModel::GetSysMode(); - if ((mode == domi::RESET) || (mode == domi::STOP)) { - GELOGE(domi::MODEL_NOT_READY, "System mode is reset or stop"); - return domi::MODEL_NOT_READY; - } +Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector &inputs) { std::shared_ptr model = GetModel(model_id); GE_CHECK_NOTNULL(model); @@ -358,31 +357,16 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vectorGetDataList()) { - GE_CHECK_NOTNULL(op); - GE_CHECK_GE(inputs.size(), 1); - GE_CHECK_GE(inputs.size() - 1, index); - + for (size_t i = 0; i < inputs.size(); ++i) { DataBuffer data; - data.data = inputs[index].data.data; - data.length = inputs[index].data.length; + data.data = inputs[i].data; + data.length = static_cast(inputs[i].length); input_data.blobs.push_back(data); - index++; } - CHECK_FALSE_EXEC(input_data.blobs.size() >= inputs.size(), - GELOGW("cur_inputs size = %zu, inputs size = %zu.", input_data.blobs.size(), inputs.size());); - OutputData output_data; output_data.model_id = model_id; output_data.index = 0; - for (size_t i = 0; i < outputs.size(); i++) { - DataBuffer data; - data.data = outputs[i].data.data; - data.length = outputs[i].data.length; - output_data.blobs.push_back(data); - } shared_ptr data_wrap(new (std::nothrow) InputDataWrapper()); GE_CHECK_NOTNULL(data_wrap); @@ -472,7 +456,7 @@ Status ModelManager::HandleAclProfilingCommand(const Command &command) { std::string map_key = command.cmd_params[0]; std::string value = command.cmd_params[1]; - if (map_key == domi::PROFILE_CONFIG) { + if (map_key == PROFILE_CONFIG) { ProfilingManager::Instance().SetProfilingConfig(value); } @@ -490,18 +474,17 @@ Status ModelManager::HandleProfileCommand(const Command &command) { GELOGI("Profiling mode, Command key:%s , value:%s ", map_key.c_str(), value.c_str()); - auto iter = domi::PROFILE_COMPONENT_MAP.find(map_key); - if (iter != domi::PROFILE_COMPONENT_MAP.end()) { + auto iter = PROFILE_COMPONENT_MAP.find(map_key); + if (iter != PROFILE_COMPONENT_MAP.end()) { std::string property_value = (value == "on") ? "1" : "0"; PropertiesManager::Instance().SetPropertyValue(iter->second, property_value); } - if ((map_key == domi::PROFILER_JOBCTX || map_key == domi::PROFILER_TARGET_PATH || - map_key == domi::RTS_PROFILE_PATH)) { + if ((map_key == PROFILER_JOBCTX || map_key == PROFILER_TARGET_PATH || map_key == RTS_PROFILE_PATH)) { PropertiesManager::Instance().SetPropertyValue(map_key, value); } - if ((map_key == domi::PROFILE_STOP_KEY) && (value == domi::PROFILE_STOP_VALUE)) { + if ((map_key == PROFILE_STOP_KEY) && (value == PROFILE_STOP_VALUE)) { rtError_t rt_ret = rtProfilerStop(); if (rt_ret != RT_ERROR_NONE) { GELOGE(PARAM_INVALID, "Call rtProfilerStop ret:%d", rt_ret); @@ -512,6 +495,19 @@ Status ModelManager::HandleProfileCommand(const Command &command) { return SUCCESS; } +static Status ParserPara(const Command &command, const string &dump_key, string &dump_value) { + auto iter = std::find(command.cmd_params.begin(), command.cmd_params.end(), dump_key); + if (iter != command.cmd_params.end()) { + ++iter; + if (iter == command.cmd_params.end()) { + GELOGE(PARAM_INVALID, "Invalid access."); + return PARAM_INVALID; + } + dump_value = *iter; + } + return SUCCESS; +} + Status ModelManager::HandleDumpCommand(const Command &command) { if (command.cmd_params.size() % kDumpCmdPairSize != 0) { GELOGE(PARAM_INVALID, "When the cmd_type is 'dump', the size of cmd_params must be a even number."); @@ -521,32 +517,22 @@ Status ModelManager::HandleDumpCommand(const Command &command) { std::string dump_status("off"); std::string dump_model(DUMP_ALL_MODEL); std::string dump_path("/"); + std::string dump_mode("output"); std::set dump_layers; - std::string dump_layer_count; - auto iter_dump_status = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_STATUS); - if (iter_dump_status != command.cmd_params.end()) { - ++iter_dump_status; - if (iter_dump_status == command.cmd_params.end()) { - GELOGE(PARAM_INVALID, "Invalid access."); - return PARAM_INVALID; - } - - dump_status = *iter_dump_status; - GELOGI("dump status = %s.", dump_status.c_str()); + auto ret = ParserPara(command, DUMP_STATUS, dump_status); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "parser dump status failed"); + return FAILED; } + GELOGI("dump status = %s.", dump_status.c_str()); - auto iter_dump_model = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_MODEL); - if (iter_dump_model != command.cmd_params.end()) { - ++iter_dump_model; - if (iter_dump_model == command.cmd_params.end()) { - GELOGE(PARAM_INVALID, "Invalid access."); - return PARAM_INVALID; - } - - dump_model = *iter_dump_model; - GELOGI("dump model = %s.", dump_model.c_str()); + ret = ParserPara(command, DUMP_MODEL, dump_model); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "parser dump model failed"); + return FAILED; } + GELOGI("dump status = %s.", dump_model.c_str()); if (dump_status == "off" || dump_status == "OFF") { PropertiesManager::Instance().DeleteDumpPropertyValue(dump_model); @@ -560,24 +546,37 @@ Status ModelManager::HandleDumpCommand(const Command &command) { } } - auto iter_dump_path = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_FILE_PATH); - if (iter_dump_path != command.cmd_params.end()) { - ++iter_dump_path; - if (iter_dump_path == command.cmd_params.end()) { - GELOGE(PARAM_INVALID, "Invalid access."); - return PARAM_INVALID; - } + ret = ParserPara(command, DUMP_FILE_PATH, dump_path); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "parser dump path failed"); + return FAILED; + } + if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { + dump_path += "/"; + } + GELOGI("dump status = %s.", dump_path.c_str()); - dump_path = *iter_dump_path; + ret = ParserPara(command, DUMP_MODE, dump_mode); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "parser dump mode failed"); + return FAILED; + } + GELOGI("dump mode = %s", dump_mode.c_str()); - if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path += "/"; + auto iter_dump_mode = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_MODE); + if (iter_dump_mode != command.cmd_params.end()) { + ++iter_dump_mode; + if (iter_dump_mode == command.cmd_params.end()) { + GELOGE(PARAM_INVALID, "Invalid access."); + return PARAM_INVALID; } - GELOGI("dump path = %s.", dump_path.c_str()); + dump_mode = *iter_dump_mode; + GELOGI("dump mode = %s", dump_mode.c_str()); } PropertiesManager::Instance().AddDumpPropertyValue(dump_model, dump_layers); PropertiesManager::Instance().SetDumpOutputPath(dump_path); + PropertiesManager::Instance().SetDumpMode(dump_mode); return SUCCESS; } @@ -599,15 +598,6 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vectorGetInputOutputDescInfo(input_desc, output_desc); } -Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, vector &input_desc, - vector &output_desc) { - std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, - "GetInputOutputDescInfo Failed, Invalid Model ID %u !", model_id); - - return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc); -} - Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc, std::vector &inputFormats, std::vector &outputFormats) { @@ -677,6 +667,15 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model break; } davinci_model->SetId(model_id); + + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE || device_id < 0) { + GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); + return FAILED; + } + davinci_model->SetDeviceId(device_id); + ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); @@ -689,7 +688,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile(davinci_model) != SUCCESS) { + if (davinci_model->SinkModelProfile() != SUCCESS) { GELOGW("Sink model profile failed."); } } @@ -716,7 +715,7 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d GE_CHK_BOOL_RET_STATUS(model_data.key.empty() || access(model_data.key.c_str(), F_OK) == 0, PARAM_INVALID, "input key file path is not valid, %s", strerror(errno)); - domi::ModelHelper model_helper; + ModelHelper model_helper; Status ret = model_helper.LoadModel(model_data); if (ret != SUCCESS) { GELOGE(ret, "load model failed."); @@ -807,17 +806,17 @@ Status ModelManager::GetModelMemAndWeightSize(const ModelData &model, size_t &me Status ret = DavinciModelParser::ParseModelContent(model, model_data, model_len); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "parse model content failed!"); - domi::OmFileLoadHelper om_file_helper; + OmFileLoadHelper om_file_helper; ret = om_file_helper.Init(model_data, model_len); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "om file helperInit failed!"); - auto partition_table = reinterpret_cast(model_data); + auto partition_table = reinterpret_cast(model_data); if (partition_table->num == 1) { GELOGE(FAILED, "om model is error,please use executable om model"); return FAILED; } - domi::ModelPartition task_partition; - if (om_file_helper.GetModelPartition(domi::ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { + ModelPartition task_partition; + if (om_file_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { GELOGE(FAILED, "get task model partition failed."); return FAILED; } @@ -827,14 +826,14 @@ Status ModelManager::GetModelMemAndWeightSize(const ModelData &model, size_t &me return FAILED; } if (task_partition.size != 0) { - if (!domi::ReadProtoFromArray(task_partition.data, static_cast(task_partition.size), model_task_def.get())) { + if (!ReadProtoFromArray(task_partition.data, static_cast(task_partition.size), model_task_def.get())) { GELOGE(FAILED, "ReadProtoFromArray failed."); return FAILED; } } - domi::ModelPartition partition_weight; - ret = om_file_helper.GetModelPartition(domi::ModelPartitionType::WEIGHTS_DATA, partition_weight); + ModelPartition partition_weight; + ret = om_file_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition_weight); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Get weight partition failed. ret = %u", ret); mem_size = model_task_def->memory_size(); @@ -848,12 +847,6 @@ void ModelManager::GenModelId(uint32_t *id) { } std::lock_guard lock(map_mutex_); - if (free_model_id_.empty()) { - *id = ++max_model_id_; - } else { - *id = free_model_id_.back(); - free_model_id_.pop_back(); - } + *id = ++max_model_id_; } - } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index 7ac4d822..ae73c1ce 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -23,17 +23,18 @@ #include #include #include +#include #include #include "cce/aicpu_engine_struct.h" -#include "common/types.h" -#include "common/ge_types.h" #include "common/ge_inner_error_codes.h" +#include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" +#include "common/types.h" +#include "ge/ge_api_types.h" +#include "graph/ge_context.h" #include "graph/model.h" #include "runtime/base.h" -#include "graph/ge_context.h" -#include "ge/ge_api_types.h" namespace ge { @@ -68,7 +69,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @return Status run result /// @author @ /// - ge::Status LoadModelOnline(uint32_t &model_id, std::shared_ptr &model, + ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr &model, std::shared_ptr listener); /// @@ -116,8 +117,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// ge::Status DataInput(const InputData &input_data, OutputData &output_data); - ge::Status DataInputTensor(uint32_t model_id, const std::vector &inputs, - std::vector &outputs); + ge::Status DataInputTensor(uint32_t model_id, const std::vector &inputs); /// /// @ingroup domi_ome @@ -193,9 +193,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @return SUCCESS success /// @return PARAM_INVALID parameter invalid /// - ge::Status GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, std::vector &input_desc, - std::vector &output_desc); - ge::Status GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, std::vector &input_desc, std::vector &output_desc, std::vector &inputFormats, @@ -221,6 +218,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status CreateAicpuKernel(uint64_t session_id, uint32_t model_id, uint64_t kernel_id); + ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); + private: /// /// @ingroup domi_ome @@ -250,7 +249,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::map> model_map_; std::map> model_aicpu_kernel_; - std::vector free_model_id_; uint32_t max_model_id_; std::mutex map_mutex_; std::mutex sess_ids_mutex_; diff --git a/src/ge/graph/load/new_model_manager/model_output.cc b/src/ge/graph/load/new_model_manager/model_output.cc deleted file mode 100644 index affda08a..00000000 --- a/src/ge/graph/load/new_model_manager/model_output.cc +++ /dev/null @@ -1,41 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/load/new_model_manager/model_output.h" -#include -#include -#include "common/debug/log.h" -#include "common/op/ge_op_utils.h" -#include "graph/load/new_model_manager/davinci_model.h" - -#include "graph/load/output/output.h" - -namespace ge { -Status ModelOutput::CopyResult(DavinciModel *model, OpDescPtr op_desc, OutputData &rslt, uint32_t &data_index, - bool support_mem_share) { - uint32_t data_begin = data_index; - std::shared_ptr model_output = MakeShared(op_desc, model); - if (model_output == nullptr) { - return INTERNAL_ERROR; - } - - if (model_output->Init() != SUCCESS) { - return INTERNAL_ERROR; - } - - return model_output->CopyResult(rslt, data_begin, data_index, support_mem_share); -} -} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index df11c874..c372f528 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -30,48 +30,6 @@ #include "graph/manager/graph_var_manager.h" namespace ge { -/// -/// @ingroup domi_ome -/// @brief Check is Output Op. -/// @return bool -/// -bool ModelUtils::IsOutput(ConstOpDescPtr op_desc) { - GE_CHECK_NOTNULL_EXEC(op_desc, return false); - size_t output_size = op_desc->GetOutputsSize(); - for (size_t i = 0; i < output_size; ++i) { - bool output_tensor = false; - GE_IF_BOOL_EXEC(TensorUtils::GetOutputTensor(op_desc->GetOutputDesc(i), output_tensor) != GRAPH_SUCCESS, - GELOGW("Get OutputTensor failed, name: %s, output index: %zu", op_desc->GetName().c_str(), i); - return false;); - if (output_tensor) { - return true; - } - } - - return false; -} - -/// -/// @ingroup domi_ome -/// @brief Check is the Input need trans code. -/// @return bool -/// -bool ModelUtils::IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index) { - GE_CHECK_NOTNULL_EXEC(op_desc, return false); - const auto &input_desc = op_desc->MutableInputDesc(static_cast(tensor_index)); - const auto &output_desc = op_desc->MutableOutputDesc(static_cast(tensor_index)); - GE_CHECK_NOTNULL_EXEC(input_desc, return false); - GE_CHECK_NOTNULL_EXEC(output_desc, return false); - - if ((output_desc->GetFormat() == FORMAT_NC1HWC0) && (output_desc->GetDataType() == DT_INT8)) { - // AIPP input, add attribute in data op to tag aipp - return false; - } - - return (input_desc->GetFormat() != output_desc->GetFormat()) || - (input_desc->GetDataType() != output_desc->GetDataType()); -} - /// /// @ingroup domi_ome /// @brief Get input size. @@ -85,12 +43,14 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { - if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != domi::NETOUTPUT)) { + if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights size to input - GE_IF_BOOL_EXEC( - true, GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - if (tensor_size) { v_input_size.push_back(tensor_size); }); + GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + if (tensor_size) { + v_input_size.push_back(tensor_size); + } continue; } @@ -227,38 +187,6 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { return v_weights; } -/// -/// @ingroup domi_ome -/// @brief Save Output tensor info to vector. -/// @return Status -/// -Status ModelUtils::GetOutputSize(ConstOpDescPtr op_desc, vector &output_size_list, - vector &output_memory_size_list) { - GE_CHECK_NOTNULL(op_desc); - - for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { - bool output_tensor = false; - auto output_desc = op_desc->GetOutputDesc(i); - GE_CHK_STATUS_RET(TensorUtils::GetOutputTensor(output_desc, output_tensor), - "get OutputTensor failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - - if (output_tensor) { - // get transferred parameters such as size - int64_t size = 0; - int64_t memory_size = 0; - graphStatus graph_status0 = TensorUtils::GetTensorSizeInBytes(output_desc, size); - graphStatus graph_status1 = TensorUtils::GetTensorMemorySizeInBytes(output_desc, memory_size); - if ((graph_status0 != GRAPH_SUCCESS) || (graph_status1 != GRAPH_SUCCESS)) { - return INTERNAL_ERROR; - } - output_size_list.push_back(size); - output_memory_size_list.push_back(memory_size); - } - } - - return SUCCESS; -} - /// /// @ingroup domi_ome /// @brief Get AiCpuOp Input descriptor. @@ -384,23 +312,24 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); if (has_mem_type_attr && (v_memory_type.size() != inputs_size)) { - GELOGE(PARAM_INVALID, "L1Fusion: check input size failed, op: %s, input v_memory_type size: %zu input numbers: %zu", + GELOGE(PARAM_INVALID, "Fusion: check input size failed, op: %s, input v_memory_type size: %zu input numbers: %zu", op_desc->GetName().c_str(), v_memory_type.size(), inputs_size); return v_input_data_addr; } for (size_t i = 0; i < inputs_size; ++i) { - if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != domi::NETOUTPUT)) { + if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights address to input - GE_IF_BOOL_EXEC( - true, GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); if (tensor_size) { - int64_t data_offset = 0; - GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); - uint8_t *weight_addr = static_cast(weight_base + data_offset - logic_weight_base); - v_input_data_addr.push_back(weight_addr); - GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, - op_desc->GetName().c_str(), i, weight_addr); - }); + GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + if (tensor_size) { + int64_t data_offset = 0; + GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); + uint8_t *weight_addr = static_cast(weight_base + data_offset - logic_weight_base); + v_input_data_addr.push_back(weight_addr); + GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, + op_desc->GetName().c_str(), i, weight_addr); + } non_const_index++; continue; } @@ -424,9 +353,9 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co continue;); // feature maps uint8_t *mem_addr = nullptr; - // l1 fusion - if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { - mem_addr = reinterpret_cast(input_offset); + // fusion + if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { + mem_addr = reinterpret_cast(reinterpret_cast(input_offset)); v_input_data_addr.push_back(mem_addr); } else { mem_addr = static_cast(mem_base + input_offset - logic_mem_base); @@ -479,7 +408,7 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); if (has_mem_type_attr && (v_memory_type.size() != outputs_size)) { GELOGE(PARAM_INVALID, - "L1Fusion: check output size failed, op: %s, output v_memory_type size: %lu output numbers: %zu", + "Fusion: check output size failed, op: %s, output v_memory_type size: %lu output numbers: %zu", op_desc->GetName().c_str(), v_memory_type.size(), outputs_size); return v_output_data_addr; } @@ -492,9 +421,9 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C continue;); // feature maps uint8_t *mem_addr = nullptr; - // l1 fusion - if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { - mem_addr = reinterpret_cast(v_output_offset[i]); + // fusion + if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { + mem_addr = reinterpret_cast(reinterpret_cast(v_output_offset[i])); v_output_data_addr.push_back(mem_addr); } else { mem_addr = static_cast(mem_base + v_output_offset[i] - logic_mem_base); @@ -536,10 +465,10 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { - if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { + if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { v_workspace_data_addr.push_back(reinterpret_cast(v_workspace_offset[i])); - GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, - reinterpret_cast(v_workspace_offset[i])); + GELOGI("Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, + reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); } else { int64_t workspace_offset = v_workspace_offset[i]; int64_t workspace_bytes = v_workspace_bytes[i]; diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h index 1a15c930..d6afd5c8 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.h +++ b/src/ge/graph/load/new_model_manager/model_utils.h @@ -33,20 +33,6 @@ class ModelUtils { ModelUtils() = default; ~ModelUtils() = default; - /// - /// @ingroup domi_ome - /// @brief Check is Output Op. - /// @return bool - /// - static bool IsOutput(ConstOpDescPtr op_desc); - - /// - /// @ingroup domi_ome - /// @brief Check is the Input need trans code. - /// @return bool - /// - static bool IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index); - /// /// @ingroup domi_ome /// @brief Get input size. @@ -82,14 +68,6 @@ class ModelUtils { /// static vector GetWeights(ConstOpDescPtr op_desc); - /// - /// @ingroup domi_ome - /// @brief Save Output tensor info to vector. - /// @return Status - /// - static Status GetOutputSize(ConstOpDescPtr op_desc, vector &output_size_list, - vector &output_memory_size_list); - /// /// @ingroup domi_ome /// @brief Get AiCpuOp Input descriptor. diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc index 75acf548..a7b169bf 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc @@ -44,6 +44,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin Status EndGraphTaskInfo::Distribute() { GELOGI("EndGraphTaskInfo Distribute Start."); + GE_CHECK_NOTNULL(davinci_model_); auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(davinci_model_->Name()) != all_dump_model.end()) { @@ -63,15 +64,17 @@ Status EndGraphTaskInfo::Distribute() { } uint32_t task_id = 0; - GE_CHECK_NOTNULL(davinci_model_); - rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id); + uint32_t stream_id = 0; + rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } task_id_ = task_id; + stream_id_ = stream_id; + davinci_model_->SetEndGraphId(task_id, stream_id); - GELOGI("EndGraphTaskInfo Distribute Success, task id is %u", task_id); + GELOGI("EndGraphTaskInfo Distribute Success, task id is %u, stream id is %u", task_id, stream_id); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h index 1c039172..49bef082 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h @@ -16,6 +16,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ + #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -31,11 +32,13 @@ class EndGraphTaskInfo : public TaskInfo { uint32_t GetTaskID() override { return task_id_; } + uint32_t GetStreamId() override { return stream_id_; } + private: rtModel_t model_; DavinciModel *davinci_model_; uint32_t task_id_; + uint32_t stream_id_; }; - } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index f65d05dd..77825991 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -24,6 +24,13 @@ #include "graph/load/new_model_manager/model_utils.h" namespace ge { +namespace { +const uint32_t kMaxTaskOfStream = 200; +} + +uint32_t HcclTaskInfo::max_node_of_hccl_stream_ = 0; +std::mutex HcclTaskInfo::hccl_follow_stream_mutex_; + HcclTaskInfo::~HcclTaskInfo() { if (private_def_ != nullptr) { rtError_t ret = rtFreeHost(private_def_); @@ -38,6 +45,7 @@ HcclTaskInfo::~HcclTaskInfo() { ops_kernel_store_ = nullptr; output_data_addr_ = nullptr; workspace_addr_ = nullptr; + max_node_of_hccl_stream_ = 0; } Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { @@ -63,17 +71,17 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m std::string hccl_type = hccl_def.hccl_type(); // Get HCCL op - auto op_desc = davinci_model->GetOpList()[op_index]; + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); Status dmrt = HcomOmeUtil::GetHcomDataType(op_desc, data_type); - if (dmrt != domi::SUCCESS) { + if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); return FAILED; } - dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == domi::HCOMALLGATHER), count); - if (dmrt != domi::SUCCESS) { + dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == HCOMALLGATHER), count); + if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); return FAILED; } @@ -109,7 +117,49 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m GELOGI("op_desc has no attr used_stream_num!"); } - for (int64_t i = 0; i < hccl_stream_num; ++i) { + std::lock_guard lock(hccl_follow_stream_mutex_); + if (max_node_of_hccl_stream_ == 0) { + uint32_t max_stream_count; + uint32_t max_task_count; + ret = rtGetMaxStreamAndTask(RT_NORMAL_STREAM, &max_stream_count, &max_task_count); + if (ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Get max stream and task count by rts failed."); + return FAILED; + } + max_node_of_hccl_stream_ = max_task_count / kMaxTaskOfStream; + } + + if (static_cast(hccl_stream_num) <= davinci_model->GetHcclFolowStream().size()) { + GELOGI("capacity of follow stream is enough to be reused."); + ReuseStream(hccl_stream_num, davinci_model); + } else { + GELOGI("need to reuse follow stream and create new follow stream."); + size_t created_stream_num = davinci_model->GetHcclFolowStream().size(); + ReuseStream(created_stream_num, davinci_model); + ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model); + if (ret != SUCCESS) { + GELOGE(FAILED, "Create hccl stream failed."); + return FAILED; + } + } + + GELOGI("HcclTaskInfo Init Success, hcclStreamNum =%ld", hccl_stream_num); + return SUCCESS; +} + +void HcclTaskInfo::ReuseStream(int64_t stream_num, DavinciModel *davinci_model) { + GELOGI("Start to reuse %ld follow stream.", stream_num); + int64_t index = 0; + for (int64_t i = 0; i < stream_num; i++) { + hccl_stream_list_.emplace_back(davinci_model->GetHcclFolowStream().at(index).first); + int64_t remain_cap = davinci_model->GetHcclFolowStream().at(index).second - 1; + davinci_model->ReuseHcclFollowStream(remain_cap, index); + } +} + +Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_model) { + GELOGI("Start to create %ld hccl stream.", stream_num); + for (int64_t i = 0; i < stream_num; ++i) { rtStream_t stream = nullptr; rtError_t rt_ret = rtStreamCreateWithFlags(&stream, davinci_model->Priority(), RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); @@ -126,11 +176,13 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m } GELOGD("hccl_stream addr is=%p", stream); - hccl_stream_list_.push_back(stream); + int64_t remain_cap = max_node_of_hccl_stream_ - 1; + davinci_model->CreateHcclFollowStream(stream, remain_cap); + + hccl_stream_list_.emplace_back(stream); davinci_model->PushHcclStream(stream); } - - GELOGI("HcclTaskInfo Init Success, hcclStreamNum =%ld", hccl_stream_num); + GELOGI("CreateStream success."); return SUCCESS; } @@ -170,28 +222,28 @@ Status HcclTaskInfo::SetAddrs(const std::string &hccl_type, const std::shared_pt output_data_addr = output_data_addr_list[0]; } - if (hccl_type == domi::HCOMBROADCAST) { + if (hccl_type == HCOMBROADCAST) { int64_t root_id; dmrt = HcomOmeUtil::GetHcomRootId(op_desc, root_id); - if (dmrt != domi::SUCCESS) { + if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); return FAILED; } root_id_ = root_id; - } else if (hccl_type == domi::HCOMALLGATHER || hccl_type == domi::HCOMRECEIVE) { + } else if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE) { output_data_addr_ = output_data_addr; - } else if (hccl_type == domi::HCOMALLREDUCE) { + } else if (hccl_type == HCOMALLREDUCE) { dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != domi::SUCCESS) { + if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); return FAILED; } output_data_addr_ = output_data_addr; op_type_ = op_type; - } else if (hccl_type == domi::HCOMREDUCESCATTER) { + } else if (hccl_type == HCOMREDUCESCATTER) { dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != domi::SUCCESS) { + if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); return FAILED; } @@ -200,6 +252,7 @@ Status HcclTaskInfo::SetAddrs(const std::string &hccl_type, const std::shared_pt op_type_ = op_type; } + davinci_model_->DisableZeroCopy(input_data_addr_); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index 1a2c508f..be033fac 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "common/opskernel/ge_task_info.h" #include "graph/load/new_model_manager/task_info/task_info.h" @@ -59,6 +60,10 @@ class HcclTaskInfo : public TaskInfo { void GetPrivateDefByTaskDef(const domi::TaskDef &task); + void ReuseStream(int64_t stream_num, DavinciModel *davinci_model); + + ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model); + DavinciModel *davinci_model_; string hccl_type_; void *input_data_addr_; @@ -74,6 +79,8 @@ class HcclTaskInfo : public TaskInfo { void *ops_kernel_store_; void *private_def_; uint32_t private_def_len_; + static std::mutex hccl_follow_stream_mutex_; + static uint32_t max_node_of_hccl_stream_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index 11b32be1..085b3ab4 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -19,16 +19,15 @@ #include #include "cce/aicpu_engine_struct.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/fmk_error_codes.h" #include "common/ge/ge_util.h" #include "common/properties_manager.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_error_codes.h" #include "graph/attr_value.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/model_manager.h" namespace ge { - Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("KernelExTaskInfo Init Start."); if (davinci_model == nullptr) { @@ -36,35 +35,37 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin return PARAM_INVALID; } - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); + davinci_model_ = davinci_model; + Status ret = SetStream(task_def.stream_id(), davinci_model_->GetStreamList()); if (ret != SUCCESS) { return ret; } auto kernel_ex_def = task_def.kernel_ex(); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); // 1. Copy context from kernelExDef.private to workspace uint32_t op_index = kernel_ex_def.op_index(); - OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); if (op_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); return INTERNAL_ERROR; } - if (CopyTaskInfo(kernel_ex_def, davinci_model->GetRuntimeParam(), op_desc) != SUCCESS) { + if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { GELOGE(FAILED, "copy task info to workspace failed."); return FAILED; } - vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); if (workspace_data_addrs.empty()) { GELOGE(FAILED, "workspace_data_addrs is empty."); return FAILED; } // 2. Reconstruct kernelExDef.args to STR_FWK_OP_KERNEL - STR_FWK_OP_KERNEL fwk_op_kernel; - if (sizeof(STR_FWK_OP_KERNEL) != kernel_ex_def.args_size()) { + STR_FWK_OP_KERNEL fwk_op_kernel = {0}; + if (sizeof(STR_FWK_OP_KERNEL) < kernel_ex_def.args_size()) { GELOGE(FAILED, "sizeof STR_FWK_OP_KERNEL is: %zu, but args_size is: %u", sizeof(STR_FWK_OP_KERNEL), kernel_ex_def.args_size()); return FAILED; @@ -78,18 +79,18 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin // 2.1 get loop cond variable for tensor array write uint64_t step_id_addr = 0; - OpDescPtr step_id_node = davinci_model->GetVariableOp(domi::NODE_NAME_GLOBAL_STEP); + OpDescPtr step_id_node = davinci_model_->GetVariableOp(NODE_NAME_GLOBAL_STEP); if (step_id_node != nullptr) { - vector v_step_id_addr = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), step_id_node); + vector v_step_id_addr = ModelUtils::GetOutputDataAddrs(rts_param, step_id_node); if (!v_step_id_addr.empty()) { step_id_addr = static_cast(reinterpret_cast(v_step_id_addr[0])); } } // 3. Set workspaceaddr, inputOutputDataAddr - uint64_t workspace_base_addr = reinterpret_cast(workspace_data_addrs[0]); - vector input_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - vector output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + uint64_t workspace_base_addr = reinterpret_cast(reinterpret_cast(workspace_data_addrs[0])); + const vector input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + const vector output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); vector io_addrs; io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); @@ -103,7 +104,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); return FAILED;) - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model->Name(), op_desc->GetName())) { + if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = reinterpret_cast(reinterpret_cast(input_output_addr_) + sizeof(void *) * input_addrs.size()); @@ -114,6 +115,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = workspace_base_addr; fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = input_output_addr; fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = step_id_addr; + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoNum = 0; + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; // 4. Create session auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; @@ -133,10 +136,15 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast(&fwk_op_kernel), sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) - davinci_model->SetZeroCopyAddr(op_desc, io_addrs, input_output_addr_); + + vector virtual_io_addrs; // use virtual address for zero copy key. + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); - davinci_model_ = davinci_model; GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); return SUCCESS; @@ -188,8 +196,8 @@ Status KernelExTaskInfo::Distribute() { } uint32_t task_id = 0; - uint32_t stream_id = UINT32_MAX; // default value, wait for rts - rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id); + uint32_t stream_id = 0; // for profiling + rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; @@ -197,7 +205,7 @@ Status KernelExTaskInfo::Distribute() { task_id_ = task_id; stream_id_ = stream_id; - GELOGI("KernelExTaskInfo Distribute Success. task id: %u", task_id_); + GELOGI("KernelExTaskInfo Distribute Success. task id: %u, stream id: %u", task_id_, stream_id_); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index 9aab55e7..a6419f9f 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -25,6 +25,7 @@ class KernelExTaskInfo : public TaskInfo { public: KernelExTaskInfo() : task_id_(0), + stream_id_(0), dump_flag_(RT_KERNEL_DEFAULT), kernel_buf_size_(0), davinci_model_(nullptr), diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 84710e41..5ed89cc6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -24,13 +24,13 @@ #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/l2_cache_optimize.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/model_utils.h" -#include "graph/debug/ge_attr_define.h" #include "runtime/kernel.h" -#include "graph/debug/ge_attr_define.h" -#include "super_kernel/super_kernel_factory.h" #include "super_kernel/super_kernel.h" +#include "super_kernel/super_kernel_factory.h" namespace { const uint8_t kL2LoadToDdr = 1; @@ -42,11 +42,12 @@ constexpr uint32_t kSKTMaxSizeLimit = 20000; const char *kIsLastNode = "is_last_node"; const char *kIsFirstNode = "is_first_node"; const int64_t kCloseSkt = 100; +const uint32_t kAddrLen = sizeof(void *); } // namespace namespace ge { KernelTaskInfo::SuperKernelTaskInfo KernelTaskInfo::skt_info_ = { - 0, 0, 0, nullptr, nullptr, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; + 0, 0, 0, 0, nullptr, nullptr, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { if (davinci_model == nullptr) { @@ -57,7 +58,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); GELOGD("KernelTaskInfo Init Start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); + Status ret = SetStream(task_def.stream_id(), davinci_model_->GetStreamList()); if (ret != SUCCESS) { return ret; } @@ -70,14 +71,14 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get kernel_type kernel_type_ = static_cast(context.kernel_type()); // get opdesc - op_desc_ = davinci_model->GetOpByIndex(context.op_index()); + op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); if (op_desc_ == nullptr) { GELOGE(INTERNAL_ERROR, "Get op_desc failed, index is out of range!"); return INTERNAL_ERROR; } (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); GELOGD("node[%s] is_n_batch_spilt %d", op_desc_->GetName().c_str(), is_n_batch_spilt_); - (void)AttrUtils::GetInt(*op_desc_, ATTR_NAME_L1_FUSION_GROUP_KEY, group_key_); + (void)AttrUtils::GetInt(*op_desc_, ATTR_NAME_FUSION_GROUP_KEY, group_key_); has_group_key_ = (group_key_ != kInvalidGroupKey); GELOGD("node[%s] has_group_key_ %ld, group key is [%ld]", op_desc_->GetName().c_str(), has_group_key_, group_key_); @@ -89,7 +90,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci fusion_op_info_.op_name = op_desc_->GetName()); string session_graph_model_id; - davinci_model->GetUniqueId(op_desc_, session_graph_model_id); + davinci_model_->GetUniqueId(op_desc_, session_graph_model_id); // get bin_file_key const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc_->GetName(), session_graph_model_id); // new aicpu kernel(rtCpuKernelLaunch) no need to check function @@ -124,17 +125,17 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci return FAILED; } - ret = InitTVMTask(davinci_model, args_offset_tmp[0], kernel_def); + ret = InitTVMTask(args_offset_tmp[0], kernel_def); } else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) { - ret = InitAICPUCustomTask(davinci_model->GetOpList(), context.op_index(), kernel_def); + ret = InitAICPUCustomTask(context.op_index(), kernel_def); } else if (kernel_type_ == cce::ccKernelType::AI_CPU) { - ret = InitAicpuTask(davinci_model->GetOpList(), context.op_index(), kernel_def); + ret = InitAicpuTask(context.op_index(), kernel_def); } else { if (kernel_def.args().empty() || args_size_ == 0) { GELOGE(FAILED, "args is null."); return FAILED; } - ret = InitCceTask(davinci_model, kernel_def); + ret = InitCceTask(kernel_def); } GELOGD("KernelTaskInfo Init finish, result=%u.", ret); @@ -143,36 +144,40 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci Status KernelTaskInfo::SaveSKTDumpInfo() { GE_CHECK_NOTNULL(davinci_model_); - davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_op, skt_info_.last_dump_args); + davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.last_op, + skt_info_.last_dump_args); return SUCCESS; } void KernelTaskInfo::UpdateSKTTaskId() { uint32_t task_id = 0; + uint32_t stream_id = 0; if (davinci_model_ != nullptr) { - rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id); + rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return; } skt_info_.last_task_id = task_id; + skt_info_.last_stream_id = stream_id; skt_id_ = skt_info_.last_task_id; - GELOGI("UpdateTaskId:UpdateSKTTaskId [%u]", task_id); + + GELOGI("UpdateTaskId:UpdateSKTTaskId [%u],stream id [%u]", task_id, stream_id); } } void KernelTaskInfo::UpdateTaskId() { uint32_t task_id = 0; - uint32_t stream_id = UINT32_MAX; // default value, wait for rts + uint32_t stream_id = 0; // for profiling if (davinci_model_ != nullptr) { - rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id); + rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return; } task_id_ = task_id; stream_id_ = stream_id; - GELOGI("UpdateTaskId:UpdateTaskId [%u]", task_id); + GELOGI("UpdateTaskId:UpdateTaskId [%u], stream id [%u]:", task_id, stream_id); } } @@ -221,13 +226,13 @@ Status KernelTaskInfo::SuperKernelLaunch() { return RT_FAILED; } // Call the fuse API - skt::SuperKernel *superKernel; + skt::SuperKernel *superKernel = nullptr; if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); return RT_FAILED; } // Launch a super kernel - if (superKernel->Launch(skt_info_.last_stream, true) != SUCCESS) { + if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); return RT_FAILED; } @@ -341,6 +346,7 @@ Status KernelTaskInfo::Distribute() { rtError_t rt_ret = RT_ERROR_NONE; char *skt_enable_env = getenv("SKT_ENABLE"); int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; + bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); if (kernel_type_ == cce::ccKernelType::AI_CPU) { // blockDim is reserved parameter, set to 1 rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), @@ -348,11 +354,10 @@ Status KernelTaskInfo::Distribute() { nullptr, stream_, dump_flag_); } else { /* default: not skt launch */ - bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); GELOGI( - "KernelTaskInfo Distribute Start, sktenable:%ld taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " + "KernelTaskInfo Distribute Start, sktenable:%d taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " "stubfunc:%p blockdim:%u stream:%p", - env_flag, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); + call_skt, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); // l1 fusion enable and env flag open (kCloseSkt for skt debug) if (call_skt && (env_flag != kCloseSkt)) { GE_RETURN_IF_ERROR(SuperKernelDistribute()); @@ -371,7 +376,7 @@ Status KernelTaskInfo::Distribute() { GELOGI( "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " "blockdim:%d stream:%p", - env_flag, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); + call_skt, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); return SUCCESS; } @@ -399,13 +404,48 @@ Status KernelTaskInfo::Release() { return SUCCESS; } -Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, const domi::KernelDef &kernel_def) { +Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) { + string sm_desc = kernel_def.sm_desc(); + if (sm_desc.empty()) { + return SUCCESS; + } + + char *sm_contrl = const_cast(sm_desc.data()); + rtL2Ctrl_t *l2_ctrl_info = reinterpret_cast(sm_contrl); + uint64_t gen_base_addr = davinci_model_->GetRtBaseAddr(); + + // There is no weight for te op now. Update L2_mirror_addr by data memory base. + uint64_t data_base_addr = (uint64_t)(uintptr_t)davinci_model_->MemBase() - (uint64_t)gen_base_addr; + const uint32_t l2_ctrl_info_data_count = 8; + for (uint32_t data_index = 0; data_index < l2_ctrl_info_data_count; ++data_index) { + if (l2_ctrl_info->data[data_index].L2_mirror_addr != 0) { + l2_ctrl_info->data[data_index].L2_mirror_addr += data_base_addr; + l2_ctrl_info->data[data_index].L2_load_to_ddr = IsL2CpToDDR(l2_ctrl_info->data[data_index].L2_load_to_ddr); + } + } + + rtError_t rt_ret = rtMemAllocManaged(&sm_desc_, sm_desc.size(), RT_MEMORY_SPM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + rt_ret = rtMemcpy(sm_desc_, sm_desc.size(), sm_desc.data(), sm_desc.size(), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + return SUCCESS; +} + +Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kernel_def) { GELOGD("Do InitTVMTask."); - GE_CHECK_NOTNULL(davinci_model); + GE_CHECK_NOTNULL(davinci_model_); // get tvm op desc - OpDescPtr op_desc = davinci_model->GetOpByIndex(ctx_.opIndex); + OpDescPtr op_desc = davinci_model_->GetOpByIndex(ctx_.opIndex); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "InitTVMTaskInfo error, index is out of range!"); + GELOGE(INTERNAL_ERROR, "InitTVMTaskInfo error, index:%u out of range!", ctx_.opIndex); return INTERNAL_ERROR; } @@ -414,21 +454,19 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, // and does not need to be modified. // When inferencing, stub_func_ is different from dynamic-registration to runtime, and needs to be modified. string session_graph_model_id; - const char *bin_file_key; - davinci_model->GetUniqueId(op_desc, session_graph_model_id); - bin_file_key = DavinciModel::GetRegisterStub(op_desc->GetName(), session_graph_model_id); - rtError_t rt_ret; - rt_ret = rtQueryFunctionRegistered(const_cast(bin_file_key)); + davinci_model_->GetUniqueId(op_desc, session_graph_model_id); + const char *bin_file_key = DavinciModel::GetRegisterStub(op_desc->GetName(), session_graph_model_id); + rtError_t rt_ret = rtQueryFunctionRegistered(const_cast(bin_file_key)); if (rt_ret != RT_ERROR_NONE) { stub_func_ = const_cast(bin_file_key); } - const vector input_data_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - const vector workspace_data_addrs = - ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - vector tensor_device_addrs; + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + const vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + vector tensor_device_addrs; tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); @@ -441,81 +479,67 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, } // copy orign args - rt_ret = rtMemcpy(args_, args_size_, static_cast(const_cast(kernel_def.args().data())), args_size_, - RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(args_, args_size_, kernel_def.args().data(), args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } + vector args_info(args_size_); + errno_t sec_ret = memcpy_s(args_info.data(), args_size_, kernel_def.args().data(), args_size_); + if (sec_ret != EOK) { + GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); + return FAILED; + } - if (args_size_ <= static_cast(offset) || - args_size_ - static_cast(offset) < static_cast(sizeof(void *) * tensor_device_addrs.size())) { + if ((args_size_ <= offset) || (args_size_ - offset < kAddrLen * tensor_device_addrs.size())) { GELOGE(FAILED, "offset >= kernelInfo.argsSize or copy content beyond applied memory."); return FAILED; } // copy args - rt_ret = rtMemcpy(static_cast(args_) + offset, sizeof(void *) * tensor_device_addrs.size(), - tensor_device_addrs.data(), sizeof(void *) * tensor_device_addrs.size(), RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(static_cast(args_) + offset, args_size_ - offset, tensor_device_addrs.data(), + kAddrLen * tensor_device_addrs.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } + sec_ret = memcpy_s(args_info.data() + offset, args_size_ - offset, tensor_device_addrs.data(), + kAddrLen * tensor_device_addrs.size()); + if (sec_ret != EOK) { + GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); + return FAILED; + } - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model->Name(), op_desc->GetName())) { + if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; - dump_args_ = - reinterpret_cast(reinterpret_cast(args_) + offset + sizeof(void *) * input_data_addrs.size()); + dump_args_ = static_cast(args_) + offset + kAddrLen * input_data_addrs.size(); } - davinci_model_->SetZeroCopyAddr(op_desc, tensor_device_addrs, static_cast(args_) + offset); // update origin l2 data - string sm_desc = kernel_def.sm_desc(); - char *sm_contrl = nullptr; - rtL2Ctrl_t *l2_ctrl_info = nullptr; - if (!sm_desc.empty()) { - sm_contrl = const_cast(sm_desc.data()); - l2_ctrl_info = reinterpret_cast(sm_contrl); - - uint64_t gen_base_addr = davinci_model->GetRtBaseAddr(); - - // There is no weight for te op now. Update L2_mirror_addr by data memory base. - uint64_t data_base_addr = (uint64_t)(uintptr_t)davinci_model->MemBase() - (uint64_t)gen_base_addr; - const uint32_t l2_ctrl_info_data_count = 8; - for (uint32_t data_index = 0; data_index < l2_ctrl_info_data_count; ++data_index) { - if (l2_ctrl_info->data[data_index].L2_mirror_addr != 0) { - l2_ctrl_info->data[data_index].L2_mirror_addr += data_base_addr; - l2_ctrl_info->data[data_index].L2_load_to_ddr = IsL2CpToDDR(l2_ctrl_info->data[data_index].L2_load_to_ddr); - } - } + if (UpdateL2Data(kernel_def) != SUCCESS) { + return RT_FAILED; + } - rt_ret = rtMemAllocManaged(&sm_desc_, sm_desc.size(), RT_MEMORY_SPM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } + vector virtual_io_addrs; // use virtual address for zero copy key. + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_info.data(), args_, args_size_, offset); - rt_ret = rtMemcpy(sm_desc_, sm_desc.size(), sm_desc.data(), sm_desc.size(), RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - } GELOGD("Do InitTVMTask end"); return SUCCESS; } -Status KernelTaskInfo::InitAICPUCustomTask(const std::map> &op_list, - uint32_t op_index, const domi::KernelDef &kernel_def) { +Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::KernelDef &kernel_def) { GELOGI("Do InitAICPUCustomTask"); - - auto iter = op_list.find(op_index); - if (iter == op_list.end()) { + OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); + if (op_desc == nullptr) { GELOGE(INTERNAL_ERROR, "index is out of range, index: %u", op_index); return INTERNAL_ERROR; } - auto op_desc = iter->second; + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); const domi::KernelContext &context = kernel_def.context(); const uint32_t kCustomAicpuArgsLen = 5; @@ -534,11 +558,8 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map(const_cast(context.args_offset().data())))[i]; } - const std::vector input_data_addrs = - ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - const std::vector output_data_addrs = - ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - + const std::vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + const std::vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), ModelUtils::GetOutputDescs(op_desc)); @@ -549,7 +570,7 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map(args + ctx_.argsOffset[0])) = - reinterpret_cast(custom_info_.input_descs); // arg 0 + reinterpret_cast(reinterpret_cast(custom_info_.input_descs)); // arg 0 *(reinterpret_cast(args + ctx_.argsOffset[1])) = - reinterpret_cast(custom_info_.input_addrs); // arg 1 + reinterpret_cast(reinterpret_cast(custom_info_.input_addrs)); // arg 1 *(reinterpret_cast(args + ctx_.argsOffset[2])) = - reinterpret_cast(custom_info_.output_descs); // arg 2 + reinterpret_cast(reinterpret_cast(custom_info_.output_descs)); // arg 2 *(reinterpret_cast(args + ctx_.argsOffset[3])) = - reinterpret_cast(custom_info_.output_addrs); // arg 3 + reinterpret_cast(reinterpret_cast(custom_info_.output_addrs)); // arg 3 *(reinterpret_cast(args + ctx_.argsOffset[4])) = - reinterpret_cast(custom_info_.attr_handle); // arg 4 + reinterpret_cast(reinterpret_cast(custom_info_.attr_handle)); // arg 4 rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { @@ -606,14 +627,18 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::mapSetZeroCopyAddr(op_desc, input_data_addrs, custom_info_.input_addrs); - davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, custom_info_.output_addrs); + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, input_data_addrs.data(), custom_info_.input_addrs, + virtual_in_addrs.size() * kAddrLen, 0); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, output_data_addrs.data(), custom_info_.output_addrs, + output_data_addrs.size() * kAddrLen, 0); return SUCCESS; } -Status KernelTaskInfo::InitCceTask(DavinciModel *davinci_model, const domi::KernelDef &kernel_def) { +Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { GELOGI("Do InitCCETask"); - if (davinci_model == nullptr) { + if (davinci_model_ == nullptr) { GELOGE(PARAM_INVALID, "davinci_model is null!"); return PARAM_INVALID; } @@ -639,15 +664,15 @@ Status KernelTaskInfo::InitCceTask(DavinciModel *davinci_model, const domi::Kern uint64_t sm_contrl_size = sm_desc.empty() ? 0 : sizeof(rtSmDesc_t); // Passing the memory info when the offline-model-generated to the CCE, which uses this info for address refresh - ctx_.genDataBaseAddr = davinci_model->GetRtBaseAddr(); - ctx_.genDataBaseSize = davinci_model->TotalMemSize(); - ctx_.genWeightBaseAddr = davinci_model->GetRtWeightAddr(); - ctx_.genWeightBaseSize = davinci_model->TotalWeightsMemSize(); - ctx_.genVariableBaseAddr = davinci_model->GetRtVarAddr(); - ctx_.genVariableBaseSize = davinci_model->TotalVarMemSize(); + ctx_.genDataBaseAddr = davinci_model_->GetRtBaseAddr(); + ctx_.genDataBaseSize = davinci_model_->TotalMemSize(); + ctx_.genWeightBaseAddr = davinci_model_->GetRtWeightAddr(); + ctx_.genWeightBaseSize = davinci_model_->TotalWeightsMemSize(); + ctx_.genVariableBaseAddr = davinci_model_->GetRtVarAddr(); + ctx_.genVariableBaseSize = davinci_model_->TotalVarMemSize(); ctx_.l2ctrlSize = sm_contrl_size; - if (UpdateCceArgs(sm_desc, flowtable, davinci_model, kernel_def) != SUCCESS) { + if (UpdateCceArgs(sm_desc, flowtable, kernel_def) != SUCCESS) { GELOGE(ret, "update cce args fail"); return ret; } @@ -691,14 +716,13 @@ Status KernelTaskInfo::InitCceTask(DavinciModel *davinci_model, const domi::Kern return SUCCESS; } -Status KernelTaskInfo::InitAicpuTask(const std::map &op_list, uint32_t op_index, - const domi::KernelDef &kernel_def) { +Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &kernel_def) { GELOGI("Do InitAicpuTask"); so_name_ = kernel_def.so_name(); kernel_name_ = kernel_def.kernel_name(); - auto iter = op_list.find(op_index); - if (iter == op_list.end()) { + OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); + if (op_desc == nullptr) { GELOGE(INTERNAL_ERROR, "index is out of range, index: %u", op_index); return INTERNAL_ERROR; } @@ -706,25 +730,24 @@ Status KernelTaskInfo::InitAicpuTask(const std::map &op_lis // copy args to new host memory std::unique_ptr args_addr(new (std::nothrow) uint8_t[args_size_]); GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) - errno_t sec_ret = memcpy_s(static_cast(args_addr.get()), args_size_, - static_cast(kernel_def.args().data()), args_size_); + errno_t sec_ret = memcpy_s(args_addr.get(), args_size_, kernel_def.args().data(), args_size_); if (sec_ret != EOK) { GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); return FAILED; } - OpDescPtr op_desc = iter->second; - vector input_addrs = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - vector output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + + vector input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + vector output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); vector io_addrs; io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); if (!io_addrs.empty()) { // refresh io addrs - uintptr_t io_addr = - reinterpret_cast(args_addr.get()) + static_cast(sizeof(aicpu::AicpuParamHead)); - auto addrs_size = sizeof(uint64_t) * (io_addrs.size()); - sec_ret = memcpy_s(reinterpret_cast(io_addr), addrs_size, static_cast(io_addrs.data()), addrs_size); + uintptr_t io_addr = reinterpret_cast(args_addr.get()) + sizeof(aicpu::AicpuParamHead); + auto addrs_size = sizeof(uint64_t) * io_addrs.size(); + sec_ret = memcpy_s(reinterpret_cast(io_addr), addrs_size, io_addrs.data(), addrs_size); if (sec_ret != EOK) { GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); return FAILED; @@ -740,7 +763,7 @@ Status KernelTaskInfo::InitAicpuTask(const std::map &op_lis GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "cce task physical memory.", args_size_) // copy args to device - rt_ret = rtMemcpy(args_, args_size_, static_cast(args_addr.get()), args_size_, RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(args_, args_size_, args_addr.get(), args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret); return RT_FAILED; @@ -748,11 +771,17 @@ Status KernelTaskInfo::InitAicpuTask(const std::map &op_lis if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; - dump_args_ = reinterpret_cast(reinterpret_cast(args_) + sizeof(aicpu::AicpuParamHead) + - sizeof(void *) * input_addrs.size()); + dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead) + kAddrLen * input_addrs.size(); } - davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, static_cast(args_) + sizeof(aicpu::AicpuParamHead)); + vector virtual_io_addrs; // use virtual address for zero copy key. + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_addr.get(), args_, args_size_, + sizeof(aicpu::AicpuParamHead)); + return SUCCESS; } @@ -787,8 +816,8 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d } if (!input_data_addrs.empty()) { - rt_ret = rtMemcpy(custom_info_.input_addrs, sizeof(void *) * input_size, &input_data_addrs[0], - sizeof(void *) * input_size, RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(custom_info_.input_addrs, kAddrLen * input_size, &input_data_addrs[0], kAddrLen * input_size, + RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; @@ -818,8 +847,8 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d } if (!output_data_addrs.empty()) { - rt_ret = rtMemcpy(custom_info_.output_addrs, sizeof(void *) * output_size, &output_data_addrs[0], - sizeof(void *) * output_size, RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy(custom_info_.output_addrs, kAddrLen * output_size, &output_data_addrs[0], kAddrLen * output_size, + RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; @@ -872,17 +901,14 @@ void KernelTaskInfo::FreeRtMem(void **ptr) { *ptr = nullptr; } -Status KernelTaskInfo::UpdateCceArgs(std::string &sm_desc, std::string &flowtable, DavinciModel *davinci_model, - const domi::KernelDef &kernel_def) { - GE_CHECK_NOTNULL(davinci_model); +Status KernelTaskInfo::UpdateCceArgs(std::string &sm_desc, std::string &flowtable, const domi::KernelDef &kernel_def) { + GE_CHECK_NOTNULL(davinci_model_); const domi::KernelContext &context = kernel_def.context(); - uint64_t data_base_addr = - reinterpret_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->VarMemBase())) - - davinci_model->GetRtVarAddr(); + uint64_t data_base_addr = reinterpret_cast(davinci_model_->MemBase()) - davinci_model_->GetRtBaseAddr(); + uint64_t weight_base_addr = + reinterpret_cast(davinci_model_->WeightsMemBase()) - davinci_model_->GetRtWeightAddr(); + uint64_t var_base_addr = reinterpret_cast(davinci_model_->VarMemBase()) - davinci_model_->GetRtVarAddr(); Status status = CceUpdateKernelArgs(context, data_base_addr, weight_base_addr, var_base_addr, sm_desc, flowtable, kernel_def); @@ -904,7 +930,7 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u std::string file_name = "libcce.so"; std::string path = PluginManager::GetPath(); path.append(file_name); - string canonicalPath = domi::RealPath(path.c_str()); + string canonicalPath = RealPath(path.c_str()); if (canonicalPath.empty()) { GELOGW("failed to get realpath of %s", path.c_str()); return FAILED; @@ -977,7 +1003,7 @@ Status KernelTaskInfo::SetFlowtable(std::string &flowtable, const domi::KernelDe *(reinterpret_cast( args + (reinterpret_cast(const_cast(context.args_offset().data())))[0])) = - reinterpret_cast(flowtable_); + reinterpret_cast(reinterpret_cast(flowtable_)); } return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index 5de622eb..234c25f4 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -88,15 +88,13 @@ class KernelTaskInfo : public TaskInfo { FusionOpInfo fusion_op_info_; private: - Status InitTVMTask(DavinciModel *davinci_model, uint16_t offset, const domi::KernelDef &kernel_def); + Status InitTVMTask(uint16_t offset, const domi::KernelDef &kernel_def); - Status InitAICPUCustomTask(const std::map> &op_list, uint32_t op_index, - const domi::KernelDef &kernel_def); + Status InitAICPUCustomTask(uint32_t op_index, const domi::KernelDef &kernel_def); - Status InitCceTask(DavinciModel *davinci_model, const domi::KernelDef &kernel_def); + Status InitCceTask(const domi::KernelDef &kernel_def); - Status InitAicpuTask(const std::map &op_list, uint32_t op_index, - const domi::KernelDef &kernel_def); + Status InitAicpuTask(uint32_t op_index, const domi::KernelDef &kernel_def); Status StoreInputOutputTensor(const std::vector &input_data_addrs, const std::vector &output_data_addrs, @@ -105,14 +103,15 @@ class KernelTaskInfo : public TaskInfo { Status SetContext(const domi::KernelDef &kernel_def); - Status UpdateCceArgs(std::string &sm_desc, std::string &flowtable, DavinciModel *davinci_model, - const domi::KernelDef &kernel_def); + Status UpdateCceArgs(std::string &sm_desc, std::string &flowtable, const domi::KernelDef &kernel_def); Status CceUpdateKernelArgs(const domi::KernelContext &context, uint64_t &data_base_addr, uint64_t &weight_base_addr, uint64_t &var_base_addr, std::string &sm_desc, std::string &flowtable, const domi::KernelDef &kernel_def); Status SetFlowtable(std::string &flowtable, const domi::KernelDef &kernel_def); + Status UpdateL2Data(const domi::KernelDef &kernel_def); + uint8_t IsL2CpToDDR(uint8_t origain_L2_load_to_ddr); static void FreeRtMem(void **ptr); @@ -169,6 +168,7 @@ class KernelTaskInfo : public TaskInfo { uint32_t last_block_dim; uint32_t last_args_size; uint32_t last_task_id; + uint32_t last_stream_id; void *last_stream; void *last_sm_desc; std::vector kernel_list; diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc new file mode 100644 index 00000000..c157b1df --- /dev/null +++ b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/new_model_manager/task_info/label_goto_ex_task_info.h" + +#include "graph/load/new_model_manager/davinci_model.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +Status LabelGotoExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GELOGI("LabelGotoExTaskInfo Init Start."); + GE_CHECK_NOTNULL(davinci_model); + + if (SetStream(task_def.stream_id(), davinci_model->GetStreamList()) != SUCCESS) { + return FAILED; + } + + // Get LabelGoto task def + const domi::LabelGotoExDef &label_goto = task_def.label_goto_ex(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(label_goto.op_index()); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range!", label_goto.op_index()); + return INTERNAL_ERROR; + } + + uint32_t label_index = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, label_index)) { + GELOGE(INTERNAL_ERROR, "LabelGotoExTaskInfo: %s attr [%s] not exist.", op_desc->GetName().c_str(), + ATTR_NAME_LABEL_SWITCH_INDEX.c_str()); + return INTERNAL_ERROR; + } + + const vector &label_list = davinci_model->GetLabelList(); + if (label_index >= label_list.size()) { + GELOGE(PARAM_INVALID, "LabelGotoExTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); + return INTERNAL_ERROR; + } + label_ = label_list[label_index]; + + GELOGI("LabelGotoExTaskInfo Init Success, label id:%u, label:%p.", label_index, label_); + return SUCCESS; +} + +Status LabelGotoExTaskInfo::Distribute() { + GELOGI("LabelGotoExTaskInfo Distribute Start."); + rtError_t rt_ret = rtLabelGotoEx(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + GELOGI("LabelGotoExTaskInfo Distribute Success."); + return SUCCESS; +} + +REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_LABEL_GOTO, LabelGotoExTaskInfo); +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h similarity index 75% rename from src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.h rename to src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h index ac78cbe2..c8a695c9 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.h @@ -14,17 +14,17 @@ * limitations under the License. */ -#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_TASK_INFO_H_ -#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_TASK_INFO_H_ +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { -class LabelGotoTaskInfo : public TaskInfo { +class LabelGotoExTaskInfo : public TaskInfo { public: - LabelGotoTaskInfo() : label_(nullptr) {} + LabelGotoExTaskInfo() : label_(nullptr) {} - ~LabelGotoTaskInfo() override { label_ = nullptr; } + ~LabelGotoExTaskInfo() override { label_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -34,4 +34,4 @@ class LabelGotoTaskInfo : public TaskInfo { void *label_; }; } // namespace ge -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_TASK_INFO_H_ +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_GOTO_EX_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.cc deleted file mode 100644 index 9124be9f..00000000 --- a/src/ge/graph/load/new_model_manager/task_info/label_goto_task_info.cc +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/load/new_model_manager/task_info/label_goto_task_info.h" - -#include "framework/common/debug/ge_log.h" -#include "graph/load/new_model_manager/davinci_model.h" - -namespace ge { -Status LabelGotoTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("LabelGotoTaskInfo Init Start."); - if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); - return PARAM_INVALID; - } - - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); - if (ret != SUCCESS) { - return ret; - } - - if (!davinci_model->GetLabelList().empty()) { - label_ = davinci_model->GetLabelList().back(); - } - - return SUCCESS; -} - -Status LabelGotoTaskInfo::Distribute() { - GELOGI("LabelGotoTaskInfo Distribute Start."); - rtError_t rt_ret = rtLabelGoto(label_, stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - GELOGI("LabelGotoTaskInfo Distribute Success."); - return SUCCESS; -} - -REGISTER_TASK_INFO(RT_MODEL_TASK_LABEL_GOTO, LabelGotoTaskInfo); -} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc index 75679ec4..e8888eef 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc @@ -16,32 +16,41 @@ #include "graph/load/new_model_manager/task_info/label_set_task_info.h" -#include "framework/common/debug/ge_log.h" #include "graph/load/new_model_manager/davinci_model.h" +#include "graph/debug/ge_attr_define.h" namespace ge { Status LabelSetTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("LabelSetTaskInfo Init Start."); - if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); - return PARAM_INVALID; + GE_CHECK_NOTNULL(davinci_model); + + if (SetStream(task_def.stream_id(), davinci_model->GetStreamList()) != SUCCESS) { + return FAILED; } - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); - if (ret != SUCCESS) { - return ret; + // Get LabelSet task def + const domi::LabelSetDef &label_set = task_def.label_set(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(label_set.op_index()); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range!", label_set.op_index()); + return INTERNAL_ERROR; } - uint32_t label_id = task_def.label_id(); - if (label_id > davinci_model->BatchNum()) { - GELOGE(PARAM_INVALID, "labelId is invalid! labelId=%u, labelListSize=%u", label_id, davinci_model->BatchNum()); - return PARAM_INVALID; + uint32_t label_index = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, label_index)) { + GELOGE(INTERNAL_ERROR, "LabelSetTaskInfo: %s attr [%s] not exist.", op_desc->GetName().c_str(), + ATTR_NAME_LABEL_SWITCH_INDEX.c_str()); + return INTERNAL_ERROR; } - if (!davinci_model->GetLabelList().empty()) { - label_ = davinci_model->GetLabelList()[label_id]; + const vector &label_list = davinci_model->GetLabelList(); + if (label_index >= label_list.size()) { + GELOGE(INTERNAL_ERROR, "LabelSetTaskInfo: Invalid label id:%u, label size:%zu", label_index, label_list.size()); + return INTERNAL_ERROR; } + label_ = label_list[label_index]; + GELOGI("LabelSetTaskInfo Init success, label id:%u, label:%p.", label_index, label_); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc new file mode 100644 index 00000000..818307eb --- /dev/null +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc @@ -0,0 +1,128 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h" + +#include "graph/load/new_model_manager/davinci_model.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +constexpr uint8_t kLabelSwitchIndexNum = 1; + +LabelSwitchByIndexTaskInfo::~LabelSwitchByIndexTaskInfo() { + if (args_ != nullptr) { + rtError_t ret = rtFree(args_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); + } + } + args_ = nullptr; + index_value_ = nullptr; +} + +Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GELOGI("LabelSwitchByIndexTaskInfo Init Start."); + GE_CHECK_NOTNULL(davinci_model); + + const vector &label_list = davinci_model->GetLabelList(); + Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); + if (ret != SUCCESS) { + return FAILED; + } + + // Get LabelSwitch task def + const domi::LabelSwitchByIndexDef &label_switch = task_def.label_switch_by_index(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(label_switch.op_index()); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range!", label_switch.op_index()); + return INTERNAL_ERROR; + } + + branch_max_ = label_switch.label_max(); + + auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + if (input_data_addr.size() != kLabelSwitchIndexNum) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s invalid addr size: %zu, num: %u!", + op_desc->GetName().c_str(), input_data_addr.size(), kLabelSwitchIndexNum); + return INTERNAL_ERROR; + } + index_value_ = input_data_addr[0]; + davinci_model->DisableZeroCopy(index_value_); + + std::vector label_idx_list; + if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, label_idx_list)) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s Get attr %s failed.", op_desc->GetName().c_str(), + ATTR_NAME_LABEL_SWITCH_LIST.c_str()); + return INTERNAL_ERROR; + } + + if (label_idx_list.empty() || label_idx_list.size() != branch_max_) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s label index size: %zu, task branch max: %u.", + op_desc->GetName().c_str(), label_idx_list.size(), branch_max_); + return INTERNAL_ERROR; + } + + label_list_.resize(branch_max_, nullptr); + for (size_t idx = 0; idx < label_idx_list.size(); ++idx) { + uint32_t label_id = label_idx_list[idx]; + if (label_id >= label_list.size()) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndexTaskInfo: %s index: %zu, label index: %u, model label size: %zu.", + op_desc->GetName().c_str(), idx, label_id, label_list.size()); + return INTERNAL_ERROR; + } + GE_CHECK_NOTNULL(label_list[label_id]); + + label_list_[idx] = label_list[label_id]; + } + + args_size_ = branch_max_ * sizeof(rtLabelDevInfo); + rtError_t rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + GELOGI("LabelSwitchByIndexTaskInfo Init success, branch max: %u.", branch_max_); + return SUCCESS; +} + +Status LabelSwitchByIndexTaskInfo::Distribute() { + GELOGI("LabelSwitchByIndexTaskInfo Distribute Start, branch max: %u", branch_max_); + GE_CHECK_NOTNULL(args_); + GE_CHECK_NOTNULL(index_value_); + if (branch_max_ == 0 || args_size_ == 0) { + GELOGE(PARAM_INVALID, "branch max: %u, args size: %u invalid.", branch_max_, args_size_); + return PARAM_INVALID; + } + + rtError_t rt_ret = rtLabelSwitchByIndex(index_value_, branch_max_, args_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + GELOGI("LabelSwitchByIndexTaskInfo Distribute Success."); + return SUCCESS; +} + +REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, LabelSwitchByIndexTaskInfo); +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h new file mode 100644 index 00000000..1a644736 --- /dev/null +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ + +#include "graph/load/new_model_manager/task_info/task_info.h" + +namespace ge { +class LabelSwitchByIndexTaskInfo : public TaskInfo { + public: + LabelSwitchByIndexTaskInfo() : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0) {} + + ~LabelSwitchByIndexTaskInfo() override; + + Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + Status Distribute() override; + + private: + void *index_value_; // switch index input. + uint32_t branch_max_; // max branch count. + void *args_; // label info memory. + uint32_t args_size_; // label info length. + + std::vector label_list_; +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc new file mode 100644 index 00000000..e9d99189 --- /dev/null +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc @@ -0,0 +1,151 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h" + +#include "framework/common/debug/ge_log.h" +#include "graph/load/new_model_manager/davinci_model.h" + +namespace ge { +Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); + if (davinci_model == nullptr) { + GELOGE(PARAM_INVALID, "davinci_model is null!"); + return PARAM_INVALID; + } + + Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); + if (ret != SUCCESS) { + return ret; + } + + auto memcpy_async_def = task_def.memcpy_async(); + uint32_t op_index = memcpy_async_def.op_index(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + if (op_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Init MemcpyAddrAsyncTaskInfo error, index is out of range!"); + return INTERNAL_ERROR; + } + + uint64_t logic_dst = memcpy_async_def.dst(); + uint64_t logic_src = memcpy_async_def.src(); + + dst_max_ = memcpy_async_def.dst_max(); + + uint64_t update_base_addr = 0; + ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); + if (ret != SUCCESS) { + return ret; + } + src_ = reinterpret_cast(update_base_addr + logic_src); + if (src_ == nullptr) { + GELOGE(PARAM_INVALID, "src_ is null!"); + return PARAM_INVALID; + } + + uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); + uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); + dst_ = reinterpret_cast(reinterpret_cast(mem_base + (logic_dst - logic_mem_base))); + if (dst_ == nullptr) { + GELOGE(PARAM_INVALID, "dst_ is null!"); + return PARAM_INVALID; + } + + vector io_addrs; + io_addrs.emplace_back(src_); + io_addrs.emplace_back(dst_); + + count_ = memcpy_async_def.count(); + kind_ = memcpy_async_def.kind(); + + // malloc args memory + size_t args_size = sizeof(void *) * io_addrs.size(); + rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + // copy orign src/dst + GELOGI("src_args:%p, destMax:%zu, src_:%p, dst_args:%p, dst_:%p, count=%zu", args_, args_size, src_, + static_cast(args_) + args_size, dst_, io_addrs.size()); + rt_ret = rtMemcpy(args_, args_size, io_addrs.data(), args_size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api for src failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + // Just dest addr need zero copy. + davinci_model->SetZeroCopyAddr(op_desc, {dst_}, io_addrs.data(), args_, args_size, sizeof(void *)); + + GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", + reinterpret_cast(reinterpret_cast(logic_src)), + reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_, args_, + reinterpret_cast(reinterpret_cast(args_) + args_size)); + + return SUCCESS; +} + +Status MemcpyAddrAsyncTaskInfo::Distribute() { + GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); + GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); + + rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_) + sizeof(void *)), + dst_max_, args_, count_, static_cast(kind_), stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + return SUCCESS; +} + +Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, + uint64_t &base_addr) { + GE_CHECK_NOTNULL(davinci_model); + uint64_t data_base_addr = + static_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); + uint64_t weight_base_addr = static_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - + davinci_model->GetRtWeightAddr(); + uint64_t var_base_addr = + static_cast(reinterpret_cast(davinci_model->VarMemBase())) - davinci_model->GetRtVarAddr(); + + uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); + uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); + uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); + uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); + uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); + uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); + + if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { + base_addr = data_base_addr; + GELOGI("The update_addr is data address."); + } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { + base_addr = weight_base_addr; + GELOGI("The update_addr is weight address."); + } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { + base_addr = var_base_addr; + GELOGI("The update_addr is variable address."); + } else if (update_addr != 0) { + base_addr = 0; + GELOGE(PARAM_INVALID, "The update_addr is abnormal."); + return PARAM_INVALID; + } + return SUCCESS; +} + +REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h new file mode 100644 index 00000000..9252e43a --- /dev/null +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ +#include "graph/load/new_model_manager/task_info/task_info.h" + +namespace ge { +class MemcpyAddrAsyncTaskInfo : public TaskInfo { + public: + MemcpyAddrAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), count_(0), kind_(0) {} + + ~MemcpyAddrAsyncTaskInfo() override { + src_ = nullptr; + dst_ = nullptr; + + if (args_ != nullptr) { + rtError_t ret = rtFree(args_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); + } + } + + args_ = nullptr; + } + + Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + Status Distribute() override; + + private: + Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); + + void *dst_; + uint64_t dst_max_; + void *src_; + void *args_; + uint64_t count_; + uint32_t kind_; +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index f2621c52..82eabe69 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -44,6 +44,7 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return ret; } src_ = reinterpret_cast(update_base_addr + logic_src); + davinci_model->DisableZeroCopy(src_); uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); @@ -52,7 +53,8 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da count_ = memcpy_async_def.count(); kind_ = memcpy_async_def.kind(); GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", - reinterpret_cast(logic_src), reinterpret_cast(logic_dst), src_, dst_); + reinterpret_cast(reinterpret_cast(logic_src)), + reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc index 21c80c83..c30cad09 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc @@ -41,7 +41,7 @@ Status StreamActiveTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d uint32_t internal_index = davinci_model->GetFlowctrlIndex(op_index); // get StreamActive op - auto op_desc = davinci_model->GetOpList()[op_index]; + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); std::vector active_stream_index_list; if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_index_list)) { diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc index 4e37ab64..a1d2f143 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc @@ -44,10 +44,10 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d uint32_t op_index = stream_switch_def.op_index(); // get StreamSwitch op - auto op_desc = davinci_model->GetOpList()[op_index]; + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (!input_data_addr.empty() && input_data_addr.size() >= domi::STREAM_SWITCH_INPUT_NUM) { + if (!input_data_addr.empty() && input_data_addr.size() >= STREAM_SWITCH_INPUT_NUM) { input_ptr_ = input_data_addr[0]; value_ptr_ = input_data_addr[1]; } @@ -60,9 +60,9 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d cond_ = static_cast(cond); size_t input_size = op_desc->GetInputsSize(); - if (input_data_addr.size() != domi::STREAM_SWITCH_INPUT_NUM || input_size != domi::STREAM_SWITCH_INPUT_NUM) { - GELOGE(INTERNAL_ERROR, "Input num should be %u inputAddr size:%zu, inputDesc size:%zu.", - domi::STREAM_SWITCH_INPUT_NUM, input_data_addr.size(), input_size); + if (input_data_addr.size() != STREAM_SWITCH_INPUT_NUM || input_size != STREAM_SWITCH_INPUT_NUM) { + GELOGE(INTERNAL_ERROR, "Input num should be %u. inputAddr size:%zu, inputDesc size:%zu.", STREAM_SWITCH_INPUT_NUM, + input_data_addr.size(), input_size); return INTERNAL_ERROR; } @@ -86,6 +86,8 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d true_stream_ = davinci_model->GetStreamList()[true_stream_index]; true_stream_id_ = stream_switch_def.true_stream_id(); + davinci_model->DisableZeroCopy(input_ptr_); + davinci_model->DisableZeroCopy(value_ptr_); if (op_desc->HasAttr(ATTR_NAME_SWITCH_DATA_TYPE)) { int64_t data_type = 0; diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc index f4f62df0..29b107bd 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc @@ -40,6 +40,12 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * } auto stream_switchn_def = task_def.stream_switch_n(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(stream_switchn_def.op_index()); + if (op_desc == nullptr) { + GELOGE(FAILED, "Index is out of range, index: %u", stream_switchn_def.op_index()); + return FAILED; + } + // set size_ input_size_ = stream_switchn_def.size(); if (input_size_ != kDynamicBtachParamNum && input_size_ != kDynamicResolutionParamNum) { @@ -59,21 +65,6 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * } value_ptr_ = &value_list_[0]; - uint32_t op_index = stream_switchn_def.op_index(); - - // get StreamSwitchN op - auto op_list = davinci_model->GetOpList(); - auto iter = op_list.find(op_index); - if (iter == op_list.end()) { - GELOGE(FAILED, "Index is out of range, index: %u", op_index); - return FAILED; - } - OpDescPtr op_desc = iter->second; - if (op_desc == nullptr) { - GELOGE(FAILED, "SwitchN op is nullptr."); - return FAILED; - } - // set element_size_ if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, element_size_)) { GELOGE(FAILED, "Get ATTR_NAME_BATCH_NUM of switchN op failed."); @@ -92,6 +83,7 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * return FAILED; } input_ptr_ = input_data_addr[0]; + davinci_model->DisableZeroCopy(input_ptr_); GELOGI("StreamSwitchNTaskInfo Init Success, inputSize:%u, elementSize:%d, trueStreamID:%ld.", input_size_, element_size_, op_desc->GetStreamId()); diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc index 38dbd8b3..b8fc77ac 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc @@ -19,17 +19,17 @@ namespace ge { namespace skt { -Status SuperKernel::Launch(rtStream_t stream, bool dump_flag) { +Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { const void *func_stub_ = this->GetFuncStub(); - const void *args[] = {this->GetNavTablePtr(), (const void *)this->GetNavTableSize()}; + const void *args[] = {this->GetNavTablePtr(), + reinterpret_cast(reinterpret_cast(this->GetNavTableSize()))}; - void *device_args_addr = nullptr; - rtError_t rt_ret = rtMalloc((void **)&(device_args_addr), sizeof(args), RT_MEMORY_HBM); + rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) - rt_ret = rtMemcpy((void *)device_args_addr, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) - rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr, sizeof(args), NULL, stream, + rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, dump_flag); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); return FAILED;) diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h index b662d97b..1c31acd1 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h @@ -25,6 +25,7 @@ namespace ge { namespace skt { class SuperKernel { private: + void *device_args_addr_ = nullptr; const void *func_stub_; void *dev_nav_table_; uint64_t nav_table_size_; @@ -33,8 +34,18 @@ class SuperKernel { public: SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} - ~SuperKernel() {} - Status Launch(rtStream_t stream, bool dump_flag); + ~SuperKernel() { + // free memory when all releasing + if (device_args_addr_ != nullptr) { + GE_CHK_RT(rtFree(device_args_addr_)); + GELOGI("SKT: super_kernel args addr free."); + } + if (dev_nav_table_ != nullptr) { + GE_CHK_RT(rtFree(dev_nav_table_)); + GELOGI("SKT: super_kernel args addr free."); + } + } + Status Launch(rtStream_t stream, uint32_t dump_flag); const void *GetFuncStub() const { return func_stub_; } const void *GetNavTablePtr() const { return dev_nav_table_; } uint64_t GetNavTableSize() const { return nav_table_size_; } diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index ab3f68f1..4c430ff9 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -26,30 +26,36 @@ SuperKernelFactory &SuperKernelFactory::GetInstance() { Status SuperKernelFactory::Init() { if (!is_init_) { + std::string skt_bin = "libcce_aicore.so"; + handle_ = dlopen(skt_bin.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ == nullptr) { + GELOGE(FAILED, "SKT: open skt lib failed, please check LD_LIBRARY_PATH."); + } rtError_t rt_ret; rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetFunctionByName " - "failied. stub_func: %s", + "failed. stub_func: %s, please export LD_LIBRARY_PATH for " + "libcce_aicore.so", this->sk_stub_name_.c_str()); return FAILED;) rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); return FAILED;) if (this->use_physical_address_ != nullptr) { void *skt_func = nullptr; rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); return FAILED;) GELOGD( "SKT: fuseKernels super_kernel_template subFunc %p, device func " "address %p, device physic PC %p", - (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_, (uint64_t)skt_func); + this->func_stub_, this->func_ptr_, skt_func); } else { GELOGD( "SKT: fuseKernels super_kernel_template subFunc %p, device func " "address %p", - (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_); + this->func_stub_, this->func_ptr_); } } is_init_ = true; @@ -94,63 +100,66 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); rtError_t rt_ret; + void *hbm_nav_table_addr = nullptr; if (this->use_physical_address_ != nullptr) { for (unsigned i = 0; i < stub_func_list.size(); i++) { void *sub_device_func = nullptr; rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); return FAILED;) void *sub_device_func_pys = nullptr; void *args_addr_pys = nullptr; rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); return FAILED;) GELOGD( "SKT: fuseKernels subFunc %p, device func address %p, device " "physic func address %p", - stub_func_list[i], (uint64_t)sub_device_func, (uint64_t)sub_device_func_pys); - nav_table[i * 2] = (uint64_t)sub_device_func_pys / 4; - GELOGD("SKT: CALL offet %p", nav_table[i * 2]); - nav_table[i * 2 + 1] = (uint64_t)args_addr_pys; - GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); + stub_func_list[i], sub_device_func, sub_device_func_pys); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func_pys)) / 4; + GELOGD("SKT: CALL offset %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_pys)); + + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); } - void *hbm_nav_table_addr = nullptr; void *hbm_nav_table_addr_pys = nullptr; rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); return FAILED;) - GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", (uint64_t)hbm_nav_table_addr, - (uint64_t)hbm_nav_table_addr_pys); + GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); // Create the necessary metadata for the super kernel h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); } else { for (unsigned i = 0; i < stub_func_list.size(); i++) { void *sub_device_func = nullptr; rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); return FAILED;) - GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], (uint64_t)sub_device_func); - nav_table[i * 2] = (uint64_t)sub_device_func / 4; - GELOGD("SKT: CALL offet %p", nav_table[i * 2]); - nav_table[i * 2 + 1] = (uint64_t)args_addr_list[i]; - GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); + GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); } - void *hbm_nav_table_addr = nullptr; rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) // Create the necessary metadata for the super kernel h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); } diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h index 7b59d4bf..d8b7ff26 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h @@ -27,16 +27,24 @@ class SuperKernelFactory { private: void *func_stub_ = nullptr; void *func_ptr_ = nullptr; + void *handle_ = nullptr; std::string sk_stub_name_ = "_Z21super_kernel_templatePmm"; const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); bool is_init_ = false; SuperKernelFactory(){}; + ~SuperKernelFactory() { + if (handle_ != nullptr) { + GELOGI("SKT: SKT LIB PATH release."); + if (dlclose(handle_) != 0) { + GELOGW("failed to close handle, message: %s", dlerror()); + } + } + }; public: SuperKernelFactory(SuperKernelFactory const &) = delete; void operator=(SuperKernelFactory const &) = delete; static SuperKernelFactory &GetInstance(); - SuperKernelFactory(const std::string &sk_stub_name_, const std::string &bin_file); Status Init(); Status Uninitialize(); Status FuseKernels(const std::vector &stub_func_list, const std::vector &args_addr_list, diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h index 86fda23e..b6954016 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h @@ -86,6 +86,5 @@ class TaskInfoFactory { return ptr; \ } \ TaskInfoFactory::Registerar g_##type##_Task_Info_Creator(type, Creator_##type##_Task_Info); -}; // namespace ge +}; // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_FACTORY_H_ - diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc new file mode 100644 index 00000000..42734a87 --- /dev/null +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/new_model_manager/zero_copy_task.h" + +#include "graph/load/new_model_manager/model_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" + +namespace ge { +const char *const kDefaultBatchLable = "Batch_default"; + +ZeroCopyTask::ZeroCopyTask(const string &name, uint8_t *args, size_t size) + : name_(name), args_addr_(args), args_size_(size), is_updated_(false) {} + +ZeroCopyTask::~ZeroCopyTask() { args_addr_ = nullptr; } + +/** + * @ingroup ge + * @brief Set Task zero copy addr info. + * @param [in] addr: task addr value. + * @param [in] offset: saved offset in task args. + * @return: 0 SUCCESS / others FAILED + */ +Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { + if (offset + sizeof(uintptr_t) > args_size_) { + GELOGE(FAILED, "[ZCPY] %s set task args failed, args size: %zu, offset: %zu", name_.c_str(), args_size_, offset); + return FAILED; // unexpected error, need fix. + } + + auto it = task_addr_offset_.find(addr); + if (it == task_addr_offset_.end()) { + task_addr_offset_[addr] = {offset}; + } else { + it->second.push_back(offset); + } + + GELOGI("[ZCPY] %s set task, addr: 0x%lx, args: %p, size: %zu, offset: %zu", name_.c_str(), addr, args_addr_, + args_size_, offset); + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Save orignal data of task args. + * @param [in] info: task args orignal data. + * @param [in] size: args size. + * @return: void + */ +void ZeroCopyTask::SetOriginalArgs(const void *info, size_t size) { + GE_CHECK_NOTNULL_JUST_RETURN(info); + const uint8_t *data = static_cast(info); + args_info_.assign(data, data + size); + + GELOGI("[ZCPY] %s set info, args: %p, args size: %zu, info size: %zu", name_.c_str(), args_addr_, args_size_, size); +} + +/** + * @ingroup ge + * @brief Check is dynamic batch node. + * @param [in] addr: virtual address value from Op. + * @param [in] data: data buffer from user. + * @param [in] batch_addrs: dynamic batch addr info. + * @param [in] batch_label: batch label. + * @return: true / false + */ +bool ZeroCopyTask::CheckDynamicBatch(const map> &batch_addrs, const string &batch_label, + uintptr_t addr) { + // Used for dynamic batch / resolution scene + set dynamic_input_addrs; + auto dynamic_input_iter = batch_addrs.find(batch_label); + if (dynamic_input_iter != batch_addrs.end()) { + dynamic_input_addrs = dynamic_input_iter->second; + } + + set fix_input_addrs; + auto fix_input_iter = batch_addrs.find(kDefaultBatchLable); + if (fix_input_iter != batch_addrs.end()) { + fix_input_addrs = fix_input_iter->second; + } + + if (fix_input_addrs.empty()) { + if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end()) { + return false; + } + } else { + if (!dynamic_input_addrs.empty() && dynamic_input_addrs.find(addr) == dynamic_input_addrs.end() && + fix_input_addrs.find(addr) == fix_input_addrs.end()) { + return false; + } + } + + return true; +} + +/** + * @ingroup ge + * @brief Set user data addr to Task param. + * @param [in] addr: virtual address value from Op. + * @param [in] data: data buffer from user. + * @param [in] batch_addrs: dynamic batch addr info. + * @param [in] batch_label: batch label. + * @return: void + */ +Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, const DataBuffer &data, + const map> &batch_addrs, const string &batch_label) { + for (auto pair : task_addr_offset_) { + if (pair.first != addr) { + continue; + } + + uint8_t *args_info = args_info_.data(); + for (auto offset : pair.second) { + if (!CheckDynamicBatch(batch_addrs, batch_label, reinterpret_cast(args_addr_ + offset))) { + continue; + } + + auto dst_addr = static_cast(data.data); + auto dst_size = static_cast(data.length); + if (ModelUtils::ConvertVirtualAddressToPhysical(dst_addr, dst_size, dst_addr) != SUCCESS) { + GELOGE(FAILED, "[ZCPY] Convert virtual address to physical for dst_addr failed."); + return FAILED; + } + + GELOGI("[ZCPY] %s update task, args: %p, size: %zu, offset: %zu, addr: 0x%lx, length: %u", name_.c_str(), + args_addr_, args_size_, offset, addr, data.length); + *(uintptr_t *)(args_info + offset) = reinterpret_cast(dst_addr); + is_updated_ = true; + } + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Update task param to device. + * @param [in] stream: Stream for asychronous update. + * @return: 0 SUCCESS / others FAILED + */ +Status ZeroCopyTask::DistributeParam(rtStream_t stream) { + if (!is_updated_) { + return SUCCESS; + } + + is_updated_ = false; + GE_CHECK_NOTNULL(args_addr_); + rtError_t rt_err = RT_ERROR_NONE; + if (stream != nullptr) { + rt_err = + rtMemcpyAsync(args_addr_, args_size_, args_info_.data(), args_info_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream); + } else { + __builtin_prefetch(args_addr_); + rt_err = rtMemcpy(args_addr_, args_size_, args_info_.data(), args_info_.size(), RT_MEMCPY_HOST_TO_DEVICE); + } + + if (rt_err != RT_ERROR_NONE) { + GELOGE(FAILED, "[ZCPY] %s distribute task param failed, error=0x%x", name_.c_str(), rt_err); + return FAILED; + } + + GELOGI("[ZCPY] %s refresh task args success, args: %p, size: %zu, args_info_: %p, length: %zu", name_.c_str(), + args_addr_, args_size_, args_info_.data(), args_info_.size()); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.h b/src/ge/graph/load/new_model_manager/zero_copy_task.h new file mode 100644 index 00000000..9d3f5b03 --- /dev/null +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ + +#include +#include +#include +#include + +#include "external/ge/ge_api_error_codes.h" +#include "framework/common/ge_types.h" +#include "runtime/mem.h" + +using std::map; +using std::set; +using std::string; +using std::vector; + +namespace ge { +class ZeroCopyTask { + public: + ZeroCopyTask(const string &name, uint8_t *args, size_t size); + ~ZeroCopyTask(); + + /** + * @ingroup ge + * @brief Set Task zero copy addr info. + * @param [in] addr: task addr value. + * @param [in] offset: saved offset in task args. + * @return: 0 SUCCESS / others FAILED + */ + ge::Status SetTaskArgsOffset(uintptr_t addr, size_t offset); + + /** + * @ingroup ge + * @brief Is need zero copy. + * @return: true / false + */ + bool IsTaskArgsSet() const { return !task_addr_offset_.empty(); } + + /** + * @ingroup ge + * @brief Save orignal data of task args. + * @param [in] info: task args orignal data. + * @param [in] size: args size. + * @return: void + */ + void SetOriginalArgs(const void *info, size_t size); + + /** + * @ingroup ge + * @brief Set user data addr to Task param. + * @param [in] addr: virtual address value from Op. + * @param [in] data: data buffer from user. + * @param [in] batch_addrs: dynamic batch addr info. + * @param [in] batch_label: batch label. + * @return: 0 SUCCESS / others FAILED + */ + ge::Status UpdateTaskParam(uintptr_t addr, const DataBuffer &data, const map> &batch_addrs, + const string &batch_label); + + /** + * @ingroup ge + * @brief Update task param to device. + * @param [in] stream: Stream for asychronous update. + * @return: 0 SUCCESS / others FAILED + */ + ge::Status DistributeParam(rtStream_t stream); + + protected: + bool CheckDynamicBatch(const map> &batch_addrs, const string &batch_label, uintptr_t addr); + + private: + const string name_; + + uint8_t *args_addr_; + const size_t args_size_; + vector args_info_; + bool is_updated_; + + //
+ map> task_addr_offset_; +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ \ No newline at end of file diff --git a/src/ge/graph/load/output/output.h b/src/ge/graph/load/output/output.h index 4a3b0db2..d93b8de9 100644 --- a/src/ge/graph/load/output/output.h +++ b/src/ge/graph/load/output/output.h @@ -21,15 +21,14 @@ #include #include "common/debug/log.h" -#include "common/op/attr_define.h" #include "common/op/attr_value_util.h" #include "common/op/ge_op_utils.h" -#include "common/op/op_parser_util.h" #include "common/types.h" #include "common/util.h" #include "common/ge_types.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/op_desc.h" +#include "graph/debug/ge_attr_define.h" namespace ge { using std::string; diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index 765b2302..514a90ce 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -33,45 +33,51 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/ge_local_context.h" #include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/util/rt_context_util.h" +#include "graph/passes/addn_pass.h" #include "graph/passes/atomic_addr_clean_pass.h" +#include "graph/passes/cast_remove_pass.h" +#include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/compile_nodes_pass.h" #include "graph/passes/constant_folding_pass.h" -#include "graph/passes/control_op_attr_pass.h" +#include "graph/passes/constant_fuse_same_pass.h" #include "graph/passes/dimension_adjust_pass.h" +#include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/hccl_memcpy_pass.h" #include "graph/passes/identify_reference_pass.h" +#include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" +#include "graph/passes/merge_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/permute_pass.h" +#include "graph/passes/prune_pass.h" +#include "graph/passes/replace_with_empty_const_pass.h" #include "graph/passes/reshape_remove_pass.h" #include "graph/passes/same_transdata_breadth_fusion_pass.h" +#include "graph/passes/subgraph_pass.h" +#include "graph/passes/switch_logic_remove_pass.h" +#include "graph/passes/switch_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" +#include "graph/passes/transop_symmetry_elimination_pass.h" #include "graph/passes/transop_without_reshape_fusion_pass.h" -#include "graph/passes/cast_remove_pass.h" #include "graph/passes/transpose_transdata_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/partition/dynamic_shape_partition.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "init/gelib.h" -using domi::ASSIGN; -using domi::ATTR_MODEL_MEMORY_SIZE; -using domi::ATTR_MODEL_WEIGHT_SIZE; -using domi::ATTR_NAME_SESSION_GRAPH_ID; -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::HCOMBROADCAST; -using domi::VARIABLE; - namespace { const char *const kSummary = "Summary"; const char *const kSave = "Save"; @@ -125,6 +131,7 @@ Status GraphManager::Initialize(const std::map &options) { } graph_map_.clear(); + cache_helper_map_.clear(); init_flag_ = true; thread_run_flag_ = true; @@ -188,6 +195,7 @@ Status GraphManager::Finalize() { } } graph_map_.clear(); + cache_helper_map_.clear(); // graph context if (graph_context_ != nullptr) { @@ -222,6 +230,9 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { GELOGW("Set attribute of compute graph failed."); } + for (auto &subgraph : compute_graph->GetAllSubgraphs()) { + (void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + } GELOGW("Get graph session_graph_id attr failed, set session id to default value: [0]"); } @@ -312,6 +323,50 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr return SUCCESS; } +#define GM_RUN_AND_DUMP(name, func, ...) \ + do { \ + GE_RUN(GraphManager, func, __VA_ARGS__); \ + GraphUtils::DumpGEGraph(compute_graph, "PreRunAfter" name); \ + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PreRunAfter" name); \ + GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ + } while (0) +Status GraphManager::PreRunDynShape(const GraphNodePtr &graph_node, const std::vector &inputs, + vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { + GE_CHECK_NOTNULL(graph_node); + GE_CHECK_NOTNULL(graph_node->GetGraph()); + auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); + GE_CHECK_NOTNULL(compute_graph); + + GEEVENT("PreRun start, graph node size %zu, session id %lu, graph id %u, graph name %s", + compute_graph->GetDirectNodesSize(), session_id, compute_graph->GetGraphID(), + compute_graph->GetName().c_str()); + GraphUtils::DumpGEGraph(compute_graph, "PreRunBegin"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PreRunBegin"); + + GM_RUN_AND_DUMP("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); + GM_RUN_AND_DUMP("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); + GM_RUN_AND_DUMP("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, + session_id); + // original graph optimization and running format inference + GM_RUN_AND_DUMP("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); + GM_RUN_AND_DUMP("Optimize1", OptimizeStage1, compute_graph); + GM_RUN_AND_DUMP("InferShape2", compute_graph->InferShapeInNeed); + GM_RUN_AND_DUMP("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); + GM_RUN_AND_DUMP("Optimize2", OptimizeStage2, compute_graph); + GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_models, ge_model, session_id); + + // when set incre build, save om model and var manager + auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), compute_graph, ge_model); + if (save_ret != SUCCESS) { + GELOGW("Fail to save cache."); + } + // release rts generate context + RtContextUtil::GetInstance().DestroyrtContexts(); + GEEVENT("[GEPERFTRACE] GE PreRun End"); + return SUCCESS; +} +#undef RUN_AND_DUMP + Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { GELOGI("Ready For PreRun Start session_id = %lu.", session_id); @@ -323,6 +378,10 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetAllSubgraphs()) { + GraphUtils::DumpGEGraph(graph, "BeforeSummaryHandle"); + GraphUtils::DumpGEGraphToOnnx(*graph, "BeforeSummaryHandleSubgraph"); + } GEEVENT("PreRun start, graph node size is %zu", compute_graph->GetDirectNodesSize()); // optimize the summary op in graph: store the summary name and replace the summary ops with net_output op. GE_TIMESTAMP_START(HandleSummaryOp); @@ -330,7 +389,7 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetGraph(), inputs, compute_graph, session_id); + ret = graph_preparer_.Prepare(graph_node->GetGraph(), inputs, compute_graph, var_acc_ctrl_, session_id); if (ret != SUCCESS) { GELOGE(ret, "ATC RunGraph input compute graph is NULL"); return ret; @@ -339,6 +398,9 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorSetSessionID(session_id); GraphUtils::DumpGEGraph(compute_graph, "OptimizeOriginalGraphAfter"); GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeOriginalGraphAfter"); + for (auto graph : compute_graph->GetAllSubgraphs()) { + GraphUtils::DumpGEGraphToOnnx(*graph, "OptimizeOriginalGraphAfterSubgraph"); + } GE_TIMESTAMP_START(InferShape); // Origin graph infershape @@ -346,94 +408,43 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector merged_sub_graph_list; - - GE_TIMESTAMP_START(MergeSubgraph); - ret = MergeSubGraph(merged_compute_graph, compute_graph); + ret = OptimizeSubgraph(graph_node, compute_graph, session_id); if (ret != SUCCESS) { - GELOGE(ret, "Merge SubGraph Failed"); return ret; } - merged_compute_graph->SetSessionID(session_id); - merged_compute_graph->SetGraphID(graph_node->GetGraphId()); - GraphUtils::DumpGEGraph(merged_compute_graph, "mergedComputeGraph"); - GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, "mergedComputeGraph"); - for (auto &sub_graph : merged_compute_graph->GetAllSubgraphs()) { - string subgraph_name = "mergedComputeGraph" + sub_graph->GetName(); - sub_graph->SetSessionID(session_id); - sub_graph->SetGraphID(graph_node->GetGraphId()); - GraphUtils::DumpGEGraph(merged_compute_graph, subgraph_name); - GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, subgraph_name); - } - GE_TIMESTAMP_END(MergeSubgraph, "GraphManager::MergeSubGraph"); - std::shared_ptr instance_ge = ge::GELib::GetInstance(); if (instance_ge != nullptr && instance_ge->InitFlag()) { // optimize after merge subgraph GE_TIMESTAMP_START(OptimizeAfterMergeSubgraph); - ret = OptimizeAfterMergeSubGraph(merged_compute_graph); + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on != nullptr) { + ret = NewOptimizeAfterMergeSubGraph(compute_graph); + } else { + ret = OptimizeAfterMergeSubGraph(compute_graph); + } if (ret != SUCCESS) { GELOGE(ret, "Optimize after merge subgraph failed."); return ret; } GE_TIMESTAMP_END(OptimizeAfterMergeSubgraph, "GraphManager::OptimizeAfterMergeSubGraph"); } - GraphUtils::DumpGEGraph(merged_compute_graph, "OptimizeMergeSubGraphAfter"); - GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, "OptimizeMergeSubGraphAfter"); - // build - if (merged_compute_graph != nullptr) { - std::string graph_name = merged_compute_graph->GetName(); - graph_name.append("_"); - graph_name.append(std::to_string(graph_node->GetGraphId())); - merged_compute_graph->SetName(graph_name); - } - std::vector sub_graph_list; - ret = graph_builder_.Build(merged_compute_graph, sub_graph_list, ge_model, session_id); + GraphUtils::DumpGEGraph(compute_graph, "OptimizeMergeSubGraphAfter"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeMergeSubGraphAfter"); + + ret = Build(graph_node, compute_graph, ge_models, ge_model, session_id); if (ret != SUCCESS) { - GELOGE(ret, "SubGraph build Failed."); return ret; } - bool is_always_dump = false; - PropertiesManager &properties_manager = PropertiesManager::Instance(); - if (!properties_manager.GetDumpOutputPath().empty()) { - is_always_dump = true; - } - - GraphUtils::DumpGEGraph(merged_compute_graph, "Build", is_always_dump); - GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, "Build"); - - // set modelptr to subgraph - for (const auto &sub_graph_info : sub_graph_list) { - sub_graph_info->SetGeModelPtr(ge_model); + // when set incre build, save om model and var manager + auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), compute_graph, ge_model); + if (save_ret != SUCCESS) { + GELOGW("Fail to save cache."); } - - ge_models.push_back(ge_model); - - GE_IF_BOOL_EXEC(sub_graph_list.empty(), GELOGE(FAILED, "Input graph must have at least one calculation op Node"); - return FAILED;); - sub_graph_list[0]->SetSubGraph(merged_compute_graph); - // set subgraphlist to graphnode - graph_node->SetSubGraph(sub_graph_list); + // release rts generate context + RtContextUtil::GetInstance().DestroyrtContexts(); GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); GEEVENT("[GEPERFTRACE] GE PreRun End"); return ret; @@ -452,10 +463,14 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: return PARAM_INVALID; } GeModelPtr ge_model = nullptr; - ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); + // check need incre build. + ret = IncreBuild(graph_node, ge_model); if (ret != SUCCESS) { - GELOGE(ret, "PreRun Failed."); - return ret; + ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); + if (ret != SUCCESS) { + GELOGE(ret, "PreRun Failed."); + return ret; + } } ret = LoadGraph(ge_model, graph_node); if (ret != SUCCESS) { @@ -483,10 +498,10 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g if (getenv(kEnvGeuseStaticMemory) != nullptr) { GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); } else { - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)) + GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); } GE_TIMESTAMP_START(LoadGraph); - Status ret = graph_loader_.LoadGraph(ge_model, model_listener, model_id_info); + Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_model, model_listener); GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraph"); if (ret != SUCCESS) { GELOGE(ret, "[StartForRunGraph] LoadGraph Failed"); @@ -500,6 +515,90 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g return SUCCESS; } +Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, + GeModelPtr &ge_model) { + auto graph_id = graph_node->GetGraphId(); + auto ret = cache_helper->LoadOmModelFromCache(ge_model); + if (ret != SUCCESS) { + GELOGW("Fail to load om model from cache."); + if (cache_helper->ClearCache(graph_id) != SUCCESS) { + GELOGW("Fail to clear cache of graph %u.", graph_id); + } + return FAILED; + } + ret = cache_helper->RecoverVarManagerFromCache(); + if (ret != SUCCESS) { + GELOGW("Fail to recover VarManager from cache."); + if (cache_helper->ClearCache(graph_id) != SUCCESS) { + GELOGW("Fail to clear cache of graph %u.", graph_id); + } + return FAILED; + } + ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); + if (compute_graph_in_model == nullptr) { + GELOGW("Error occurred when get compute graph from om, abandon."); + return FAILED; + } else { + graph_node->SetComputeGraph(compute_graph_in_model); + graph_node->SetGeModel(ge_model); + GELOGI("Load model and graph form cache om file."); + } + return SUCCESS; +} + +Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { + auto ret = cache_helper->SaveCacheInfoToCache(); + if (ret != SUCCESS) { + GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); + return FAILED; + } + ret = cache_helper->SaveVarManagerToCache(true); + if (ret != SUCCESS) { + GELOGW("Fail to save var manager to cache."); + cache_helper->ClearCache(graph_id); + return FAILED; + } + GELOGI("Cache files have been saved."); + return SUCCESS; +} + +Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { + GELOGW("GELib not initialized."); + return FAILED; + } + + if (instance_ptr->IsIncreBuild()) { + auto iter = cache_helper_map_.find(graph_id); + if (iter == cache_helper_map_.end()) { + GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); + return FAILED; + } else { + ModelCacheHelperPtr cache_helper = iter->second; + auto ret = cache_helper->RefreshComputeGraph(graph); + if (ret != SUCCESS) { + cache_helper->ClearCache(graph_id); + GELOGW("Fail to refresh cache helper's compute graph"); + return FAILED; + } + ret = cache_helper->SaveVarManagerToCache(false); + if (ret != SUCCESS) { + cache_helper->ClearCache(graph_id); + GELOGW("Fail to save VarManager to cache"); + return FAILED; + } + ret = cache_helper->SaveOmModelToCache(ge_model); + if (ret != SUCCESS) { + cache_helper->ClearCache(graph_id); + GELOGW("Fail to save om model to cache"); + return FAILED; + } + } + } + return SUCCESS; +} + Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs) { Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); @@ -510,7 +609,7 @@ Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &grap } if (GetTrainFlag()) { - GE_CHK_STATUS_RET(graph_executor_.SetGraphContext(GetGraphContext())) + GE_CHK_STATUS_RET(graph_executor_.SetGraphContext(GetGraphContext())); graph_executor_.SetTrainFlag(options_.train_graph_flag); } ret = graph_executor_.ExecuteGraph(graph_id, graph_node->GetGeModel(), inputs, outputs); @@ -559,6 +658,9 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector ge_models; if (options_.local_fmk_op_flag) { @@ -591,7 +693,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorGetSubGraph(); if (IsCheckpointGraph(checkPointGraph)) { - ret = CheckpointHandle(graph_id, outputs); + ret = CheckpointHandle(graph_id, checkPointGraph, outputs); if (ret != SUCCESS) { GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); } @@ -603,6 +705,31 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorGetGraph()); + if (ret != SUCCESS) { + GELOGE(ret, "ATC dump infershape json failed"); + return ret; + } + + GELOGI("[DumpInfershapeJson] Dump infershape json success, graph_id=%u.", graph_id); + return ret; +} + Status GraphManager::BuildGraph(const GraphId &graph_id, const std::vector &inputs, std::vector &models) { GELOGI("[BuildGraph] start to build graph, graph_id=%u.", graph_id); @@ -675,6 +802,19 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con return SUCCESS; } +void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { + auto iter = cache_helper_map_.find(graph_id); + if (iter != cache_helper_map_.end()) { + cache_helper_map_.erase(iter); + } else { + GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); + } +} + +bool GraphManager::CheckModelLoad(const GeModelPtr &ge_model, bool load_flag) { + return ((ge_model != nullptr) && (ge_model->GetModelId() != INVALID_MODEL_ID) && load_flag); +} + Status GraphManager::RemoveGraph(const GraphId &graph_id) { auto it = graph_map_.find(graph_id); if (it == graph_map_.end()) { @@ -724,8 +864,11 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { } var_acc_ctrl_.RemoveGraph(graph_id); graph_map_.erase(it); + + RemoveModelCacheHelper(graph_id); + auto ge_model = graph_node->GetGeModel(); - if (ge_model != nullptr) { + if (CheckModelLoad(ge_model, graph_node->GetLoadFlag())) { GELOGI("Unload model %u.", ge_model->GetModelId()); rt_ret = rtSetDevice(GetContext().DeviceId()); if (rt_ret != RT_ERROR_NONE) { @@ -1114,21 +1257,15 @@ Status GraphManager::SummaryHandle(const GraphId &graph_id, std::vector &outputs) { +Status GraphManager::CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, + const std::vector &outputs) { GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); std::vector outputs_desc = graph_executor_.GetOutputsDesc(); GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); - // find graph - GraphNodePtr graph_node = nullptr; - Status ret = GetGraphNode(graph_id, graph_node); - if (ret != SUCCESS) { - GELOGE(ret, "[CheckpointHandle] graph not exist, graph_id = %u.", graph_id); - return ret; - } - ComputeGraphPtr compute_graph_ptr = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); + std::map save_results; NodePtr netoutput = nullptr; - for (const auto &node : compute_graph_ptr->GetDirectNode()) { + for (const auto &node : compute_graph->GetDirectNode()) { if (node->GetType() == kNetOutput) { netoutput = node; break; @@ -1256,6 +1393,8 @@ bool GraphManager::CheckTransOpForCheckpointGraph(NodePtr &node) { return true; } +static inline bool CheckConstanOpForCheckpointGraph(NodePtr &node) { return node->GetOutDataNodes().empty(); } + bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); @@ -1276,6 +1415,10 @@ bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { if (!CheckTransOpForCheckpointGraph(node)) { return false; } + } else if (op->GetType() == CONSTANTOP) { + if (!CheckConstanOpForCheckpointGraph(node)) { + return false; + } } else if (op->GetType() != kSend && op->GetType() != kRecv) { GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), op->GetName().c_str()); @@ -1302,7 +1445,7 @@ bool GraphManager::IsBroadCastOpData(const ge::NodePtr &var_node) { } void GraphManager::AdjustBroadCastOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { GELOGW("set var_is_restore failed"); } } @@ -1320,7 +1463,7 @@ bool GraphManager::IsAssignOpData(const ge::NodePtr &var_node) { } void GraphManager::AdjustAssignOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { GELOGW("SetStr var_is_restore failed"); } } @@ -1382,45 +1525,128 @@ Status GraphManager::RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph) { return SUCCESS; } -Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { - GELOGI("Start optimize after merge sub graph."); +Status GraphManager::NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { + GELOGD("NewOptimizeAfterMergeSubGraph in"); - GEPass ge_passes_for_shape(compute_graph); - NamesToPass names_to_passes_for_shape; - IdentifyReferencePass identify_reference_pass; - names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); - CastRemovePass cast_remove_pass; - names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); - TransposeTransDataPass transpose_transdata_pass; - names_to_passes_for_shape.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); - GE_TIMESTAMP_START(ge_passes_for_shape); - Status ret = ge_passes_for_shape.Run(names_to_passes_for_shape); - GE_TIMESTAMP_END(ge_passes_for_shape, "GraphManager::GePassesForShape"); + GEPass ge_passes(compute_graph); + NamesToPass names_to_passes; + ConstantFoldingPass constant_folding_pass; + names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + GE_TIMESTAMP_START(names_to_passes); + auto ret = ge_passes.Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphManager::ge_passes"); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + return ret; + } + + ret = RemoveIsolatedConst(compute_graph); if (ret != SUCCESS) { - GELOGE(ret, "Run ge_passes_for_shape optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + GELOGE(ret, "Remove isolated Constant failed, ret:%d.", ret); return ret; } + PassManager passes; + GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) MultiBatchPass)); + GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) CompileNodesPass)); + GE_CHK_STATUS_RET(passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); + + GE_TIMESTAMP_START(passes); + ret = passes.Run(compute_graph); + GE_TIMESTAMP_END(passes, "GraphManager::passes"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes optimize for OptimizeAfterMergeSubGraph failed"); + return ret; + } + + ret = compute_graph->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); + return ret; + } + return SUCCESS; +} + +Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { string options = "default"; if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { GELOGI("get ge.exec.variable_acc failed. set default value."); } PassManager after_merge_passes; - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) PermutePass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariablePrepareOpPass)) + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) ConstantFuseSamePass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariablePrepareOpPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) IteratorOpPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) PermutePass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariablePrepareOpPass)); GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpDepthFusionPass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpBreadthFusionPass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableRefDeleteOpPass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)) - GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)) + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpDepthFusionPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpBreadthFusionPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)); + + GE_TIMESTAMP_START(after_merge_passes); + auto ret = after_merge_passes.Run(compute_graph); + GE_TIMESTAMP_END(after_merge_passes, "GraphManager::AfterMergePasses"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); + return ret; + } + + GEPass ge_passes(compute_graph); + NamesToPass names_to_passes; + TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; + ReshapeRemovePass reshape_remove_pass; + ConstantFoldingPass constant_folding_pass; + DimensionAdjustPass dimension_adjust_pass; + AddNPass addn_pass; + SwitchPass switch_pass; + SwitchLogicRemovePass switch_logic_remove_pass; + MergePass merge_pass; + IdentifyReferencePass identify_reference_pass; + CastRemovePass cast_remove_pass; + TransposeTransDataPass transpose_transdata_pass; + names_to_passes.emplace_back("AddNPass", &addn_pass); + names_to_passes.emplace_back("SwitchPass", &switch_pass); + names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); + names_to_passes.emplace_back("MergePass", &merge_pass); + names_to_passes.emplace_back("IdentifyReferencePass", &identify_reference_pass); + names_to_passes.emplace_back("CastRemovePass", &cast_remove_pass); + names_to_passes.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); + names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); + names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); + GE_TIMESTAMP_START(names_to_passes); + ret = ge_passes.Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + return ret; + } + + PassManager graph_pass; + try { + (void)graph_pass.AddPass(new PrunePass); + } catch (std::bad_alloc &e) { + GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); + return INTERNAL_ERROR; + } + + return SUCCESS; +} +Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { + GELOGI("Start optimize after merge sub graph."); + + PassManager after_merge_passes; + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableRefDeleteOpPass)); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); GE_CHK_STATUS_RET( - after_merge_passes.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))) + after_merge_passes.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); GE_TIMESTAMP_START(after_merge_passes); - ret = after_merge_passes.Run(compute_graph); + auto ret = after_merge_passes.Run(compute_graph); GE_TIMESTAMP_END(after_merge_passes, "GraphManager::AfterMergePasses"); if (ret != SUCCESS && ret != NOT_CHANGED) { GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); @@ -1443,14 +1669,8 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GEPass ge_passes(compute_graph); NamesToPass names_to_passes; - TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; - names_to_passes.emplace_back("ReshapeRemovePass", &trans_op_nearby_allreduce_fusion_pass); - ReshapeRemovePass reshape_remove_pass; - names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); ConstantFoldingPass constant_folding_pass; names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); - DimensionAdjustPass dimension_adjust_pass; - names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); GE_TIMESTAMP_START(names_to_passes); ret = ge_passes.Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); @@ -1466,15 +1686,18 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra } PassManager pass_for_control_attr_optimize; - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) MultiBatchPass)) - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) ControlOpAttrPass)) - GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) CompileNodesPass)) + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) HcclMemcpyPass)); + if (options_.train_graph_flag) { + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) FlowCtrlPass)); + } + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) MultiBatchPass)); + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass(new (std::nothrow) CompileNodesPass)); GE_TIMESTAMP_START(pass_for_control_attr_optimize); ret = pass_for_control_attr_optimize.Run(compute_graph); GE_TIMESTAMP_END(pass_for_control_attr_optimize, "GraphManager::ControlAttrOptimize"); if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(ret, "Run ControlOpAttrPass failed"); + GELOGE(ret, "Run passes when optimize stage 2 failed"); return ret; } @@ -1488,6 +1711,112 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra return SUCCESS; } +Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph) { + GELOGI("Start optimize after merge sub graph."); + + GEPass ge_passes_for_shape(compute_graph); + NamesToPass names_to_passes_for_shape; + IdentifyReferencePass identify_reference_pass; + names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); + CastRemovePass cast_remove_pass; + names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); + TransposeTransDataPass transpose_transdata_pass; + names_to_passes_for_shape.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); + GE_TIMESTAMP_START(ge_passes_for_shape); + Status ret = ge_passes_for_shape.Run(names_to_passes_for_shape); + GE_TIMESTAMP_END(ge_passes_for_shape, "GraphManager::GePassesForShape"); + GE_CHK_STATUS_RET(ret, "Run ge_passes_for_shape optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + + string options = "default"; + if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { + GELOGI("get ge.exec.variable_acc failed. set default value."); + } + PassManager after_merge_passes; + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) PermutePass)); + GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); + GE_CHK_STATUS_RET(after_merge_passes.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl_)))); + ret = after_merge_passes.Run(compute_graph); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); + return ret; + } + + // reshape remove + symmetry_elimination_pass to replace transop depth fusion pass + GEPass ge_passes_symmetry(compute_graph); + NamesToPass names_to_passes_for_symmetry; + ReshapeRemovePass reshape_remove_pass; + names_to_passes_for_symmetry.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + TransOpSymmetryEliminationPass symmetry_elimination_pass; + names_to_passes_for_symmetry.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); + ret = ge_passes_symmetry.Run(names_to_passes_for_symmetry); + GE_CHK_STATUS_RET(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + + PassManager after_merge_fusion_passes; + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) TransOpBreadthFusionPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) VariableRefDeleteOpPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)); + GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass(new (std::nothrow) AtomicAddrCleanPass)); + GE_CHK_STATUS_RET( + after_merge_fusion_passes.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); + GE_TIMESTAMP_START(after_merge_fusion_passes); + ret = after_merge_fusion_passes.Run(compute_graph); + GE_TIMESTAMP_END(after_merge_fusion_passes, "GraphManager::AfterMergePasses"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); + return ret; + } + + // add variable attr for hccl broadcast,need to be removed after variable pass online + for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { + if (node->GetOpDesc()->GetType() != VARIABLE) { + continue; + } + + if (IsBroadCastOpData(node)) { + AdjustBroadCastOpData(node); + } + if (IsAssignOpData(node)) { + AdjustAssignOpData(node); + } + } + + GEPass ge_passes(compute_graph); + NamesToPass names_to_passes; + TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; + names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); + names_to_passes_for_shape.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + ConstantFoldingPass constant_folding_pass; + names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + DimensionAdjustPass dimension_adjust_pass; + names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); + GE_TIMESTAMP_START(names_to_passes); + ret = ge_passes.Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphManager::MergedGraphNameToPasses"); + GE_CHK_STATUS_RET(ret, "Run ge_passes optimize for OptimizeAfterMergeSubGraph failed, ret:%d.", ret); + + ret = RemoveIsolatedConst(compute_graph); + GE_CHK_STATUS_RET(ret, "Remove isolated Constant failed, ret:%d.", ret); + + PassManager pass_for_optimize; + GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) SubgraphPass)); + GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) MultiBatchPass)); + GE_CHK_STATUS_RET(pass_for_optimize.AddPass(new (std::nothrow) CompileNodesPass)); + GE_TIMESTAMP_START(pass_for_optimize); + ret = pass_for_optimize.Run(compute_graph); + GE_TIMESTAMP_END(pass_for_optimize, "GraphManager::OptimizePass"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run optimize pass failed"); + return ret; + } + + ret = compute_graph->TopologicalSorting(); + GE_CHK_STATUS_RET(ret, "Graph topological sort failed, ret:%d.", ret); + + GELOGI("End optimize after merge sub graph."); + return SUCCESS; +} + Status GraphManager::LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodePtr &graph_node) { GELOGI("[LoadGraphAsync] run_graph_flag[%d], graph_id[%u]", options_.run_graph_flag, graph_node->GetGraphId()); if (options_.run_graph_flag && ge_model != nullptr) { @@ -1496,11 +1825,11 @@ Status GraphManager::LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodeP if (getenv(kEnvGeuseStaticMemory) != nullptr) { GELOGI("[LoadGraphAsync] GE_USE_STATIC_MEMORY is seted."); } else { - GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)) + GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); } GE_TIMESTAMP_START(LoadGraph); GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); - Status ret = graph_loader_.LoadGraphAsync(ge_model, graph_node->graph_run_async_listener_, model_id_info); + Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_model, graph_node->graph_run_async_listener_); GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); @@ -1621,14 +1950,11 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager } // run graph async on session -Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, - std::vector &outputs, uint64_t session_id, - std::function callback) { - GELOGI("[GraphManager] Start to run graph async, graph_id=%u, inputsSize=%zu, outputsSize=%zu.", graph_id, - inputs.size(), outputs.size()); - - bool ret = - prerun_args_q_.Push(PreRunArgs({graph_id, inputs, outputs, session_id, GetThreadLocalContext(), callback})); +Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, + uint64_t session_id, RunAsyncCallback callback) { + GELOGI("[GraphManager] Start to run graph async, graph_id=%u, inputsSize=%zu.", graph_id, inputs.size()); + + bool ret = prerun_args_q_.Push(PreRunArgs({graph_id, inputs, session_id, GetThreadLocalContext(), callback})); if (!ret) { GELOGE(FAILED, "[GraphManager] Run graph async failed, graph_id=%u.", graph_id); return FAILED; @@ -1638,6 +1964,51 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { + auto iter = cache_helper_map_.find(graph_id); + if (iter == cache_helper_map_.end()) { + ModelCacheHelperPtr cache_helper = MakeShared(session_id, graph_id, compute_graph); + if (cache_helper != nullptr) { + cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); + } else { + GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); + } + } + } +} + +Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { + return FAILED; + } + const uint32_t graph_id = graph_node->GetGraphId(); + auto iter = cache_helper_map_.find(graph_id); + if (iter == cache_helper_map_.end()) { + GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); + return FAILED; + } + ModelCacheHelperPtr cache_helper = iter->second; + if (cache_helper->IsModelCacheHit()) { + GEEVENT("Model cache hit."); + Status ret = LoadFromCache(graph_node, cache_helper, ge_model); + if (ret == SUCCESS) { + return SUCCESS; + } else { + GELOGW("Error occurred when load from cache, abandon."); + } + } else { + GEEVENT("Model cache miss."); + } + if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { + GELOGW("Error occurred when save cache."); + } + return FAILED; +} + void GraphManager::PreRunThread(GraphManager *graph_manager) { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { GELOGW("Set thread name failed."); @@ -1653,12 +2024,12 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { std::vector ge_inputs; for (auto const &input : args.input_tensor) { std::vector input_dims; - std::transform(input.shapeInfo.dims.begin(), input.shapeInfo.dims.end(), std::back_inserter(input_dims), - [](uint32_t x) -> int64_t { return static_cast(x); }); + std::transform(input.dims.begin(), input.dims.end(), std::back_inserter(input_dims), + [](int64_t x) -> int64_t { return x; }); GeShape input_shape(input_dims); GeTensorDesc input_tensor_desc; input_tensor_desc.SetShape(input_shape); - input_tensor_desc.SetDataType(static_cast(input.dataType)); + input_tensor_desc.SetDataType(static_cast(input.data_type)); ge_inputs.emplace_back(input_tensor_desc); } // find graph @@ -1691,6 +2062,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { return; } } + // when set incre build, save cache helper. + graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); std::vector ge_models; @@ -1713,12 +2086,15 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { return; } - ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); - if (ret != SUCCESS) { - graph_node->SetRunFlag(false); - ReturnError(graph_manager, args.callback, ret, "PreRun failed, thread exit."); - graph_node->Unlock(); - return; + // check need incre build. + if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { + ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); + if (ret != SUCCESS) { + graph_node->SetRunFlag(false); + ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); + graph_node->Unlock(); + return; + } } graph_node->SetBuildFlag(true); graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); @@ -1726,8 +2102,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { ge_model = graph_node->GetGeModel(); } - graph_manager->run_args_q_.Push(RunArgs({graph_node, args.graph_id, args.input_tensor, args.output_tensor, ge_model, - GetThreadLocalContext(), args.callback})); + graph_manager->run_args_q_.Push( + RunArgs({graph_node, args.graph_id, args.input_tensor, ge_model, GetThreadLocalContext(), args.callback})); GELOGI("Loop end."); } } @@ -1770,8 +2146,8 @@ void GraphManager::RunThread(GraphManager *graph_manager) { graph_manager->graph_executor_.SetTrainFlag(graph_manager->options_.train_graph_flag); } - ret = graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeModel(), - args.input_tensor, args.output_tensor); + ret = + graph_manager->graph_executor_.ExecuteGraphAsync(args.graph_id, args.graph_node->GetGeModel(), args.input_tensor); args.graph_node->SetRunFlag(false); args.graph_node->Unlock(); if (ret != SUCCESS) { @@ -1793,15 +2169,15 @@ void GraphManager::StopQueue(GraphManager *graph_manager) { graph_manager->run_args_q_.Stop(); } -void GraphManager::ReturnError(GraphManager *graph_manager, std::function callback, Status ret, - const string &log) { +void GraphManager::ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log) { if (graph_manager == nullptr) { return; } GELOGE(ret, "%s.", log.c_str()); StopQueue(graph_manager); - callback(ret); + std::vector outputs; + callback(ret, outputs); } bool GraphManager::IsGraphNeedRebuild(uint32_t graph_id) { @@ -1838,4 +2214,86 @@ const map *GraphManager::GetGraphOptions(uint32_t grap } return &(graph_node->GetOptions()); } +Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, + uint64_t session_id) { + // graph partition + // all sub graph list of root graph and sub graph + GE_TIMESTAMP_START(GraphPartition); + auto ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); + if (ret != SUCCESS) { + GELOGE(ret, "Graph partition Failed"); + return ret; + } + GE_TIMESTAMP_END(GraphPartition, "GraphPartitioner::Partition1"); + GE_TIMESTAMP_START(SetSubgraph); + ret = SetSubgraph(session_id, compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Graph set subgraph Failed"); + return ret; + } + GE_TIMESTAMP_END(SetSubgraph, "SetSubGraph"); + + ComputeGraphPtr merged_compute_graph = nullptr; + std::vector merged_sub_graph_list; + + GE_TIMESTAMP_START(MergeSubgraph); + ret = MergeSubGraph(merged_compute_graph, compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "Merge SubGraph Failed"); + return ret; + } + GE_CHECK_NOTNULL(merged_compute_graph); + merged_compute_graph->SetSessionID(session_id); + merged_compute_graph->SetGraphID(graph_node->GetGraphId()); + GraphUtils::DumpGEGraph(merged_compute_graph, "mergedComputeGraph"); + GraphUtils::DumpGEGraphToOnnx(*merged_compute_graph, "mergedComputeGraph"); + for (auto &sub_graph : merged_compute_graph->GetAllSubgraphs()) { + sub_graph->SetSessionID(session_id); + sub_graph->SetGraphID(graph_node->GetGraphId()); + GraphUtils::DumpGEGraph(sub_graph, "mergedComputeGraph_subgraph"); + GraphUtils::DumpGEGraphToOnnx(*sub_graph, "mergedComputeGraph_subgraph"); + } + GE_TIMESTAMP_END(MergeSubgraph, "GraphManager::MergeSubGraph"); + compute_graph = merged_compute_graph; + return SUCCESS; +} +Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, + vector &ge_models, GeModelPtr &ge_model, uint64_t session_id) { + // build + if (compute_graph != nullptr) { + std::string graph_name = compute_graph->GetName(); + graph_name.append("_"); + graph_name.append(std::to_string(graph_node->GetGraphId())); + compute_graph->SetName(graph_name); + } + std::vector sub_graph_list; + auto ret = graph_builder_.Build(compute_graph, sub_graph_list, ge_model, session_id); + if (ret != SUCCESS) { + GELOGE(ret, "SubGraph build Failed."); + return ret; + } + + bool is_always_dump = false; + PropertiesManager &properties_manager = PropertiesManager::Instance(); + if (!properties_manager.GetDumpOutputPath().empty()) { + is_always_dump = true; + } + + GraphUtils::DumpGEGraph(compute_graph, "Build", is_always_dump); + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "Build"); + + // set modelptr to subgraph + for (const auto &sub_graph_info : sub_graph_list) { + sub_graph_info->SetGeModelPtr(ge_model); + } + + ge_models.push_back(ge_model); + + GE_IF_BOOL_EXEC(sub_graph_list.empty(), GELOGE(FAILED, "Input graph must have at least one calculation op Node"); + return FAILED;); + sub_graph_list[0]->SetSubGraph(compute_graph); + // set subgraphlist to graphnode + graph_node->SetSubGraph(sub_graph_list); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index 5a296b91..d13a2929 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -27,6 +27,7 @@ #include "common/blocking_queue.h" #include "common/ge_inner_error_codes.h" +#include "common/helper/model_cache_helper.h" #include "external/graph/types.h" #include "ge/ge_api_types.h" #include "graph/build/graph_builder.h" @@ -128,12 +129,11 @@ class GraphManager { /// @brief run graph async on session with specific session id /// @param [in] graph_id graph id /// @param [in] inputs input data - /// @param [out] outputs output data /// @param [out] callback: callback while run graph async finish /// @return Status result of function /// - Status RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, - std::vector &outputs, uint64_t session_id, std::function callback); + Status RunGraphAsync(const GraphId &graph_id, const std::vector &inputs, uint64_t session_id, + RunAsyncCallback callback); /// /// @ingroup ge_graph @@ -149,26 +149,26 @@ class GraphManager { bool IsGraphNeedRebuild(uint32_t graph_id); + Status GenerateInfershapeGraph(GraphId &graph_id); + const std::map *GetGraphOptions(uint32_t graph_id); private: struct PreRunArgs { GraphId graph_id; - std::vector input_tensor; - std::vector output_tensor; + std::vector input_tensor; uint64_t session_id; GEThreadLocalContext context; - std::function callback; + RunAsyncCallback callback; }; struct RunArgs { GraphNodePtr graph_node; GraphId graph_id; - std::vector input_tensor; - std::vector output_tensor; + std::vector input_tensor; GeModelPtr ge_model; GEThreadLocalContext context; - std::function callback; + RunAsyncCallback callback; }; Status GetGraphNode(const GraphId &graph_id, GraphNodePtr &out); @@ -180,6 +180,14 @@ class GraphManager { Status PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, vector &ge_models, GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); + Status PreRunDynShape(const GraphNodePtr &graph_node, const std::vector &inputs, + vector &ge_models, GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); + + Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); + + Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, vector &ge_models, + GeModelPtr &ge_model, uint64_t session_id); + Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector &inputs, vector &ge_models, uint64_t session_id = INVALID_SESSION_ID); @@ -211,7 +219,8 @@ class GraphManager { Status SummaryHandle(const GraphId &graph_id, std::vector &outputs); - Status CheckpointHandle(const GraphId &graph_id, const std::vector &outputs); + Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, + const std::vector &outputs); // call the callback function of ME to push summary result data to ME Status PushSummaryData2ME(const GraphId &graph_id, const std::map &summary_data); @@ -250,21 +259,33 @@ class GraphManager { Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph); + Status OptimizeStage1(ComputeGraphPtr &compute_graph); + Status OptimizeStage2(ComputeGraphPtr &compute_graph); Status OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); + Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); + Status LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); + bool CheckModelLoad(const GeModelPtr &ge_model, bool load_flag); + Status LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); bool IsGraphNeedBuild(const GraphNodePtr &graph_node); + Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); + Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); + Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); + void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); + Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); + void RemoveModelCacheHelper(const GraphId &graph_id); + static void PreRunThread(GraphManager *graph_manager); static void RunThread(GraphManager *graph_manager); static void StopQueue(GraphManager *graph_manager); - static void ReturnError(GraphManager *graph_manager, std::function callback, Status ret, - const string &log); + static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); std::atomic_bool thread_run_flag_; BlockingQueue prerun_args_q_{}; @@ -274,6 +295,8 @@ class GraphManager { std::map graph_map_; + std::map cache_helper_map_; + // for run graph synchronous return std::mutex sync_run_mutex_; std::condition_variable condition_; diff --git a/src/ge/graph/manager/graph_manager_utils.cc b/src/ge/graph/manager/graph_manager_utils.cc index a340ce35..dd5c5fbb 100644 --- a/src/ge/graph/manager/graph_manager_utils.cc +++ b/src/ge/graph/manager/graph_manager_utils.cc @@ -21,8 +21,8 @@ #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "common/op/attr_define.h" #include "common/string_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/compute_graph.h" #include "graph/op_desc.h" #include "graph/optimize/common/params.h" @@ -99,7 +99,8 @@ Status SubGraphInfo::FreeInOutBuffer() { GraphModelListener::GraphModelListener(std::mutex &mutex, std::condition_variable &cond) : result_code_(0), is_finished_(false), mutex_(mutex), condition_(cond) {} -Status GraphModelListener::OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result) { +Status GraphModelListener::OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, + std::vector &outputs) { GELOGI( "[GraphManager] graph compute call back, model_id:%u, task_id:%u, " "resultCode:%u.", @@ -129,15 +130,16 @@ Status GraphModelListener::ResetResult() { return SUCCESS; } -void RunAsyncListener::SetCallback(const std::function &callback) { +void RunAsyncListener::SetCallback(const RunAsyncCallback &callback) { sem_.Push(0); callback_ = callback; } -Status RunAsyncListener::OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result) { +Status RunAsyncListener::OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, + std::vector &outputs) { GELOGI("[GraphManager] run graph async call back, modelId:%u, taskId:%u, resultCode:%u.", model_id, task_id, result); GE_CHECK_NOTNULL(callback_); - callback_(result); + callback_(result, outputs); uint8_t unused; sem_.Pop(unused); return SUCCESS; @@ -148,7 +150,7 @@ bool HasCalcOp(const ComputeGraphPtr &graph) { return false; } - static const std::set calc_op_type = {domi::CONVOLUTION, domi::DECONVOLUTION, domi::FULL_CONNECTION}; + static const std::set calc_op_type = {CONVOLUTION, DECONVOLUTION, FULL_CONNECTION}; for (const auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); @@ -167,15 +169,15 @@ Status ParseOutNodes(const string &out_nodes) { domi::GetContext().out_nodes_map.clear(); domi::GetContext().user_out_nodes.clear(); - vector nodes_v = domi::StringUtils::Split(out_nodes, ';'); + vector nodes_v = StringUtils::Split(out_nodes, ';'); for (const string &node : nodes_v) { - vector key_value_v = domi::StringUtils::Split(node, ':'); + vector key_value_v = StringUtils::Split(node, ':'); if (key_value_v.size() != 2) { // must contain 2 items GELOGE(GE_GRAPH_PARAM_NULLPTR, "Invalid outNodes: %s", node.c_str()); return GE_GRAPH_PARAM_NULLPTR; } auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); - int32_t index = std::stoi(domi::StringUtils::Trim(key_value_v[1])); + int32_t index = std::stoi(StringUtils::Trim(key_value_v[1])); if (iter != domi::GetContext().out_nodes_map.end()) { iter->second.emplace_back(index); } else { diff --git a/src/ge/graph/manager/graph_manager_utils.h b/src/ge/graph/manager/graph_manager_utils.h index ca33aba1..b595e182 100644 --- a/src/ge/graph/manager/graph_manager_utils.h +++ b/src/ge/graph/manager/graph_manager_utils.h @@ -37,6 +37,7 @@ #include "graph/model.h" #include "model/ge_model.h" #include "register/register_fmk_types.h" +#include "external/ge/ge_api_types.h" namespace ge { // state for graph task in life cycle @@ -122,13 +123,14 @@ class RunAsyncListener : public ge::ModelListener { ~RunAsyncListener() = default; - void SetCallback(const std::function &callback); + void SetCallback(const RunAsyncCallback &callback); // callback - Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result) override; + Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, + std::vector &outputs) override; private: - std::function callback_; + RunAsyncCallback callback_; BlockingQueue sem_; }; @@ -190,7 +192,8 @@ class GraphModelListener : public ge::ModelListener { ~GraphModelListener() = default; // callback - Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result) override; + Status OnComputeDone(uint32_t model_id, uint32_t task_id, uint32_t result, + std::vector &outputs) override; Status ResetResult(); diff --git a/src/ge/graph/manager/graph_mem_allocator.cc b/src/ge/graph/manager/graph_mem_allocator.cc index f01a0b4b..95773f11 100644 --- a/src/ge/graph/manager/graph_mem_allocator.cc +++ b/src/ge/graph/manager/graph_mem_allocator.cc @@ -47,7 +47,7 @@ void MemoryAllocator::Finalize(uint32_t device_id) { memory_base_map_.clear(); } -uint8_t *MemoryAllocator::MallocMemory(uint64_t memory_size, uint32_t device_id) const { +uint8_t *MemoryAllocator::MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id) const { uint8_t *memory_addr = nullptr; if (rtMalloc(reinterpret_cast(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { @@ -60,7 +60,7 @@ uint8_t *MemoryAllocator::MallocMemory(uint64_t memory_size, uint32_t device_id) } GELOGI("MemoryAllocator::MallocMemory device_id = %u, size= %lu", device_id, memory_size); - GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "malloc function.", memory_size) + GE_PRINT_DYNAMIC_MEMORY(rtMalloc, purpose.c_str(), memory_size) return memory_addr; } @@ -74,14 +74,15 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con return ge::SUCCESS; } -uint8_t *MemoryAllocator::MallocMemory(const string &memory_key, uint64_t memory_size, uint32_t device_id) { +uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, + uint32_t device_id) { auto it = memory_base_map_.find(memory_key); if (it != memory_base_map_.end()) { it->second.memory_used_num_++; return it->second.memory_addr_; } - uint8_t *memory_addr = MallocMemory(memory_size, device_id); + uint8_t *memory_addr = MallocMemory(purpose, memory_size, device_id); if (memory_addr == nullptr) { GELOGE(ge::INTERNAL_ERROR, diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h index fa4bf42f..9622e07a 100644 --- a/src/ge/graph/manager/graph_mem_allocator.h +++ b/src/ge/graph/manager/graph_mem_allocator.h @@ -83,11 +83,12 @@ class MemoryAllocator { /// /// @ingroup ge_graph /// @brief malloc memory + /// @param [in] purpose memory usage /// @param [in] size memory size /// @param [in] device_id device id /// @return memory address /// - uint8_t *MallocMemory(uint64_t memory_size, uint32_t device_id = 0) const; + uint8_t *MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id = 0) const; /// /// @ingroup ge_graph @@ -101,12 +102,13 @@ class MemoryAllocator { /// /// @ingroup ge_graph /// @brief malloc memory + /// @param [in] purpose memory usage /// @param [in] memory_key memory key /// @param [in] size memory size /// @param [in] device_id device id /// @return memory address /// - uint8_t *MallocMemory(const string &memory_key, uint64_t memory_size, uint32_t device_id = 0); + uint8_t *MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, uint32_t device_id = 0); /// /// @ingroup ge_graph diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index 5b76a597..813e9256 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -19,11 +19,11 @@ #include #include "common/l2_cache_optimize.h" -#include "common/op/attr_define.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "ge/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/attr_utils.h" @@ -64,6 +64,10 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens return SUCCESS; } +void VarResource::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { + var_addr_mgr_map = var_addr_mgr_map_; +} + void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t memory_type) { std::string var_key = VarKey(var_name, tensor_desc); @@ -170,6 +174,14 @@ void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &b var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; } +ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { + if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) { + return FAILED; + } + broad_cast_info = var_broad_cast_info_[graph_id][var_name]; + return SUCCESS; +} + ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { if (var_op_desc == nullptr) { @@ -192,7 +204,7 @@ ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::stri GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); GE_CHECK_NOTNULL(var_op_desc); string var_is_broadcast; - bool is_broadcast = AttrUtils::GetStr(var_op_desc, domi::VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); + bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); if (!is_broadcast) { return SUCCESS; } @@ -210,7 +222,7 @@ ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_na const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { GE_CHECK_NOTNULL(var_op_desc); string var_is_broadcast; - bool is_broadcast = AttrUtils::GetStr(var_op_desc, domi::VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); + bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); if (!is_broadcast) { return SUCCESS; } @@ -291,6 +303,8 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } +void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; + VarManager::VarManager(uint64_t session_id) : version_(SessionVersion::OTHER_VERSION), session_id_(session_id), @@ -367,6 +381,21 @@ ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTenso return ge::SUCCESS; } +ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, + rtMemType_t memory_type) { + GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), + ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), + ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); + + std::lock_guard lock(mutex_); + if (var_resource_ == nullptr) { + GELOGW("VarManager has not been init."); + return ge::INTERNAL_ERROR; + } + var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type); + return ge::SUCCESS; +} + ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type) { std::lock_guard lock(mutex_); @@ -392,6 +421,10 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); } +void VarManager::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { + var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); +} + int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { std::lock_guard lock(mutex_); MemResource *mem_resource = nullptr; @@ -409,6 +442,30 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { return mem_resource->GetVarMemSize(); } +Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { + std::lock_guard lock(mutex_); + MemResource *mem_resource = nullptr; + auto iter = mem_resource_map_.find(memory_type); + if (iter == mem_resource_map_.end()) { + mem_resource = new (std::nothrow) MemResource(); + if (mem_resource == nullptr) { + GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); + return ge::INTERNAL_ERROR; + } else { + mem_resource_map_[memory_type] = mem_resource; + } + } else { + mem_resource = iter->second; + } + + if (mem_resource == nullptr) { + GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); + return FAILED; + } + mem_resource->UpdateVarMemSize(mem_size); + return SUCCESS; +} + ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, rtMemType_t memory_type) { std::lock_guard lock(mutex_); @@ -551,6 +608,16 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn return SUCCESS; } +ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { + std::lock_guard lock(mutex_); + + if (var_resource_ == nullptr) { + GELOGW("VarManager has not been init."); + return ge::INTERNAL_ERROR; + } + return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); +} + ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { std::lock_guard lock(mutex_); GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); @@ -592,7 +659,8 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { // align 512 BYTE var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; - var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(memory_key, var_memory_size); + const string purpose("variables and constant op memory in training network."); + var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size); if (var_mem_base == nullptr) { GELOGE(ge::INTERNAL_ERROR, "VarManager::MallocVarMemory failed " @@ -673,6 +741,7 @@ Status VarManager::SetMemoryMallocSize(const map &options) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); return ge::GE_GRAPH_OPTIONS_INVALID; } + GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); } it = options.find(VARIABLE_MEMORY_MAX_SIZE); diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index a23c45b6..6229837c 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -101,6 +101,8 @@ class VarResource { ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); + void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); + void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t rtMemType_t); @@ -113,6 +115,8 @@ class VarResource { void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); + ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); + ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); @@ -175,6 +179,8 @@ class MemResource { int64_t GetVarMemSize() const; + void UpdateVarMemSize(int64_t mem_size); + private: uint64_t total_size_; uint64_t var_mem_size_; @@ -196,9 +202,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t memory_type); + ge::Status SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, + rtMemType_t memory_type); + ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); + void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); + ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, @@ -206,6 +217,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); + ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); + ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, uint8_t *base_ptr); @@ -251,6 +264,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { int64_t GetVarMemSize(rtMemType_t memory_type); + Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); + bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); bool IsVarExist(const std::string &var_name); diff --git a/src/ge/graph/manager/util/debug.cc b/src/ge/graph/manager/util/debug.cc index 3ca4642d..b2ef1c92 100644 --- a/src/ge/graph/manager/util/debug.cc +++ b/src/ge/graph/manager/util/debug.cc @@ -31,7 +31,7 @@ Debug::Debug() = default; Debug::~Debug() = default; void Debug::DumpProto(const Message &proto, const char *file) { - std::string file_path = domi::RealPath(file); + std::string file_path = RealPath(file); int fd = open(file_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); if (fd == -1) { GELOGW("Write %s failed", file_path.c_str()); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index 6319f985..a1c4d769 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -23,14 +23,6 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -using domi::HCOM_ATTR_DATA_TYPE; -using domi::HCOM_ATTR_RANK_SIZE; -using domi::HCOM_ATTR_REDUCE_TYPE; -using domi::HCOM_ATTR_ROOT_RANK; -using domi::HCOM_ATTR_SHAPE; -using domi::HCOMRECEIVE; -using domi::HCOMREDUCESCATTER; - namespace ge { Status HcomOmeUtil::GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type) { GE_CHECK_NOTNULL(op_desc); diff --git a/src/ge/graph/manager/util/variable_accelerate_ctrl.cc b/src/ge/graph/manager/util/variable_accelerate_ctrl.cc index 3bc61e84..b62be02c 100644 --- a/src/ge/graph/manager/util/variable_accelerate_ctrl.cc +++ b/src/ge/graph/manager/util/variable_accelerate_ctrl.cc @@ -23,7 +23,7 @@ namespace ge { namespace { inline bool IsVariable(const std::string &node_type) { - return node_type == domi::VARIABLE || node_type == domi::VARIABLEV2 || node_type == domi::VARHANDLEOP; + return node_type == VARIABLE || node_type == VARIABLEV2 || node_type == VARHANDLEOP; } } // namespace diff --git a/src/ge/graph/optimize/common/params.h b/src/ge/graph/optimize/common/params.h index 403e1aa8..ee2a735b 100644 --- a/src/ge/graph/optimize/common/params.h +++ b/src/ge/graph/optimize/common/params.h @@ -22,10 +22,6 @@ #include "common/singleton.h" #include "common/types.h" -using domi::TARGET_TYPE_LTTE_8BIT; -using domi::TARGET_TYPE_MINI_8BIT; -using domi::TARGET_TYPE_TINY_8BIT; - namespace ge { class Params : public Singleton { public: diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index f1fe27b9..84cc77f9 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -26,11 +26,6 @@ #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" -using domi::ATTR_NAME_FRAMEWORK_FWK_TYPE; -using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; -using ge::ComputeGraph; -using ge::OpDesc; - namespace { const char *const kVectorCore = "VectorCore"; const char *const kVectorEngine = "VectorEngine"; @@ -73,15 +68,14 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) { peer_out_anchor == nullptr, GELOGW("peer_out_anchor is nullptr! node: %s", node->GetName().c_str()); continue); ge::NodePtr src_node = peer_out_anchor->GetOwnerNode(); - src_name_list = node_op_desc->GetSrcName(); src_index_list = node_op_desc->GetSrcIndex(); src_name_list.emplace_back(src_node->GetName()); src_index_list.emplace_back(peer_out_anchor->GetIdx()); node_op_desc->SetSrcName(src_name_list); node_op_desc->SetSrcIndex(src_index_list); - GE_IF_BOOL_EXEC(!(node_op_desc->GetType() == domi::NETOUTPUT && domi::GetContext().type == domi::FMK_TYPE_T), + GE_IF_BOOL_EXEC(!(node_op_desc->GetType() == NETOUTPUT && domi::GetContext().type == domi::FMK_TYPE_T), ge::NodePtr peer_owner_node = peer_out_anchor->GetOwnerNode(); - input_name_list = node_op_desc->GetInputName(); input_name_list.emplace_back( + input_name_list.emplace_back( peer_owner_node->GetName() + (peer_out_anchor->GetIdx() == 0 ? "" : ": " + to_string(peer_out_anchor->GetIdx()))); node_op_desc->SetInputName(input_name_list);) @@ -160,6 +154,47 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { return ret; } +Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { + GELOGD("NewOptimizeOriginalGraph in"); + if (compute_graph == nullptr) { + GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); + return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; + } + + Status ret = SUCCESS; + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeOriginalGraph failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + + std::map graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); + GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", + graph_optimizer.size()); + string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; + GELOGD("[OptimizeOriginalGraph]: engine type will exclude: %s", exclude_core_Type.c_str()); + if (graph_optimizer.size() != 0) { + for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { + if (iter->first == exclude_core_Type) { + continue; + } + ret = (iter->second)->OptimizeOriginalGraph(*compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[OptimizeOriginalGraph]: graph optimize failed, ret:%d", ret); + return ret; + } + + // call fe + ret = (iter->second)->OptimizeOriginalGraphJudgeInsert(*compute_graph); + if (ret != SUCCESS) { + GELOGE(ret, "[OptimizeOriginalGraphForInsert]: graph optimize failed, ret:%d", ret); + return ret; + } + } + } + return ret; +} + Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[OptimizeOriginalGraph]: compute_graph is nullptr."); diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index ceb50e3f..83b5489f 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -47,6 +47,9 @@ class GraphOptimize { // original graph optimize Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); + // new original graph optimize + Status NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph); + // for fe prepare optimize in quantize scene Status OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_graph); diff --git a/src/ge/graph/optimize/summary_optimize.cc b/src/ge/graph/optimize/summary_optimize.cc index 3347f042..8b38d602 100644 --- a/src/ge/graph/optimize/summary_optimize.cc +++ b/src/ge/graph/optimize/summary_optimize.cc @@ -23,9 +23,12 @@ #include "graph/utils/tensor_utils.h" #include "omg/omg_inner_types.h" +namespace { +const char *const kSummary = "Summary"; +const int kMaxMapSize = 10000; +} // namespace + namespace ge { -static const char *const kSummary = "Summary"; -static const int kMaxMapSize = 10000; Status GraphOptimize::HandleSummaryOp(ComputeGraphPtr &compute_graph) { GELOGI("[HandleSummaryOp] HandleSummaryOp start!"); if (summary_output_indexes_.size() >= kMaxMapSize) { diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc new file mode 100644 index 00000000..bbf31e5c --- /dev/null +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -0,0 +1,789 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/partition/dynamic_shape_partition.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" + +#define REQUIRE(cond, ...) \ + do { \ + if (!(cond)) { \ + GELOGE(FAILED, "[Dynamic shape partition]" __VA_ARGS__); \ + return FAILED; \ + } \ + } while (0) + +#define REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) +#define REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) +#define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) + +namespace { +const bool kDebugging = (std::getenv("DEBUG_DYNAMIC_PARTITION") != nullptr); +} // namespace + +#define DLOG() \ + if (kDebugging) std::cerr +namespace ge { +using Cluster = DynamicShapePartitioner::Cluster; +using ClusterPtr = std::shared_ptr; + +Status DynamicShapePartitioner::Partition() { + REQUIRE_NOT_NULL(root_graph_, "Graph is nullptr."); + DLOG() << "Start dynamic shape partition graph " << root_graph_->GetName() << std::endl; + REQUIRE_SUCCESS(MarkUnknowShapeNodes(), "Failed mark unknow shape nodes."); + if (unknown_shape_nodes_.empty()) { + DLOG() << "Skip dynamic shape partition of graph " << root_graph_->GetName() << " as all nodes are known shape." + << std::endl; + REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", false), + "Failed set dynamic shape partitioned flag on root graph."); + return SUCCESS; + } + REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", true), + "Failed set dynamic shape partitioned flag on root graph."); + DumpGraph("_Before_DSP"); + auto status = PartitionImpl(); + DLOG() << DebugString() << std::endl; + if (status != SUCCESS) { + GELOGE(status, "Failed dynamic shape partition graph: %s, status:\n %s", root_graph_->GetName().c_str(), + DebugString().c_str()); + } + DumpGraph("_After_DSP"); + DLOG() << (status == SUCCESS ? "Succeed" : "Failed") << " dynamic shape partition graph " << root_graph_->GetName() + << std::endl; + ClearResource(); + return status; +} + +Status DynamicShapePartitioner::PartitionImpl() { + REQUIRE_SUCCESS(root_graph_->TopologicalSorting(), "Graph topological sort failed."); + REQUIRE_SUCCESS(InitClusters(), "Failed init cluster nodes."); + REQUIRE_SUCCESS(MergeClusters(), "Failed merge clusters."); + PruneUniqueClusters(); + REQUIRE_SUCCESS(BuildPartitionFrame(), "Failed build cluster partition frame."); + REQUIRE_SUCCESS(CombinePartitionFrame(), "Failed combine cluster partition frame."); + REQUIRE_SUCCESS(BuildPartitionSubgraph(), "Failed build cluster partition subgraph."); + return SUCCESS; +} + +void DynamicShapePartitioner::PruneUniqueClusters() { + for (auto &node : root_graph_->GetDirectNode()) { + auto cluster = node_2_cluster_[node]; + if (unique_clusters_.count(cluster) != 0) { + continue; + } + unique_clusters_.insert(cluster); + } +} + +Status DynamicShapePartitioner::BuildPartitionFrame() { + for (auto cluster : unique_clusters_) { + REQUIRE_SUCCESS(cluster->BuildFrame(), "Failed build frame of cluster[%lu].", cluster->Id()); + } + return SUCCESS; +} + +Status DynamicShapePartitioner::CombinePartitionFrame() { + for (auto cluster : unique_clusters_) { + REQUIRE_SUCCESS(cluster->CombinePartitionFrame(), "Failed combine frame of cluster[%lu].", cluster->Id()); + } + return SUCCESS; +} + +Status DynamicShapePartitioner::BuildPartitionSubgraph() { + for (auto cluster : unique_clusters_) { + REQUIRE_SUCCESS(cluster->BuildPartitionSubgraph(), "Failed build subgraph of cluster[%lu].", cluster->Id()); + } + return SUCCESS; +} + +std::string DynamicShapePartitioner::DebugString() { + size_t unknow = 0; + size_t know = 0; + size_t data = 0; + size_t netoutput = 0; + std::stringstream ss; + ss << "All unknow shape nodes:" << std::endl; + for (auto node : unknown_shape_nodes_) { + ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; + } + for (auto cluster : unique_clusters_) { + if (cluster->IsUnknowShape()) { + unknow++; + } else if (cluster->IsKnowShape()) { + know++; + } else if (cluster->IsData()) { + data++; + } else if (cluster->IsNetOutput()) { + netoutput++; + } + } + ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", know:" << know << ", unknow:" << unknow + << ", netoutput:" << netoutput << std::endl; + for (auto cluster : unique_clusters_) { + ss << " " << cluster->DebugString() << std::endl; + } + return ss.str(); +} + +void DynamicShapePartitioner::DumpGraph(std::string suffix) { + GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); + for (auto sub_graph : root_graph_->GetAllSubgraphs()) { + GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); + } +} + +void DynamicShapePartitioner::ClearResource() { + for (auto cluster : unique_clusters_) { + cluster->Clear(); + } + node_2_cluster_.clear(); + ordered_cluster_.clear(); + unique_clusters_.clear(); + unknown_shape_nodes_.clear(); + root_graph_.reset(); +} + +Status DynamicShapePartitioner::MarkUnknowShapeNodes() { + auto graph = root_graph_; + for (auto &node : graph->GetDirectNode()) { + REQUIRE_SUCCESS(CollectSpreadUnknowShapeNodes(node), "Failed collect spread unknow shape nodes %s.", + node->GetName().c_str()); + } + return SUCCESS; +} + +Status DynamicShapePartitioner::InitClusters() { + auto graph = root_graph_; + size_t rank = 0; + for (const auto node : graph->GetDirectNode()) { + Cluster::Type type = Cluster::DATA; + if (node->GetType() == DATA) { + type = Cluster::DATA; + } else if (node->GetType() == NETOUTPUT) { + type = Cluster::NETOUTPUT; + } else if (unknown_shape_nodes_.count(node) > 0) { + type = Cluster::UNKNOW_SHAPE; + } else { + type = Cluster::KNOW_SHAPE; + } + auto cluster = MakeShared(rank++, type, node, this); + REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); + node_2_cluster_[node] = cluster; + if (cluster->IsUnknowShape()) { + ordered_cluster_.push_back(cluster); + } + // Already sorted topologically, so access to the parent cluster is safe + for (const auto &parent : node->GetInAllNodes()) { + cluster->AddInput(node_2_cluster_[parent]); + } + } + if (kDebugging) { + for (const auto node : graph->GetDirectNode()) { + DLOG() << "Make cluster for node :" << node->GetName() << ":" << node_2_cluster_[node]->DebugString() + << std::endl; + } + } + return SUCCESS; +} + +Status DynamicShapePartitioner::TopologicalSortClusters() { + ordered_cluster_.clear(); + // BFS topological sort clusters for know shape cluster + std::queue ready_clusters; + std::unordered_map cluster_pending_count; + std::unordered_set seen_clusters; + for (auto iter = node_2_cluster_.begin(); iter != node_2_cluster_.end(); iter++) { + auto cluster = iter->second; + if (seen_clusters.count(cluster) != 0) { + continue; + } + seen_clusters.insert(cluster); + auto pending_count = cluster->Inputs().size(); + if (pending_count == 0) { + ready_clusters.push(cluster); + } else { + cluster_pending_count[cluster] = pending_count; + } + } + size_t rank = 0; + while (!ready_clusters.empty()) { + auto cluster = ready_clusters.front(); + ready_clusters.pop(); + cluster->UpdateRank(rank++); + if (cluster->IsKnowShape()) { + ordered_cluster_.push_back(cluster); + } + for (auto out_cluster : cluster->Outputs()) { + if (--cluster_pending_count[out_cluster] == 0) { + ready_clusters.push(out_cluster); + } + } + } + if (rank != seen_clusters.size()) { + return FAILED; + } + return SUCCESS; +} + +namespace { +template +static std::string ToString(T vec) { + if (vec.empty()) { + return "()"; + } + std::stringstream ss; + ss << "("; + auto iter = vec.begin(); + for (size_t i = 0; i < vec.size() - 1; i++) { + ss << (*iter++)->Id() << ","; + } + ss << (*iter++)->Id() << ")."; + return ss.str(); +} +} // namespace + +Status DynamicShapePartitioner::MergeClusters() { + // Merge unknow shape clusters + for (auto cluster : ordered_cluster_) { + for (auto in_cluster : cluster->Inputs()) { + if (in_cluster->IsUnknowShape()) { + auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); + DLOG() << "Merge all path cluster from " << in_cluster->Id() << " to " << cluster->Id() + << ToString(merged_clusters) << std::endl; + for (auto merged_cluster : merged_clusters) { + for (auto node : merged_cluster->Nodes()) { + node_2_cluster_[node] = cluster; + } + } + } + } + } + REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknow shape clusters."); + // Merge know shape clusters + for (auto cluster : ordered_cluster_) { + for (auto in_cluster : cluster->Inputs()) { + if (in_cluster->IsKnowShape()) { + if (cluster->TryMerge(in_cluster)) { + DLOG() << "Success merge known shape cluster " << in_cluster->Id() << " to " << cluster->Id() << "." + << std::endl; + for (auto node : in_cluster->Nodes()) { + node_2_cluster_[node] = cluster; + } + } + } + } + } + return SUCCESS; +} + +Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { + if (unknown_shape_nodes_.count(node) > 0) { + return SUCCESS; + } + auto opdesc = node->GetOpDesc(); + size_t anchor_index = 0; + bool is_unknow = false; + for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { + if (IsUnknowShapeTensor(out_tensor)) { + DLOG() << "Collect node " << node->GetName() << " as unknown as output " << anchor_index << " is unknown" + << std::endl; + is_unknow = true; + auto anchor = node->GetOutDataAnchor(anchor_index); + for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { + if (peer_anchor != nullptr) { + DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown input from " + << node->GetName() << ":" << anchor_index << std::endl; + unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); + } + } + } + anchor_index++; + } + anchor_index = 0; + for (auto &in_tensor : opdesc->GetAllInputsDesc()) { + if (IsUnknowShapeTensor(in_tensor)) { + DLOG() << "Collect node " << node->GetName() << " as unknown as input " << anchor_index << " is unknown" + << std::endl; + is_unknow = true; + auto anchor = node->GetInDataAnchor(anchor_index); + const auto peer_anchor = anchor->GetPeerOutAnchor(); + if (peer_anchor != nullptr) { + DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown output to " + << node->GetName() << ":" << anchor_index << std::endl; + unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); + } + } + anchor_index++; + } + if (is_unknow) { + unknown_shape_nodes_.insert(node); + } else { + auto graph = root_graph_; + for (const auto &subgraph_name : opdesc->GetSubgraphInstanceNames()) { + auto subgraph = graph->GetSubgraph(subgraph_name); + REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), + node->GetName().c_str()); + bool is_graph_unknow = false; + REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_graph_unknow), "Failed check subgraph %s shape of node %s.", + subgraph_name.c_str(), node->GetName().c_str()); + if (is_graph_unknow) { + DLOG() << "Collect node " << node->GetName() << " as its subgraph " << subgraph->GetName() << " is unknown." + << std::endl; + unknown_shape_nodes_.insert(node); + break; + } + } + } + return SUCCESS; +} + +Status DynamicShapePartitioner::IsUnknowShapeNode(NodePtr node, bool &is_unknow) { + auto opdesc = node->GetOpDesc(); + auto graph = root_graph_; + for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { + if (IsUnknowShapeTensor(out_tensor)) { + DLOG() << "Mark node " << node->GetName() << " unknown because unknown output " << std::endl; + is_unknow = true; + return SUCCESS; + } + } + for (auto &in_tensor : opdesc->GetAllInputsDesc()) { + if (IsUnknowShapeTensor(in_tensor)) { + DLOG() << "Mark node " << node->GetName() << " unknown because unknown intput " << std::endl; + is_unknow = true; + return SUCCESS; + } + } + for (auto &subgraph_name : opdesc->GetSubgraphInstanceNames()) { + auto subgraph = graph->GetSubgraph(subgraph_name); + REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), + node->GetName().c_str()); + REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_unknow), "Failed check subgraph %s shape of node %s.", + subgraph_name.c_str(), node->GetName().c_str()); + if (is_unknow) { + DLOG() << "Mark node " << node->GetName() << " unknown because unknown subgraph " << std::endl; + return SUCCESS; + } + } + is_unknow = false; + return SUCCESS; +} + +Status DynamicShapePartitioner::IsUnknowShapeGraph(ComputeGraphPtr graph, bool &is_unknow) { + for (auto &node : graph->GetDirectNode()) { + REQUIRE_SUCCESS(IsUnknowShapeNode(node, is_unknow), "Failed check node %s shape on graph %s.", + node->GetName().c_str(), graph->GetName().c_str()); + if (is_unknow) { + DLOG() << "Mark graph " << graph->GetName() << " unknown because unknown node " << node->GetName() << std::endl; + return SUCCESS; + } + } + return SUCCESS; +} + +bool DynamicShapePartitioner::IsUnknowShapeTensor(GeTensorDesc &tensor) { + const static int kUnknowShape = -1; + const static int kUnknowRank = -2; + for (auto dim_size : tensor.GetShape().GetDims()) { + if (dim_size == kUnknowShape || dim_size == kUnknowRank) { + return true; + } + } + return false; +} + +std::string Cluster::DebugString() { + std::stringstream ss; + switch (type_) { + case DATA: + ss << "DATA"; + break; + case NETOUTPUT: + ss << "NETOUTPUT"; + break; + case UNKNOW_SHAPE: + ss << "UNKNOW"; + break; + case KNOW_SHAPE: + ss << "KNOW"; + break; + } + ss << "[" << id_ << "](size:" << nodes_.size() << ")"; + ss << "(" << min_ << "," << max_ << ")("; + for (auto cluster : in_clusters_) { + ss << cluster->id_ << ","; + } + ss << ")->("; + for (auto cluster : out_clusters_) { + ss << cluster->id_ << ","; + } + ss << ")|"; + for (auto node : nodes_) { + ss << (node->GetName() + "|"); + } + return ss.str(); +} + +size_t Cluster::Id() { return id_; } +void Cluster::UpdateRank(size_t rank) { + max_ = rank; + min_ = rank; +}; +bool Cluster::IsData() { return type_ == DATA; }; +bool Cluster::IsKnowShape() { return type_ == KNOW_SHAPE; }; +bool Cluster::IsUnknowShape() { return type_ == UNKNOW_SHAPE; }; +bool Cluster::IsNetOutput() { return type_ == NETOUTPUT; }; +bool Cluster::IsolatedConstant() { + return ((nodes_.size() == 1) && (nodes_[0]->GetType() == CONSTANTOP) && (out_clusters_.size() == 1) && + (*out_clusters_.begin())->IsUnknowShape() && in_clusters_.empty()); +} +void Cluster::AddInput(ClusterPtr in) { + in_clusters_.insert(in); + in->out_clusters_.insert(shared_from_this()); +}; +void Cluster::RemoveInput(ClusterPtr in) { + in_clusters_.erase(in); + in->out_clusters_.erase(shared_from_this()); +}; +void Cluster::AddOutput(ClusterPtr out) { + out_clusters_.insert(out); + out->in_clusters_.insert(shared_from_this()); +}; +void Cluster::RemoveOutput(ClusterPtr out) { + out_clusters_.erase(out); + out->in_clusters_.erase(shared_from_this()); +}; +void Cluster::Merge(ClusterPtr other) { + nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end()); + other->in_clusters_.erase(shared_from_this()); + other->out_clusters_.erase(shared_from_this()); + in_clusters_.erase(other); + out_clusters_.erase(other); + auto in_clusters = other->in_clusters_; + for (auto cluster : in_clusters) { + cluster->RemoveOutput(other); + cluster->AddOutput(shared_from_this()); + } + auto out_clusters = other->out_clusters_; + for (auto cluster : out_clusters) { + cluster->RemoveInput(other); + cluster->AddInput(shared_from_this()); + } + if (other->max_ > max_) { + max_ = other->max_; + } + if (other->min_ < min_) { + min_ = other->min_; + } +}; +bool Cluster::TryMerge(ClusterPtr other) { + std::queue forward_reached; + forward_reached.push(other); + while (!forward_reached.empty()) { + auto current_cluster = forward_reached.front(); + forward_reached.pop(); + for (auto cluster : current_cluster->out_clusters_) { + if (cluster->max_ == max_ && current_cluster != other) { + return false; + } else if (cluster->min_ < max_) { + forward_reached.push(cluster); + } + } + } + Merge(other); + return true; +}; +std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { + std::queue forward_reached_queue; + std::queue backward_reached_queue; + + std::unordered_set forward_reached_clusters; + std::unordered_set backward_reached_clusters; + std::vector path_clusters; + + if (other->out_clusters_.count(shared_from_this()) == 0) { + return path_clusters; + } + path_clusters.push_back(other); + forward_reached_queue.push(other); + backward_reached_queue.push(shared_from_this()); + while (!forward_reached_queue.empty()) { + auto current_cluster = forward_reached_queue.front(); + forward_reached_queue.pop(); + for (auto cluster : current_cluster->out_clusters_) { + if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) { + forward_reached_clusters.insert(cluster); + forward_reached_queue.push(cluster); + } + } + } + while (!backward_reached_queue.empty()) { + auto current_cluster = backward_reached_queue.front(); + backward_reached_queue.pop(); + for (auto cluster : current_cluster->in_clusters_) { + if (cluster->max_ > other->min_ && cluster->max_ != other->max_ && + backward_reached_clusters.count(cluster) == 0) { + backward_reached_clusters.insert(cluster); + backward_reached_queue.push(cluster); + if (forward_reached_clusters.count(cluster) != 0) { + path_clusters.push_back(cluster); + } + } + } + } + for (auto cluster : path_clusters) { + Merge(cluster); + } + return path_clusters; +} +std::unordered_set Cluster::Inputs() { return in_clusters_; }; +std::unordered_set Cluster::Outputs() { return out_clusters_; }; +std::vector Cluster::Nodes() { return nodes_; }; + +void Cluster::AddFrameInput(InDataAnchorPtr anchor) { + inputs_index_[anchor] = inputs_.size(); + inputs_.push_back(anchor); +}; + +void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { + outputs_index_[anchor] = outputs_.size(); + outputs_.push_back(anchor); +}; + +InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { + return partition_node_->GetInDataAnchor(inputs_index_[anchor]); +}; + +OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { + return partition_node_->GetOutDataAnchor(outputs_index_[anchor]); +}; + +InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; + +OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; + +Status Cluster::BuildFrame() { + if (IsUnknowShape() || IsKnowShape()) { + return BuildPartitionFrame(); + } else { + auto node = nodes_.front(); + auto in_control_anchor = node->GetInControlAnchor(); + if (in_control_anchor != nullptr) { + for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; + if (src_cluster->id_ != id_) { + auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; + GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor); + control_inputs_.insert(src_cluster); + src_cluster->control_outputs_.insert(peer_out_control_anchor); + } + } + } + if (IsData()) { + for (auto anchor : node->GetAllOutDataAnchors()) { + AddFrameOutput(anchor); + } + } else { + for (auto anchor : node->GetAllInDataAnchors()) { + AddFrameInput(anchor); + } + } + partition_node_ = node; + } + return SUCCESS; +} + +Status Cluster::BuildPartitionFrame() { + auto graph = partitioner_->root_graph_; + bool is_unknown_shape = IsUnknowShape(); + std::string sub_graph_name = + graph->GetName() + "_sub_" + std::to_string(unique_id_) + (is_unknown_shape ? "_unknow" : "_know"); + subgraph_ = MakeShared(sub_graph_name); + REQUIRE_NOT_NULL(subgraph_, "Failed new memory for subgraph."); + auto partition_op = MakeShared("PartitionedCall_" + std::to_string(unique_id_++), "PartitionedCall"); + REQUIRE_NOT_NULL(partition_op, "Failed new memory for partition op."); + REQUIRE(AttrUtils::SetBool(partition_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), + "Failed set _is_unknown_shape flag on partitioned op %s.", partition_op->GetName().c_str()); + REQUIRE_GRAPH_SUCCESS(partition_op->AddSubgraphName(subgraph_->GetName()), "Failed add subgraph name."); + REQUIRE_GRAPH_SUCCESS(partition_op->SetSubgraphInstanceName(0, subgraph_->GetName()), + "Failed set subgraph instance name."); + for (auto &node : nodes_) { + REQUIRE_NOT_NULL(subgraph_->AddNode(node), "Failed add node to subgraph."); + REQUIRE(AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), + "Failed set shape flag."); + REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node), "Failed remove root graph node."); + REQUIRE_GRAPH_SUCCESS(node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph."); + for (auto anchor : node->GetAllInDataAnchors()) { + auto peer_out_anchor = anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + continue; // Skip overhang input. + } + auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()]; + if (src_cluster->id_ != id_) { + AddFrameInput(anchor); + REQUIRE_GRAPH_SUCCESS(partition_op->AddInputDesc(node->GetOpDesc()->GetInputDesc(anchor->GetIdx())), + "Failed add input desc."); + } + } + auto in_control_anchor = node->GetInControlAnchor(); + if (in_control_anchor != nullptr) { + for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + if (peer_out_control_anchor == nullptr) { + continue; + } + auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; + if (src_cluster->id_ != id_) { + REQUIRE_GRAPH_SUCCESS( + GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor), + "Failed remove edge from %s:%d to %s:%d.", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), + peer_out_control_anchor->GetIdx(), node->GetName().c_str(), in_control_anchor->GetIdx()); + control_inputs_.insert(src_cluster); + src_cluster->control_outputs_.insert(peer_out_control_anchor); + } + } + } + for (auto anchor : node->GetAllOutDataAnchors()) { + auto peer_in_anchors = anchor->GetPeerInDataAnchors(); + for (auto peer_in_anchor : peer_in_anchors) { + auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()]; + if (src_cluster->id_ != id_) { + AddFrameOutput(anchor); + REQUIRE_GRAPH_SUCCESS(partition_op->AddOutputDesc(node->GetOpDesc()->GetOutputDesc(anchor->GetIdx())), + "Failed add output desc."); + break; + } + } + } + } + partition_node_ = graph->AddNode(partition_op); + REQUIRE_NOT_NULL(partition_node_, "Failed add partition node."); + REQUIRE_GRAPH_SUCCESS(partition_node_->SetOwnerComputeGraph(graph), "Failed set owner graph."); + subgraph_->SetParentNode(partition_node_); + subgraph_->SetParentGraph(graph); + REQUIRE_GRAPH_SUCCESS(graph->AddSubgraph(subgraph_), "Failed add subgraph to root graph."); + std::string session_graph_id; + REQUIRE(AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), + "Failed get ATTR_NAME_SESSION_GRAPH_ID on root graph."); + REQUIRE(AttrUtils::SetStr(*subgraph_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), + "Failed set ATTR_NAME_SESSION_GRAPH_ID on subgraph."); + return SUCCESS; +} + +Status Cluster::CombinePartitionFrame() { + for (auto anchor : inputs_) { + auto peer_out_anchor = anchor->GetPeerOutAnchor(); + auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()]; + auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor); + auto dst_anchor = GetFrameInDataAnchor(anchor); + REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveEdge(peer_out_anchor, anchor), "Failed remove edge from %s:%d to %s:%d.", + peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), + anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx()); + REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "Failed add edge from %s:%d to %s:%d.", + src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(), + dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx()); + } + for (auto src_cluster : control_inputs_) { + auto src_anchor = src_cluster->GetFrameOutControlAnchor(); + auto dst_anchor = GetFrameInControlAnchor(); + REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "Failed add edge from %s:%d to %s:%d.", + src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(), + dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx()); + } + return SUCCESS; +} + +Status Cluster::BuildPartitionSubgraph() { + if (IsData() || IsNetOutput()) { + return SUCCESS; + } + int64_t parent_node_index = 0; + for (auto anchor : inputs_) { + auto data_op = MakeShared(std::string("Data_") + std::to_string(parent_node_index), ge::DATA); + REQUIRE_NOT_NULL(data_op, "Failed new memory for data op."); + auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx()); + REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc), "Failed add output desc."); + REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), + "Failed set parent_node_index on subgraph data node."); + auto data_node = subgraph_->AddNode(data_op); + REQUIRE_NOT_NULL(data_node, "Failed add data node to subgraph."); + REQUIRE_GRAPH_SUCCESS(data_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of data node."); + REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), anchor), + "Faile add data input edge to %s:%d", anchor->GetOwnerNode()->GetName().c_str(), + anchor->GetIdx()); + parent_node_index++; + } + if (outputs_.empty() && control_outputs_.empty()) { + return SUCCESS; + } + auto net_output_op = MakeShared(NODE_NAME_NET_OUTPUT, ge::NETOUTPUT); + REQUIRE_NOT_NULL(net_output_op, "Failed new memory for netoutput op."); + for (size_t i = 0; i < outputs_.size(); ++i) { + GeTensorDesc input_desc; + REQUIRE_GRAPH_SUCCESS(net_output_op->AddInputDesc(input_desc), "Failed add input desc."); + } + auto net_output_node = subgraph_->AddNode(net_output_op); + REQUIRE_NOT_NULL(net_output_node, "Failed add netoutput node to subgraph."); + REQUIRE_GRAPH_SUCCESS(net_output_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of netoutput node."); + parent_node_index = 0; + for (auto anchor : outputs_) { + auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()); + REQUIRE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), + "Failed set parent_node_index on subgraph netoutput's input."); + REQUIRE_GRAPH_SUCCESS(net_output_op->UpdateInputDesc(parent_node_index, output_desc), + "Failed update input desc of netoutput node."); + + REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInDataAnchor(parent_node_index)), + "Faile add edge from %s:%d to netoutput node.", anchor->GetOwnerNode()->GetName().c_str(), + anchor->GetIdx()); + parent_node_index++; + } + for (auto anchor : control_outputs_) { + REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()), + "Faile add control edge from %s:%d to netoutput node.", + anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx()); + } + return SUCCESS; +} +void Cluster::Clear() { + in_clusters_.clear(); + out_clusters_.clear(); + nodes_.clear(); + partitioner_ = nullptr; + inputs_index_.clear(); + outputs_index_.clear(); + inputs_.clear(); + outputs_.clear(); + control_inputs_.clear(); + control_outputs_.clear(); + partition_node_.reset(); + subgraph_.reset(); +} + +size_t Cluster::unique_id_ = 0; +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/partition/dynamic_shape_partition.h b/src/ge/graph/partition/dynamic_shape_partition.h new file mode 100644 index 00000000..8734d7aa --- /dev/null +++ b/src/ge/graph/partition/dynamic_shape_partition.h @@ -0,0 +1,158 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PARTITION_DYNAMIC_SHAPE_PARTITION_H_ +#define GE_GRAPH_PARTITION_DYNAMIC_SHAPE_PARTITION_H_ + +#include +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "graph/compute_graph.h" + +namespace ge { +class DynamicShapePartitioner { + public: + // An cluster means set of nodes that can be merged in same partition, + // Corresponding relationship between cluster type and node: + // DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. + class Cluster : public std::enable_shared_from_this { + public: + enum Type { DATA, NETOUTPUT, KNOW_SHAPE, UNKNOW_SHAPE }; + explicit Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) + : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { + nodes_.push_back(node); + } + ~Cluster() = default; + std::string DebugString(); + // Basic bean functions + size_t Id(); + void UpdateRank(size_t rank); + bool IsData(); + bool IsKnowShape(); + bool IsUnknowShape(); + bool IsNetOutput(); + std::unordered_set> Inputs(); + std::unordered_set> Outputs(); + std::vector Nodes(); + bool IsolatedConstant(); + // Cluster modify functions + void AddInput(std::shared_ptr in); + void RemoveInput(std::shared_ptr in); + void AddOutput(std::shared_ptr out); + void RemoveOutput(std::shared_ptr out); + // Merge other cluster to this cluster, Whether it leads to a ring or not + // Merge src to dst means: + // All links to src will break and link to dst instead + // All nodes of src will change its owner to dst + // Update max and min rank of dst + void Merge(std::shared_ptr other); + // Try merge other cluster to this cluster, ONLY if will not leads to a ring + bool TryMerge(std::shared_ptr other); + // Merge all clusters on path(s) from other to this + std::vector> MergeAllPathFrom(std::shared_ptr other); + // Convert cluster to functioned call functions + void AddFrameInput(InDataAnchorPtr anchor); + void AddFrameOutput(OutDataAnchorPtr anchor); + InDataAnchorPtr GetFrameInDataAnchor(InDataAnchorPtr anchor); + OutDataAnchorPtr GetFrameOutDataAnchor(OutDataAnchorPtr anchor); + InControlAnchorPtr GetFrameInControlAnchor(); + OutControlAnchorPtr GetFrameOutControlAnchor(); + Status BuildFrame(); + Status BuildPartitionFrame(); + Status CombinePartitionFrame(); + Status BuildPartitionSubgraph(); + // Clear resource and break circular dependency + void Clear(); + + private: + static size_t unique_id_; + size_t id_; + // Each Cluster records the maximum and minimum topological order of its node + size_t min_; // maximum topological order + size_t max_; // minimum topological order + Type type_; + std::unordered_set> in_clusters_; + std::unordered_set> out_clusters_; + std::vector nodes_; + // Fileds for build partitoned call and subgraph + DynamicShapePartitioner *partitioner_; // Not owned, the partitioner this cluster belongs to + std::unordered_map inputs_index_; + std::unordered_map outputs_index_; + std::vector inputs_; + std::vector outputs_; + std::unordered_set> control_inputs_; + std::unordered_set control_outputs_; + NodePtr partition_node_; // corresponding partitioned call node + ComputeGraphPtr subgraph_; // corresponding subgraph + }; + explicit DynamicShapePartitioner(ge::ComputeGraphPtr graph) : root_graph_(graph) {} + ~DynamicShapePartitioner() = default; + + Status Partition(); + + private: + Status PartitionImpl(); + // Collect nodes that satisfy the unknowshape rules: + // 1) The Tensor shape of any input or output is unknow shape(dim_size = -1) or unknow rank(dim_size=-2) + // 2) Subgraphs of the node has an operator that satisfies rule 1) + Status MarkUnknowShapeNodes(); + // For each node a Cluster structure, and connected according to the connection relationship of the nodes + // An cluster means set of nodes that can be merged in same partition, + // Corresponding relationship between cluster type and node: + // DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT + Status InitClusters(); + // Merge clusters according to the following rules: + // 1) Iterate through the UNKNOW_SHAPE clusters, if the input is UNKNOW_SHAPE, + // merge all the clusters in the path(s) between the two clusters + // 2) Iterate through the KNOW_SHAPE clusters, if the input is KNOW_SHAPE, and + // and there's only one path between the two clusters , merge the two clusters + Status MergeClusters(); + // Topological sort clusters after merge unknow shape clusters. + Status TopologicalSortClusters(); + // Deduplicate merged clusters + void PruneUniqueClusters(); + // Establish the input-output anchors for each partition of the cluster and record links to other clusters + Status BuildPartitionFrame(); + // Establish connection between corresponding partitioned of clusters + Status CombinePartitionFrame(); + // Convert the nodes in cluster into a complete ComputeGraoh + Status BuildPartitionSubgraph(); + // Clear resource and break circular dependency + void ClearResource(); + // Debug functions + void DumpGraph(std::string suffix); + std::string DebugString(); + // Util functions + Status CollectSpreadUnknowShapeNodes(NodePtr node); + Status IsUnknowShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); + Status IsUnknowShapeNode(ge::NodePtr node, bool &is_unknow); + bool IsUnknowShapeTensor(ge::GeTensorDesc &tensor); + ge::ComputeGraphPtr root_graph_; // The original graph to partition + std::unordered_map> node_2_cluster_; // Record nodes and the cluster it belongs to + // topological sorted clusters, this field will change with the splitting. + // When partitioning UNKNOW_SHAPE cluster, it is a collection of all topological sorted UNKNOW_SHAPE clusters + // When partitioning KNOW_SHAPE cluster, it is a collection of all topological sorted KNOW_SHAPE clusters + std::vector> ordered_cluster_; + // Unique clusters left after merged clusters + std::unordered_set> unique_clusters_; + // Nodes of root_graph_ that satisfy the unknowshape rules + std::unordered_set unknown_shape_nodes_; +}; +} // namespace ge + +#endif // GE_GRAPH_PARTITION_DYNAMIC_SHAPE_PARTITION_H_ diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index f459a7c2..b408c287 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -22,8 +22,8 @@ #include #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" -#include "framework/common/op/attr_define.h" #include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" @@ -31,21 +31,6 @@ #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" -using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; -using domi::ATTR_NAME_SESSION_GRAPH_ID; -using domi::ATTR_NAME_STREAM_LABEL; -using domi::END; -using domi::NCHW_DIM_C; -using domi::NCHW_DIM_H; -using domi::NCHW_DIM_N; -using domi::NCHW_DIM_W; -using domi::NHWC_DIM_C; -using domi::NHWC_DIM_H; -using domi::NHWC_DIM_N; -using domi::NHWC_DIM_W; -using domi::PERMUTE_ATTR_ORDER; -using domi::PLACEHOLDER; - namespace { const char *const kEngineDefaultData = "ENGINE_DEFAULT_DATA"; const char *const kEndType = "End"; @@ -65,12 +50,6 @@ Status ge::GraphPartitioner::CheckIfEnd2PldEmpty(ge::ComputeGraphPtr &output_mer return FAILED; } output_merged_compute_graph = partition.first; - // flush all nodes' engine of merged graph - graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); - if (graph_info_.engine_placer_.Run() != SUCCESS) { - GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); - return FAILED; - } } else { // if placeholder to end map is empty, it should be an exception condition GELOGE(GE_GRAPH_EMPTY_PARTITION, "[GraphPartitioner]: placeholder to end map is empty, partitions size is not 1."); return FAILED; @@ -203,8 +182,12 @@ Status ge::GraphPartitioner::MergeAfterSubGraphOptimization(ge::ComputeGraphPtr GELOGE(FAILED, "Find corresponding node failed, parent node name is %s", parent_node->GetName().c_str()); return FAILED;) auto corresponding_node = graph_info.corresponding_node_in_partitions_[parent_node]; + GE_IF_BOOL_EXEC(corresponding_node == nullptr, + GELOGE(FAILED, "Get null node, node name is %s", parent_node->GetName().c_str()); + return FAILED;); merged_sub_graph->SetParentNode(corresponding_node); - merged_sub_graph->SetParentGraph(output_merged_compute_graph); + auto subgraph_parent_graph = corresponding_node->GetOwnerComputeGraph(); + merged_sub_graph->SetParentGraph(subgraph_parent_graph); ret = output_merged_compute_graph->AddSubgraph(sub_graph->GetName(), merged_sub_graph); GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, return ret;) } @@ -286,20 +269,23 @@ Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &src_node, int output GELOGE(GE_GRAPH_ADD_PLC_END_FAILED, "[GraphPartitioner]: pld_op_desc is null."); return FAILED; } - // flush pld data type as original data type - if (output_desc.GetOriginDataType() != DT_UNDEFINED) { - pld_op_desc->MutableOutputDesc(0)->SetDataType(output_desc.GetOriginDataType()); - } else { - GELOGW("Original data type of %s is undefined![data type is %s]", src_node->GetName().c_str(), - TypeUtils::DataTypeToSerialString(output_desc.GetDataType()).c_str()); - } - // flush pld format as original format - if (output_desc.GetOriginFormat() != FORMAT_RESERVED) { - pld_op_desc->MutableOutputDesc(0)->SetFormat(output_desc.GetOriginFormat()); - pld_op_desc->MutableOutputDesc(0)->SetShape(output_desc.GetOriginShape()); - } else { - GELOGW("Original format of %s is undefined![format is %s]", src_node->GetName().c_str(), - TypeUtils::FormatToSerialString(output_desc.GetFormat()).c_str()); + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on == nullptr) { + // flush pld data type as original data type + if (output_desc.GetOriginDataType() != DT_UNDEFINED) { + pld_op_desc->MutableOutputDesc(0)->SetDataType(output_desc.GetOriginDataType()); + } else { + GELOGW("Original data type of %s is undefined![data type is %s]", src_node->GetName().c_str(), + TypeUtils::DataTypeToSerialString(output_desc.GetDataType()).c_str()); + } + // flush pld format as original format + if (output_desc.GetOriginFormat() != FORMAT_RESERVED) { + pld_op_desc->MutableOutputDesc(0)->SetFormat(output_desc.GetOriginFormat()); + pld_op_desc->MutableOutputDesc(0)->SetShape(output_desc.GetOriginShape()); + } else { + GELOGW("Original format of %s is undefined![format is %s]", src_node->GetName().c_str(), + TypeUtils::FormatToSerialString(output_desc.GetFormat()).c_str()); + } } return SUCCESS; } @@ -319,20 +305,23 @@ Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &dst_node, int input_ GELOGE(GE_GRAPH_ADD_PLC_END_FAILED, "[GraphPartitioner]: pld_op_desc is null."); return FAILED; } - // flush end data type as original data type - if (input_desc.GetOriginDataType() != DT_UNDEFINED) { - end_op_desc->MutableInputDesc(0)->SetDataType(input_desc.GetOriginDataType()); - } else { - GELOGI("Original data type of %s is undefined![data type is %s]", dst_node->GetName().c_str(), - TypeUtils::DataTypeToSerialString(input_desc.GetDataType()).c_str()); - } - // flush end format as original format - if (input_desc.GetOriginFormat() != FORMAT_RESERVED) { - end_op_desc->MutableInputDesc(0)->SetFormat(input_desc.GetOriginFormat()); - end_op_desc->MutableInputDesc(0)->SetShape(input_desc.GetOriginShape()); - } else { - GELOGW("Original format of %s is undefined![format is %s]", dst_node->GetName().c_str(), - TypeUtils::FormatToSerialString(input_desc.GetFormat()).c_str()); + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on == nullptr) { + // flush end data type as original data type + if (input_desc.GetOriginDataType() != DT_UNDEFINED) { + end_op_desc->MutableInputDesc(0)->SetDataType(input_desc.GetOriginDataType()); + } else { + GELOGI("Original data type of %s is undefined![data type is %s]", dst_node->GetName().c_str(), + TypeUtils::DataTypeToSerialString(input_desc.GetDataType()).c_str()); + } + // flush end format as original format + if (input_desc.GetOriginFormat() != FORMAT_RESERVED) { + end_op_desc->MutableInputDesc(0)->SetFormat(input_desc.GetOriginFormat()); + end_op_desc->MutableInputDesc(0)->SetShape(input_desc.GetOriginShape()); + } else { + GELOGW("Original format of %s is undefined![format is %s]", dst_node->GetName().c_str(), + TypeUtils::FormatToSerialString(input_desc.GetFormat()).c_str()); + } } return SUCCESS; } @@ -531,9 +520,8 @@ void ge::GraphPartitioner::AddNewGraphToPartition(ge::ComputeGraphPtr &input_gra } bool ge::GraphPartitioner::IsDataLike(ge::NodePtr node) { - return (node->GetType() == domi::CONSTANT) || (node->GetType() == domi::DATA) || - (node->GetType() == domi::AIPPDATA) || (node->GetType() == domi::CONSTANTOP) || - (node->GetType() == domi::VARIABLE); + return (node->GetType() == CONSTANT) || (node->GetType() == DATA) || (node->GetType() == AIPPDATA) || + (node->GetType() == CONSTANTOP) || (node->GetType() == VARIABLE); } bool ge::GraphPartitioner::HasNoInput(ge::NodePtr node) { diff --git a/src/ge/graph/passes/addn_pass.cc b/src/ge/graph/passes/addn_pass.cc index 7bd32a38..c0592965 100644 --- a/src/ge/graph/passes/addn_pass.cc +++ b/src/ge/graph/passes/addn_pass.cc @@ -30,7 +30,7 @@ Status AddNPass::Run(NodePtr &node) { return PARAM_INVALID; } - if (node->GetType() == domi::ADDN) { + if (node->GetType() == ADDN) { if (node->GetOpDesc() == nullptr) { GELOGE(PARAM_INVALID, "Param [node] op desc is null."); return PARAM_INVALID; diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.cc b/src/ge/graph/passes/aicpu_constant_folding_pass.cc index 667c22a2..748c8d60 100644 --- a/src/ge/graph/passes/aicpu_constant_folding_pass.cc +++ b/src/ge/graph/passes/aicpu_constant_folding_pass.cc @@ -322,6 +322,8 @@ Status AicpuConstantFoldingPass::LaunchSingleOpRunTask(const NodePtr &node, cons STR_FWK_OP_KERNEL aicpu_task; aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr = 0; aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = 0; + aicpu_task.fwkKernelBase.fwk_kernel.extInfoAddr = 0; + aicpu_task.fwkKernelBase.fwk_kernel.extInfoNum = 0; std::string task_info; Status ret = kernel_info->GenSingleOpRunTask(node, aicpu_task, task_info); if (ret != SUCCESS) { @@ -375,6 +377,8 @@ Status AicpuConstantFoldingPass::LaunchMemCopyTask(const vector &data_ STR_FWK_OP_KERNEL aicpu_task; aicpu_task.fwkKernelBase.fwk_kernel.inputOutputAddr = 0; aicpu_task.fwkKernelBase.fwk_kernel.workspaceBaseAddr = 0; + aicpu_task.fwkKernelBase.fwk_kernel.extInfoAddr = 0; + aicpu_task.fwkKernelBase.fwk_kernel.extInfoNum = 0; std::string task_info; Status ret = kernel_info->GenMemCopyTask(data_infos.size(), aicpu_task, task_info); if (ret != SUCCESS) { @@ -571,8 +575,8 @@ void AicpuConstantFoldingPass::ReleaseMemory(const vector &input_ad bool AicpuConstantFoldingPass::IsSkipFold(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); string type = node->GetType(); - if (type == domi::FRAMEWORKOP) { - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { + if (type == ge::FRAMEWORKOP) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { GELOGW("Skip aicpu constant folding on frameworkop node [%s]", node->GetName().c_str()); return true; } diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.h b/src/ge/graph/passes/aicpu_constant_folding_pass.h index 1ff4722e..02babd8e 100644 --- a/src/ge/graph/passes/aicpu_constant_folding_pass.h +++ b/src/ge/graph/passes/aicpu_constant_folding_pass.h @@ -17,6 +17,7 @@ #ifndef GE_GRAPH_PASSES_AICPU_CONSTANT_FOLDING_PASS_H_ #define GE_GRAPH_PASSES_AICPU_CONSTANT_FOLDING_PASS_H_ +#include #include #include "common/opskernel/ops_kernel_info_store.h" diff --git a/src/ge/graph/passes/assert_pass.cc b/src/ge/graph/passes/assert_pass.cc index 3207d7be..725016a9 100644 --- a/src/ge/graph/passes/assert_pass.cc +++ b/src/ge/graph/passes/assert_pass.cc @@ -38,7 +38,7 @@ Status AssertPass::Run(NodePtr &node) { return PARAM_INVALID; } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == domi::ASSERT) { + if (op_type == ASSERT) { GELOGD("op type is assert."); std::vector nodes_unused; @@ -71,8 +71,8 @@ void AssertPass::CollectUnusedNode(const NodePtr &assert_node, vector & if (src_node != nullptr && src_node->GetOpDesc() != nullptr) { auto size = ++invalid_outdata_info[src_node.get()]; // src_node need to be deleted - if (src_node->GetOutDataNodesSize() == size && src_node->GetOpDesc()->GetType() != domi::DATA && - src_node->GetOpDesc()->GetType() != domi::AIPPDATA) { + if (src_node->GetOutDataNodesSize() == size && src_node->GetOpDesc()->GetType() != DATA && + src_node->GetOpDesc()->GetType() != AIPPDATA) { node_queue.push(src_node); } } diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index 87b40170..63928c53 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -26,28 +26,65 @@ #include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "graph/debug/ge_attr_define.h" +#include "graph/utils/node_utils.h" #include "init/gelib.h" -using domi::ATOMICADDRCLEAN; -using domi::ATTR_NAME_STREAM_LABEL; -using domi::LOOPCOND; -using domi::NODE_NAME_ATOMIC_ADDR_CLEAN; - namespace { bool is_loop_graph = false; } namespace ge { +namespace { +bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { + // Internal function, guaranteeing graph non-null + auto parent = graph->GetParentGraph(); + if (parent == nullptr) { + return false; + } + for (NodePtr &node : graph->GetDirectNode()) { + bool is_unknown = false; + auto ret_status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret_status != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGI("Node %s, type %s is unknown shape, sub graph %s should be skip.", node->GetName().c_str(), + node->GetType().c_str(), graph->GetName().c_str()); + return true; + } + } + GELOGI("Sub graph %s does not have unknown shape node, run the pass.", graph->GetName().c_str()); + return false; +} +} // namespace + Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { GE_TIMESTAMP_START(AtomicAddrCleanPass); if (graph == nullptr) { GELOGE(PARAM_INVALID, "param [graph] must not be null."); return PARAM_INVALID; } + if (GraphShouldBeSkip(graph)) { + return SUCCESS; + } GELOGD("AtomicAddrCleanPass begin."); // 1.Recoginze atomic and loop mark vector atomic_node_vec; for (NodePtr &node : graph->GetDirectNode()) { if (IsAtomicOp(node)) { + bool is_unknown = false; + auto ret_status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret_status != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } atomic_node_vec.push_back(node); } if (!is_loop_graph && node->GetType() == LOOPCOND) { @@ -205,7 +242,18 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { vector op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); for (const auto &op_info : op_info_vec) { if (op_info.isAtomic) { - GELOGI("Recognized atomic op %s from HCCL engine.", op_desc->GetName().c_str()); + GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); + // check peer input is DATA + for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + if (in_data_anchor->GetPeerOutAnchor() != nullptr && + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { + auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); + if (peer_in_node->GetType() == DATA) { + GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); + return false; + } + } + } hcom_node_vec_.push_back(node); return true; } diff --git a/src/ge/graph/passes/base_pass.cc b/src/ge/graph/passes/base_pass.cc index 2ac7e938..53025f6a 100644 --- a/src/ge/graph/passes/base_pass.cc +++ b/src/ge/graph/passes/base_pass.cc @@ -29,7 +29,7 @@ namespace { constexpr int kMaxRePassTimes = 1000; constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later -constexpr int kMaxRecursiveDepth = 10; +constexpr int kMaxRecursiveDepth = 20; void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &input_edge_nodes, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index 00a9581e..a0742a03 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -22,8 +22,6 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -using domi::CAST; - namespace ge { Status CastRemovePass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/cast_translate_pass.cc b/src/ge/graph/passes/cast_translate_pass.cc index dfda5d10..2d67b0a8 100644 --- a/src/ge/graph/passes/cast_translate_pass.cc +++ b/src/ge/graph/passes/cast_translate_pass.cc @@ -23,17 +23,13 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" -using domi::ATTR_NAME_INPUT_DATATYPE; -using domi::ATTR_NAME_OUTPUT_DATATYPE; -using domi::CAST; -using domi::TRANSLATE; - namespace ge { bool CastTranslatePass::CheckInAndOutDataAnchor(NodePtr &node) const { if (node == nullptr) { diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.cc b/src/ge/graph/passes/common_subexpression_elimination_pass.cc index f16be19f..a52535c1 100644 --- a/src/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/src/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -66,10 +66,22 @@ Status CommonSubexpressionEliminationPass::Run(ComputeGraphPtr graph) { GELOGD("Begin to run the CSE process on the graph"); GE_CHECK_NOTNULL(graph); std::map keys_to_node; - for (const auto &node : graph->GetAllNodes()) { + for (const auto &node : graph->GetDirectNode()) { if (!IsNodeSupportCse(node)) { continue; } + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } auto key = GetCseKey(node); auto iter = keys_to_node.find(key); if (iter == keys_to_node.end()) { @@ -88,7 +100,7 @@ Status CommonSubexpressionEliminationPass::Run(ComputeGraphPtr graph) { output_map[i] = i; } - auto ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map); + ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map); if (ret != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to replace node %s by node %s error node %u", node->GetName().c_str(), iter->second->GetName().c_str(), ret); diff --git a/src/ge/graph/passes/compile_nodes_pass.cc b/src/ge/graph/passes/compile_nodes_pass.cc index f46b11f0..def7655e 100644 --- a/src/ge/graph/passes/compile_nodes_pass.cc +++ b/src/ge/graph/passes/compile_nodes_pass.cc @@ -45,7 +45,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { return ge::GE_CLI_GE_NOT_INITIALIZED; } std::unordered_map> kernel_to_compile_nodes; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetDirectNode()) { if (node == nullptr) { continue; } diff --git a/src/ge/graph/passes/compile_nodes_pass.h b/src/ge/graph/passes/compile_nodes_pass.h index 56df7b87..70f8cbf5 100644 --- a/src/ge/graph/passes/compile_nodes_pass.h +++ b/src/ge/graph/passes/compile_nodes_pass.h @@ -19,6 +19,9 @@ #include #include +#include +#include + #include "inc/graph_pass.h" #include "init/gelib.h" diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc new file mode 100644 index 00000000..4052950a --- /dev/null +++ b/src/ge/graph/passes/cond_pass.cc @@ -0,0 +1,344 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/cond_pass.h" +#include "common/op/ge_op_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" + +namespace { +const std::set kIfTypes = {ge::IF, ge::_IF, ge::STATELESSIF}; +const std::set kWhileTypes = {ge::WHILE, ge::_WHILE, ge::STATELESSWHILE}; +const std::string kStringLength = "StringLength"; +const size_t kScalarDimNum = 1; +} // namespace + +namespace ge { +Status CondPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + ComputeGraphPtr graph = nullptr; + OutDataAnchorPtr cond_out_anchor = nullptr; + InDataAnchorPtr cond_in_anchor = nullptr; + Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); + if (ret == NOT_CHANGED) { + return SUCCESS; + } else if (ret != SUCCESS) { + GELOGE(FAILED, "Get cond_info for node %s failed.", node->GetName().c_str()); + return FAILED; + } + + /// cond + /// 1. NonScalar: cond->Shape->Shape(int32)->If / NetOutput(while) + /// 2. String Scalar: cond->StringLength(int32)->If / NetOutput(while) + /// 3. bool / float / double / uint8 / int16 / int8 / int64 Scalar: cond->Cast(2int32)->If / NetOutput(while) + /// 4. Int32 Scalar: cond->If / NetOutput(while) + OpDescPtr op_desc = cond_in_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Handle cond for node %s.", op_desc->GetName().c_str()); + GeTensorDesc cond_tensor = op_desc->GetInputDesc(cond_in_anchor->GetIdx()); + if (!cond_tensor.GetShape().IsScalar()) { + GE_CHK_STATUS_RET(HandleNonScalarCond(graph, cond_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", + op_desc->GetName().c_str()) + } else { + switch (cond_tensor.GetDataType()) { + case DT_STRING: + GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", + op_desc->GetName().c_str()) + break; + case DT_BOOL: + case DT_FLOAT: + case DT_DOUBLE: + case DT_UINT8: + case DT_INT16: + case DT_INT8: + case DT_INT64: + GE_CHK_STATUS_RET(HandleScalarCond(graph, cond_out_anchor, cond_in_anchor, cond_tensor.GetDataType()), + "HandleScalarCond for %s failed.", op_desc->GetName().c_str()) + break; + case DT_INT32: + break; + default: + GELOGE(FAILED, "UpdateInputDesc for node %s failed.", op_desc->GetName().c_str()); + return FAILED; + } + } + + cond_tensor.SetDataType(DT_INT32); + cond_tensor.SetOriginDataType(DT_INT32); + cond_tensor.SetShape(GeShape()); + cond_tensor.SetOriginShape(GeShape()); + if (op_desc->UpdateInputDesc(cond_in_anchor->GetIdx(), cond_tensor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "UpdateInputDesc for node %s failed.", op_desc->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +/// +/// @brief Get cond info for if / while +/// @param [in] node: If / While op +/// @param [out] graph: owner_graph of if node / while_cond subgraph +/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] cond_in_anchor: cond_input +/// @return Status +/// +Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor) { + GE_CHECK_NOTNULL(node); + std::string type = node->GetType(); + if (kIfTypes.count(type) != 0) { + if (GetCondInfoForIf(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { + GELOGE(FAILED, "Get cond_info for if node failed."); + return FAILED; + } + } else if (kWhileTypes.count(type) != 0) { + if (GetCondInfoForWhile(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { + GELOGE(FAILED, "Get cond_info for while node failed."); + return FAILED; + } + } else { + GELOGI("no need cond_pass for node %s.", node->GetName().c_str()); + return NOT_CHANGED; + } + + return SUCCESS; +} + +/// +/// @brief Get cond info for if node +/// @param [in] node: If op +/// @param [out] graph: owner_graph of if node +/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] cond_in_anchor: cond_input of if +/// @return Status +/// +Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor) { + GE_CHECK_NOTNULL(node); + graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT); + GE_CHECK_NOTNULL(cond_in_anchor); + cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(cond_out_anchor); + return SUCCESS; +} + +/// +/// @brief Get cond info for while node +/// @param [in] node: While op +/// @param [out] graph: while_cond subgraph +/// @param [out] cond_out_anchor: peer_cond_anchor +/// @param [out] cond_in_anchor: input of NetOutput in cond_graph +/// @return Status +/// +Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor) { + GE_CHECK_NOTNULL(node); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::map subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); + auto iter = subgraph_names_to_index.find(ATTR_NAME_WHILE_COND); + if (iter == subgraph_names_to_index.end()) { + GELOGE(FAILED, "Get cond_graph index failed, while_node:%s.", node->GetName().c_str()); + return FAILED; + } + std::string cond_graph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); + graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph())->GetSubgraph(cond_graph_instance_name); + GE_CHECK_NOTNULL(graph); + + NodePtr net_output_node = graph->FindNode(NODE_NAME_NET_OUTPUT); + GE_CHECK_NOTNULL(net_output_node); + // cond_graph has and only has one output + uint32_t output_num = net_output_node->GetAllInDataAnchorsSize(); + if (output_num != 1) { + GELOGE(FAILED, "output size of cond_graph is invalid, expect 1 but %u exactly, while_node:%s.", output_num, + node->GetName().c_str()); + return FAILED; + } + + cond_in_anchor = net_output_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(cond_in_anchor); + cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(cond_out_anchor); + + return SUCCESS; +} + +/// +/// @brief Process Cond Op with non-scalar cond_input: cond->Shape->Shape->If / NetOutput(while) +/// @param [in] graph +/// @param [in] out_anchor: peer_cond_anchor +/// @param [in] in_anchor: cond_input +/// @return Status +/// +Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor) { + if (InsertNode(graph, out_anchor, in_anchor, SHAPE) != SUCCESS) { + GELOGE(FAILED, "Insert first Shape node failed."); + return FAILED; + } + + if (InsertNode(graph, in_anchor->GetPeerOutAnchor(), in_anchor, SHAPE) != SUCCESS) { + GELOGE(FAILED, "Insert second Shape node failed."); + return FAILED; + } + + return SUCCESS; +} + +/// +/// @brief Process Cond Op with scalar-string cond_input: cond->StringLength(int32)->If / NetOutput(while) +/// @param [in] graph +/// @param [in] out_anchor: peer_cond_anchor +/// @param [in] in_anchor: cond_input +/// @return Status +/// +Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor) { + GELOGI("Handle cond with scalar-string cond-input."); + return InsertNode(graph, out_anchor, in_anchor, kStringLength); +} + +/// +/// @brief Process Cond Op with scalar cond_input: cond->Cast(2int32)->If / NetOutput(while) +/// @param [in] graph +/// @param [in] out_anchor: peer_cond_anchor +/// @param [in] in_anchor: cond_input +/// @param [in] src_type +/// @return Status +/// +Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, DataType src_type) { + GE_CHECK_NOTNULL(in_anchor); + GE_CHECK_NOTNULL(out_anchor); + GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); + GELOGI("Handle cond with scalar cond-input."); + + GeTensorDesc tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + std::string cast_name = out_anchor->GetOwnerNode()->GetName() + "_Cast"; + NodePtr cast_node = AddCastNode(graph, cast_name, tensor, src_type, DT_INT32); + if (cast_node == nullptr) { + GELOGE(FAILED, "Add Cast node failed, name:%s.", cast_name.c_str()); + return FAILED; + } + + if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", cast_node->GetName().c_str(), + out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +/// +/// @brief Insert node +/// @param [in] graph +/// @param [in] out_anchor +/// @param [in] in_anchor +/// @param [in] type +/// @return Status +/// +Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, const std::string &type) { + GE_CHECK_NOTNULL(out_anchor); + GE_CHECK_NOTNULL(in_anchor); + GELOGD("Begin to insert %s node.", type.c_str()); + + GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); + GE_CHECK_NOTNULL(in_anchor->GetOwnerNode()->GetOpDesc()); + GeTensorDesc in_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); + GeTensorDesc out_tensor = in_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(out_anchor->GetIdx()); + out_tensor.SetDataType(DT_INT32); + out_tensor.SetOriginDataType(DT_INT32); + if (type == SHAPE) { + int64_t size = static_cast(in_tensor.GetShape().GetDimNum()); + if (size == kScalarDimNum) { + out_tensor.SetShape(GeShape()); + out_tensor.SetOriginShape(GeShape()); + } else { + std::vector size_v{size}; + out_tensor.SetShape(GeShape(size_v)); + out_tensor.SetOriginShape(GeShape(size_v)); + } + } + + OpDescBuilder op_desc_builder(out_anchor->GetOwnerNode()->GetName() + "_" + type, type); + OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed."); + return FAILED; + } + NodePtr new_node = graph->AddNode(op_desc); + if (new_node == nullptr) { + GELOGE(FAILED, "Create %s node failed.", type.c_str()); + return FAILED; + } + AddRePassNode(new_node); + + if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), new_node->GetName().c_str(), + out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +/// +/// @brief Add cast node +/// @param [in] graph +/// @param [in] name +/// @param [in] tensor +/// @param [in] src +/// @param [in] dst +/// @return NodePtr +/// +NodePtr CondPass::AddCastNode(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &tensor, + DataType src, DataType dst) { + GELOGI("Begin to create cast op: %s, from %d to %d", name.c_str(), src, dst); + + GeTensorDesc in_tensor = tensor; + in_tensor.SetDataType(src); + in_tensor.SetOriginDataType(src); + GeTensorDesc out_tensor = tensor; + out_tensor.SetDataType(dst); + out_tensor.SetOriginDataType(dst); + OpDescBuilder op_desc_builder(name, CAST); + OpDescPtr cast_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); + if (cast_desc == nullptr) { + GELOGE(FAILED, "Create cast op_desc failed, name: %s.", name.c_str()); + return nullptr; + } + if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, src) && AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, dst) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, dst) && + AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR failed, node: %s.", name.c_str()); + return nullptr; + } + + NodePtr cast_node = graph->AddNode(cast_desc); + if (cast_node == nullptr) { + GELOGE(FAILED, "Add cast node failed, name: %s.", name.c_str()); + return nullptr; + } + AddRePassNode(cast_node); + + return cast_node; +} +} // namespace ge diff --git a/src/ge/graph/passes/cond_pass.h b/src/ge/graph/passes/cond_pass.h new file mode 100644 index 00000000..fead8474 --- /dev/null +++ b/src/ge/graph/passes/cond_pass.h @@ -0,0 +1,116 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_COND_PASS_H +#define GE_GRAPH_PASSES_COND_PASS_H + +#include "graph/passes/base_pass.h" + +namespace ge { +class CondPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + /// + /// @brief Get cond info for if / while + /// @param [in] node: If / While op + /// @param [out] graph: owner_graph of if node / while_cond subgraph + /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] cond_in_anchor: cond_input + /// @return Status + /// + static Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor); + + /// + /// @brief Get cond info for if node + /// @param [in] node: If op + /// @param [out] graph: owner_graph of if node + /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] cond_in_anchor: cond_input of if + /// @return Status + /// + static Status GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor); + + /// + /// @brief Get cond info for while node + /// @param [in] node: While op + /// @param [out] graph: while_cond subgraph + /// @param [out] cond_out_anchor: peer_cond_anchor + /// @param [out] cond_in_anchor: input of NetOutput in cond_graph + /// @return Status + /// + static Status GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, + InDataAnchorPtr &cond_in_anchor); + + /// + /// @brief Process Cond Op with non-scalar cond_input + /// @param [in] graph + /// @param [in] out_anchor: peer_cond_anchor + /// @param [in] in_anchor: cond_input + /// @return Status + /// + Status HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor); + + /// + /// @brief Process Cond Op with scalar-string cond_input + /// @param [in] graph + /// @param [in] out_anchor: peer_cond_anchor + /// @param [in] in_anchor: cond_input + /// @return Status + /// + Status HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor); + + /// + /// @brief Process Cond Op with scalar cond_input + /// @param [in] graph + /// @param [in] out_anchor: peer_cond_anchor + /// @param [in] in_anchor: cond_input + /// @param [in] src_type + /// @return Status + /// + Status HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, DataType src_type); + + /// + /// @brief Insert node + /// @param [in] graph + /// @param [in] out_anchor + /// @param [in] in_anchor + /// @param [in] type + /// @return Status + /// + Status InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor, + const std::string &type); + + /// + /// @brief Add cast node + /// @param [in] graph + /// @param [in] name + /// @param [in] tensor + /// @param [in] src + /// @param [in] dst + /// @return NodePtr + /// + NodePtr AddCastNode(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &tensor, DataType src, + DataType dst); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_COND_PASS_H diff --git a/src/ge/graph/passes/constant_fuse_same_pass.cc b/src/ge/graph/passes/constant_fuse_same_pass.cc index f3ef6352..69726e5d 100644 --- a/src/ge/graph/passes/constant_fuse_same_pass.cc +++ b/src/ge/graph/passes/constant_fuse_same_pass.cc @@ -29,9 +29,6 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" -using domi::CONSTANT; -using domi::CONSTANTOP; - namespace ge { namespace { const size_t kCorrectNum = 1; diff --git a/src/ge/graph/passes/control_op_attr_pass.cc b/src/ge/graph/passes/control_op_attr_pass.cc deleted file mode 100644 index 983f22f1..00000000 --- a/src/ge/graph/passes/control_op_attr_pass.cc +++ /dev/null @@ -1,256 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/control_op_attr_pass.h" - -#include -#include -#include - -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "init/gelib.h" - -using domi::ATTR_NAME_STREAM_LABEL; - -using domi::STREAMACTIVE; -using domi::STREAMSWITCH; -using domi::STREAMSWITCHN; - -namespace { -const uint32_t kMaxNodeNum = 350; -} // namespace - -namespace ge { -/// -/// @brief Pass for Switch & Active Op attr -/// @param [in] graph -/// @return Status -/// -Status ControlOpAttrPass::Run(ComputeGraphPtr graph) { - GELOGD("ControlOpAttrPass Enter"); - - if (AcquireEngineInfo() != SUCCESS) { - GELOGE(FAILED, "AcquireEngineInfo fail."); - return FAILED; - } - - if (HandleStreamLabel(graph) != SUCCESS) { - GELOGE(FAILED, "HandleStreamLabel fail."); - return FAILED; - } - - if (HandleSwitchNodes(graph) != SUCCESS) { - GELOGE(FAILED, "HandleSwitchNodes fail."); - return FAILED; - } - - GELOGD("ControlOpAttrPass Leave"); - return SUCCESS; -} - -/// -/// @brief acquire engine info -/// @return Status -/// -Status ControlOpAttrPass::AcquireEngineInfo() { - auto gelib = GELib::GetInstance(); - if (gelib == nullptr) { - GELOGE(INTERNAL_ERROR, "Get GELib instance failed."); - return INTERNAL_ERROR; - } - - const map &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); - for (const auto &item : scheduler_confs) { - const SchedulerConf &scheduler = item.second; - for (const auto &engine_pair : scheduler.cal_engines) { - EngineConfPtr engine_conf = engine_pair.second; - if (engine_conf != nullptr) { - engine_confs_[engine_pair.first] = engine_conf; - } - } - } - - return SUCCESS; -} - -/// -/// @brief Handle stream label -/// @param [in] graph -/// @return Status -/// -Status ControlOpAttrPass::HandleStreamLabel(const ComputeGraphPtr &graph) { - std::string stream_label; - for (auto &node : graph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - const std::string type = op_desc->GetType(); - if ((type == STREAMSWITCH) || (type == STREAMSWITCHN)) { - switch_nodes_.emplace_back(node); - } - - if (!AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label)) { - continue; - } - - auto num_iter = stream_label_num_.find(stream_label); - if (num_iter == stream_label_num_.end()) { - stream_label_num_[stream_label] = 1; - } else { - num_iter->second++; - } - - bool independent = false; - const std::string engine_name = op_desc->GetOpEngineName(); - if (!engine_name.empty()) { - auto engine_conf_iter = engine_confs_.find(engine_name); - bool exist_flag = (engine_conf_iter == engine_confs_.end()) || (engine_conf_iter->second == nullptr); - if (exist_flag) { - GELOGE(INTERNAL_ERROR, "Engine conf of node %s not found (engine name: %s).", op_desc->GetName().c_str(), - engine_name.c_str()); - return INTERNAL_ERROR; - } - independent = engine_conf_iter->second->independent; - } - - auto flag_iter = label_flag_.find(stream_label); - if (flag_iter == label_flag_.end()) { - label_flag_[stream_label] = independent ? std::make_pair(false, true) : std::make_pair(true, false); - } else if (flag_iter->second.first && flag_iter->second.second) { - continue; - } else { - bool &flag = (independent ? flag_iter->second.second : flag_iter->second.first); - flag = true; - } - } - - return SUCCESS; -} - -/// -/// @brief Handle Switch Op -/// @param [in] graph -/// @return Status -/// -Status ControlOpAttrPass::HandleSwitchNodes(ComputeGraphPtr &graph) { - for (auto &switch_node : switch_nodes_) { - GE_CHECK_NOTNULL(switch_node); - std::vector ori_active_label_list; - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - if (!AttrUtils::GetListStr(switch_desc, ATTR_NAME_ACTIVE_LABEL_LIST, ori_active_label_list) || - ori_active_label_list.empty()) { - GELOGE(INTERNAL_ERROR, "active label of switch %s is null", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - } - - std::vector active_label_list; - std::vector active_nodes; - size_t label_num = ori_active_label_list.size(); - for (size_t i = 0; i < label_num; i++) { - const std::string active_label = ori_active_label_list[i]; - if (!CheckNeedActiveNode(active_label)) { - active_label_list.emplace_back(active_label); - continue; - } - - std::string name = switch_node->GetName() + "_" + STREAMACTIVE; - if (label_num > 0) { - name = name + "_" + std::to_string(i); - } - GELOGI("Create StreamActive op:%s.", name.c_str()); - OpDescPtr active_op_desc = MakeShared(name, STREAMACTIVE); - if (active_op_desc == nullptr) { - GELOGE(FAILED, "Create node %s fail.", name.c_str()); - return FAILED; - } - NodePtr active_node = graph->AddNode(active_op_desc); - if (active_node == nullptr) { - GELOGE(FAILED, "Create StreamActive node fail."); - return FAILED; - } - - for (NodePtr &node : switch_node->GetOutControlNodes()) { - std::string stream_label; - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - (void)AttrUtils::GetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label); - if (stream_label != active_label) { - continue; - } - GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "remove edge failed"); - GE_CHK_STATUS_RET(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "add edge failed"); - } - - GE_CHK_STATUS_RET(SetSwitchBranchNodeLabel(active_node, name), "set switch branch node label failed"); - GE_CHK_STATUS_RET(SetStreamLabel(active_node, name), "set stream label failed"); - GE_CHK_STATUS_RET(SetActiveLabelList(active_node, {active_label}), "set active label list failed"); - - active_nodes.emplace_back(active_node); - active_label_list.emplace_back(name); - } - - GE_CHK_STATUS_RET(SetActiveLabelList(switch_node, {active_label_list}), "set active label list failed"); - - if (active_nodes.empty()) { - continue; - } - - if (!switch_node->GetOutAllNodes().empty()) { - GELOGE(FAILED, "Exist out_node holds stream_label beyond the range of active_label_list, switch_node:%s.", - switch_desc->GetName().c_str()); - return FAILED; - } - for (auto &active_node : active_nodes) { - GE_CHK_STATUS_RET(GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), - "add edge failed"); - } - } - - return SUCCESS; -} - -/// -/// @brief Check if insert active node -/// @param [in] stream_label -/// @return bool -/// -bool ControlOpAttrPass::CheckNeedActiveNode(const std::string &stream_label) { - if (stream_label_num_[stream_label] > kMaxNodeNum) { - return true; - } - - auto iter = label_flag_.find(stream_label); - if (iter == label_flag_.end()) { - GELOGE(INTERNAL_ERROR, "not find label %s", stream_label.c_str()); - return false; - } - if (iter->second.first && iter->second.second) { - return true; - } - - return false; -} -} // namespace ge diff --git a/src/ge/graph/passes/control_op_attr_pass.h b/src/ge/graph/passes/control_op_attr_pass.h deleted file mode 100644 index d53e2191..00000000 --- a/src/ge/graph/passes/control_op_attr_pass.h +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_CONTROL_OP_ATTR_PASS_H_ -#define GE_GRAPH_PASSES_CONTROL_OP_ATTR_PASS_H_ - -#include -#include -#include -#include -#include - -#include "engine_manager/dnnengine_manager.h" -#include "inc/graph_pass.h" - -namespace ge { -class ControlOpAttrPass : public GraphPass { - public: - Status Run(ComputeGraphPtr graph); - - private: - Status AcquireEngineInfo(); - Status HandleStreamLabel(const ComputeGraphPtr &graph); - Status HandleSwitchNodes(ComputeGraphPtr &graph); - bool CheckNeedActiveNode(const std::string &stream_label); - - std::unordered_map stream_label_num_; - // map> - std::unordered_map> label_flag_; - std::vector switch_nodes_; - std::map engine_confs_; -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_CONTROL_OP_ATTR_PASS_H_ diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index a13a84b9..b1218d9f 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -27,26 +27,10 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -using domi::ATTR_NAME_WEIGHTS; - -using domi::CONSTANT; -using domi::CONTROLTRIGGER; -using domi::ENTER; -using domi::IDENTITY; -using domi::LOOPCOND; -using domi::MERGE; -using domi::REFENTER; -using domi::REFMERGE; -using domi::REFSWITCH; -using domi::SWITCH; - namespace ge { Status ControlTriggerPass::Run(ComputeGraphPtr graph) { GELOGD("ControlTriggerPass Enter"); - GraphUtils::DumpGEGraph(graph, "BeforeControlTriggerPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "BeforeControlTriggerPass"); - for (NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != CONTROLTRIGGER) { continue; @@ -61,9 +45,6 @@ Status ControlTriggerPass::Run(ComputeGraphPtr graph) { } } - GraphUtils::DumpGEGraph(graph, "AfterControlTriggerPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "AfterControlTriggerPass"); - GELOGD("ControlTriggerPass Leave"); return SUCCESS; } diff --git a/src/ge/graph/passes/control_trigger_pass.h b/src/ge/graph/passes/control_trigger_pass.h index 39ee515d..b9fff9b4 100644 --- a/src/ge/graph/passes/control_trigger_pass.h +++ b/src/ge/graph/passes/control_trigger_pass.h @@ -49,4 +49,4 @@ class ControlTriggerPass : public GraphPass { std::unordered_map>> control_trigger_map_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_CONTROL_TRIGGER_PASS_H_ \ No newline at end of file +#endif // GE_GRAPH_PASSES_CONTROL_TRIGGER_PASS_H_ diff --git a/src/ge/graph/passes/dimension_adjust_pass.cc b/src/ge/graph/passes/dimension_adjust_pass.cc index ab69693a..28ebbb83 100644 --- a/src/ge/graph/passes/dimension_adjust_pass.cc +++ b/src/ge/graph/passes/dimension_adjust_pass.cc @@ -19,6 +19,7 @@ #include #include #include +#include "graph/utils/node_utils.h" namespace ge { namespace { @@ -49,6 +50,17 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) { if (op_kernel == nullptr) { return SUCCESS; } + bool is_unknown = false; + auto ret_status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret_status != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), node->GetType().c_str()); + return INTERNAL_ERROR; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), + node->GetType().c_str()); + return SUCCESS; + } // call compute function ret = op_kernel->Compute(node); diff --git a/src/ge/graph/passes/dropout_pass.cc b/src/ge/graph/passes/dropout_pass.cc index f1be5ba0..ab88aa23 100644 --- a/src/ge/graph/passes/dropout_pass.cc +++ b/src/ge/graph/passes/dropout_pass.cc @@ -39,7 +39,7 @@ Status DropOutPass::Run(NodePtr &node) { return PARAM_INVALID; } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == domi::DROPOUT) { + if (op_type == DROPOUT) { GELOGD("op type is dropout."); return IsolateAndDeleteNode(node, {0}); } diff --git a/src/ge/graph/passes/end_graph_pass.cc b/src/ge/graph/passes/end_graph_pass.cc deleted file mode 100644 index 8cd5c176..00000000 --- a/src/ge/graph/passes/end_graph_pass.cc +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/end_graph_pass.h" - -#include -#include -#include -#include - -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "graph/passes/pass_utils.h" -#include "graph/utils/tensor_utils.h" -#include "init/gelib.h" -#include "common/ge/ge_util.h" -#include "graph/debug/ge_attr_define.h" - -using domi::ENDGRAPH; -using domi::NODE_NAME_END_GRAPH; -using domi::NODE_NAME_NET_OUTPUT; - -namespace ge { -Status EndGraphPass::Run(ge::ComputeGraphPtr graph) { - GELOGI("EndGraphPass Run."); - if (graph == nullptr) { - GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); - return GE_GRAPH_PARAM_NULLPTR; - } - - auto gelib = GELib::GetInstance(); - bool head_stream = (gelib == nullptr) ? false : gelib->HeadStream(); - if (!head_stream) { - GELOGI("Configured head stream: %d, No need EndGraph.", head_stream); - return SUCCESS; - } - - NodePtr net_output_node = graph->FindNode(NODE_NAME_NET_OUTPUT); - if (net_output_node == nullptr) { - GELOGI("No output node found."); - return SUCCESS; - } - - OpDescPtr op_desc = MakeShared(NODE_NAME_END_GRAPH, ENDGRAPH); - GE_CHECK_NOTNULL(op_desc); - GELOGI("Create EndGraph op:%s.", op_desc->GetName().c_str()); - (void)AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); - NodePtr end_graph_node = graph->AddNode(op_desc); - if (end_graph_node == nullptr) { - GELOGI("Add EndGraph:%s node to Graph fail.", op_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - - if (GraphUtils::AddEdge(net_output_node->GetOutControlAnchor(), end_graph_node->GetInControlAnchor()) != SUCCESS) { - GELOGI("Add ctrl edge to EndGraph:%s fail.", end_graph_node->GetName().c_str()); - return INTERNAL_ERROR; - } - - GELOGI("EndGraphPass Leave."); - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/enter_pass.cc b/src/ge/graph/passes/enter_pass.cc index af3e4739..98ca30a5 100644 --- a/src/ge/graph/passes/enter_pass.cc +++ b/src/ge/graph/passes/enter_pass.cc @@ -23,11 +23,6 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::ENTER; -using domi::REFENTER; - namespace ge { Status EnterPass::Run(NodePtr &node) { GELOGD("EnterPass running"); diff --git a/src/ge/graph/passes/flow_ctrl_pass.cc b/src/ge/graph/passes/flow_ctrl_pass.cc index 6e933708..a8c20a79 100644 --- a/src/ge/graph/passes/flow_ctrl_pass.cc +++ b/src/ge/graph/passes/flow_ctrl_pass.cc @@ -29,23 +29,6 @@ namespace ge { // when namespace change to ge, please delete the using code. -using domi::NODE_NAME_FLOWCTRL_LOOP_ASSIGN; -using domi::NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD; -using domi::NODE_NAME_FLOWCTRL_LOOP_COND; -using domi::NODE_NAME_FLOWCTRL_LOOP_INCREMENT; -using domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER; -using domi::NODE_NAME_FLOWCTRL_LOOP_RESETVALUE; -using domi::NODE_NAME_STREAM_SWITCH; - -using domi::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -using domi::TRUE_STREAM_ID; - -using domi::ASSIGN; -using domi::ASSIGNADD; -using domi::STREAMACTIVE; -using domi::STREAMSWITCH; -using domi::VARIABLE; - Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { GE_CHECK_NOTNULL(compute_graph); @@ -205,32 +188,27 @@ NodePtr FlowCtrlPass::AddVariableNode(ComputeGraphPtr &compute_graph, const stri } Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) { - NodePtr output_node = compute_graph->FindNode(domi::NODE_NAME_NET_OUTPUT); + NodePtr output_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); if (output_node == nullptr) { - GELOGD("Node %s can't be found in graph %u", domi::NODE_NAME_NET_OUTPUT.c_str(), compute_graph->GetGraphID()); + GELOGD("Node %s can't be found in graph %u", NODE_NAME_NET_OUTPUT.c_str(), compute_graph->GetGraphID()); return SUCCESS; } if (compute_graph->GetParentGraph() != nullptr) { // Global step just add to main graph. - GELOGD("Graph %s no need global step variable.", compute_graph->GetName().c_str()); - uint32_t parent_index = 0; // Set to 0 as a mark for subgraph. - if (!AttrUtils::SetInt(output_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGW("Node: %s Add attr %s failed.", output_node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); - } + GELOGD("Subgraph %s no need global step variable.", compute_graph->GetName().c_str()); return SUCCESS; } - NodePtr exist_node = compute_graph->FindNode(domi::NODE_NAME_GLOBAL_STEP); + NodePtr exist_node = compute_graph->FindNode(NODE_NAME_GLOBAL_STEP); if (exist_node != nullptr) { - GELOGD("Node %s already exist, no need add.", domi::NODE_NAME_GLOBAL_STEP.c_str()); + GELOGD("Node %s already exist, no need add.", NODE_NAME_GLOBAL_STEP.c_str()); return SUCCESS; } // set global step tensor desc GeTensorDesc tensor_desc(GeShape({1}), FORMAT_ND, DT_UINT64); std::vector input_desc_list = {}; std::vector output_desc_list = {tensor_desc}; - NodePtr global_step = - InsertOp(compute_graph, VARIABLE, domi::NODE_NAME_GLOBAL_STEP, input_desc_list, output_desc_list); + NodePtr global_step = InsertOp(compute_graph, VARIABLE, NODE_NAME_GLOBAL_STEP, input_desc_list, output_desc_list); if (global_step == nullptr) { GELOGE(FAILED, "Add global_step node failed, global_step is null."); return FAILED; diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.cc b/src/ge/graph/passes/folding_kernel/add_kernel.cc index 2d786a4a..89f99938 100644 --- a/src/ge/graph/passes/folding_kernel/add_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/add_kernel.cc @@ -23,7 +23,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::ADD; namespace ge { namespace { const size_t kAddFirstInput = 0; diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.h b/src/ge/graph/passes/folding_kernel/add_kernel.h index 218bc12a..f8fd272e 100644 --- a/src/ge/graph/passes/folding_kernel/add_kernel.h +++ b/src/ge/graph/passes/folding_kernel/add_kernel.h @@ -38,4 +38,4 @@ class AddKernel : public Kernel { std::vector &v_output); }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_ADD_KERNEL_H_ \ No newline at end of file +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_ADD_KERNEL_H_ diff --git a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc b/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc index 212fd419..364fb415 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc @@ -26,8 +26,6 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" -using domi::BROADCASTARGS; - namespace ge { namespace { const size_t kBCastArgsInputsSize = 2; diff --git a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc b/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc index 826d3471..0053a9df 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc @@ -27,8 +27,6 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" -using domi::BROADCASTGRADIENTARGS; - namespace ge { namespace { const size_t kBCastGradArgsInputsSize = 2; diff --git a/src/ge/graph/passes/folding_kernel/cast_kernel.cc b/src/ge/graph/passes/folding_kernel/cast_kernel.cc index 54634737..99944c20 100644 --- a/src/ge/graph/passes/folding_kernel/cast_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/cast_kernel.cc @@ -33,11 +33,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::CAST; -using domi::PARAM_INVALID; -using domi::Status; -using domi::SUCCESS; - namespace ge { namespace { const size_t kCastInputSize = 1; @@ -54,9 +49,11 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); - if (op_desc_ptr == nullptr || src_data == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr or src_data is nullptr."); + // src_data == nullptr is supported + if (op_desc_ptr == nullptr) { + GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); return PARAM_INVALID; } GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); @@ -78,7 +75,7 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().GetSize()); + // const_weight_ptr->GetData().GetSize() == 0 is supported auto src_data_size = src_shape.GetShapeSize(); if (src_data_size == 0 && static_cast(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { @@ -118,7 +115,6 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorSetData(trans_result.data.get(), trans_result.length) != SUCCESS) { GELOGW("Compute: SetData failed"); - return FAILED; } v_output.push_back(output_ptr); return SUCCESS; diff --git a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc b/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc index 2217c58e..f5146b5b 100644 --- a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc @@ -25,8 +25,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::CONCATOFFSET; - namespace ge { namespace { const size_t kConcatOffsetInputIndexZero = 0; @@ -66,6 +64,10 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector buf(new (std::nothrow) int32_t[output_size]()); if (buf == nullptr) { GELOGE(MEMALLOC_FAILED, "new buf failed"); @@ -102,4 +104,4 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector &input, GeTen auto output_size = merged_shape.GetShapeSize(); int64_t data_size = GetSizeByDataType(data_type); auto step = merged_shape.GetDim(kMergedShapeSecondDim); - if (!domi::CheckInt64MulOverflow(output_size, data_size) || !domi::CheckInt64MulOverflow(step, data_size)) { + if (!CheckInt64MulOverflow(output_size, data_size) || !CheckInt64MulOverflow(step, data_size)) { GELOGW("Check int64 mul overflow failed. Output_size is %ld, data_size is %ld, step is %ld.", output_size, data_size, step); return NOT_CHANGED; @@ -196,7 +193,7 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec allowance += data_unit; } indices_set.insert(input_indices[j]); - if (!domi::CheckInt64MulOverflow(input_indices[j], data_unit)) { + if (!CheckInt64MulOverflow(input_indices[j], data_unit)) { GELOGW("Check int64 mul overflow failed. Indices is %ld, data_unit is %ld.", input_indices[j], data_unit); return NOT_CHANGED; } diff --git a/src/ge/graph/passes/folding_kernel/empty_kernel.cc b/src/ge/graph/passes/folding_kernel/empty_kernel.cc index 6d882ef9..1b135b9c 100644 --- a/src/ge/graph/passes/folding_kernel/empty_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/empty_kernel.cc @@ -28,8 +28,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::EMPTY; - namespace ge { namespace { const size_t kEmptyFirstInput = 0; diff --git a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc b/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc index 6abf6cfb..f4091d2d 100644 --- a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc @@ -25,8 +25,6 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" -using domi::EXPANDDIMS; - namespace ge { namespace { const int kExpandDimsIndexZero = 0; diff --git a/src/ge/graph/passes/folding_kernel/fill_kernel.cc b/src/ge/graph/passes/folding_kernel/fill_kernel.cc index 8c453e74..3a3aa597 100644 --- a/src/ge/graph/passes/folding_kernel/fill_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/fill_kernel.cc @@ -27,10 +27,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::FILL; -using ge::fp16_t; -using ge::Status; - namespace { const int kFillInputSize = 2; const int kFillDimsInputIndex = 0; diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc b/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc index dc9602bb..81595822 100644 --- a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc @@ -28,8 +28,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::FLOORDIV; - namespace ge { namespace { const size_t kFloorDivInputX = 0; diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.h b/src/ge/graph/passes/folding_kernel/floordiv_kernel.h index a692ff67..c8505731 100644 --- a/src/ge/graph/passes/folding_kernel/floordiv_kernel.h +++ b/src/ge/graph/passes/folding_kernel/floordiv_kernel.h @@ -47,4 +47,4 @@ class FloorDivKernel : public Kernel { }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_FLOORDIV_KERNEL_H_ \ No newline at end of file +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_FLOORDIV_KERNEL_H_ diff --git a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc b/src/ge/graph/passes/folding_kernel/floormod_kernel.cc index a7fbf1e3..d7fb3b1c 100644 --- a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/floormod_kernel.cc @@ -27,8 +27,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::FLOORMOD; - namespace ge { namespace { const size_t kFloorModInputX = 0; diff --git a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc b/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc index 916708f1..732e0b53 100644 --- a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc @@ -29,9 +29,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::GATHERV2; -using ge::fp16_t; - namespace ge { namespace { const size_t kGatherV2InputIndexZero = 0; @@ -177,7 +174,7 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x if (data_num <= 0) { return PARAM_INVALID; } - if (!domi::CheckInt64MulOverflow(data_num, sizeof(T))) { + if (!CheckInt64MulOverflow(data_num, sizeof(T))) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T)); return PARAM_INVALID; } @@ -221,7 +218,7 @@ Status GatherV2Kernel::CalcStride(std::vector &stride, std::vector= 0) { size_t index = static_cast(i) + kGatherV2DimOne; - if (!domi::CheckInt64MulOverflow(stride[index], dims[index])) { + if (!CheckInt64MulOverflow(stride[index], dims[index])) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/greater_kernel.cc b/src/ge/graph/passes/folding_kernel/greater_kernel.cc index 944cd1b2..816d3d05 100644 --- a/src/ge/graph/passes/folding_kernel/greater_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/greater_kernel.cc @@ -29,10 +29,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::GREATER; using domi::Status; using domi::SUCCESS; -using ge::fp16_t; namespace ge { namespace { diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.cc b/src/ge/graph/passes/folding_kernel/kernel_utils.cc index 9448b232..2002643a 100644 --- a/src/ge/graph/passes/folding_kernel/kernel_utils.cc +++ b/src/ge/graph/passes/folding_kernel/kernel_utils.cc @@ -113,12 +113,26 @@ bool KernelUtils::CheckSizeForTransOp(const ge::ConstGeTensorPtr &const_weight_p GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - if ((shape_size != 0) || (length != 0 && (data_size / static_cast(length) != 1))) { - if (!(data_size == static_cast(cal_size) && data_size != 0)) { + if (shape_size != 0) { + // Standard tensor + if (data_size != static_cast(cal_size) || data_size == 0) { + GELOGW("Const input data size is not equal with tensor desc shape"); + return false; + } + } else if (data_shape.GetDimNum() != 0) { + // Empty tensor, has zero in shape vector + if (data_size != 0) { + GELOGW("Const input data size is not equal with tensor desc shape"); + return false; + } + } else { + // Scalar tensor, has only one element in tensor + if (length != 0 && (data_size / static_cast(length) != 1)) { GELOGW("Const input data size is not equal with tensor desc shape"); return false; } } + return true; } diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.h b/src/ge/graph/passes/folding_kernel/kernel_utils.h index 05f201e9..17b645aa 100644 --- a/src/ge/graph/passes/folding_kernel/kernel_utils.h +++ b/src/ge/graph/passes/folding_kernel/kernel_utils.h @@ -29,6 +29,7 @@ namespace ge { class KernelUtils { public: KernelUtils() = delete; + ~KernelUtils() = delete; static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); static bool CheckFormatSupported(const NodePtr &node_ptr); static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); @@ -44,7 +45,7 @@ class KernelUtils { template static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { if (data_num > 0) { - if (!domi::CheckInt64MulOverflow(data_num, static_cast(sizeof(T)))) { + if (!CheckInt64MulOverflow(data_num, static_cast(sizeof(T)))) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%zu)", data_num, sizeof(T)); return PARAM_INVALID; } @@ -92,7 +93,7 @@ class KernelUtils { vec_dim.clear(); break; } - if (!domi::CheckInt64MulOverflow(data_num, dim)) { + if (!CheckInt64MulOverflow(data_num, dim)) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) dim(%ld)", data_num, static_cast(dim)); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc b/src/ge/graph/passes/folding_kernel/maximum_kernel.cc index 89b3b159..5f83f0d5 100644 --- a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/maximum_kernel.cc @@ -29,9 +29,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::MAXIMUM; -using ge::fp16_t; - namespace ge { namespace { const size_t kMaximumInputNum = 2; diff --git a/src/ge/graph/passes/folding_kernel/mul_kernel.cc b/src/ge/graph/passes/folding_kernel/mul_kernel.cc index 4b1984e2..4ca740d1 100644 --- a/src/ge/graph/passes/folding_kernel/mul_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/mul_kernel.cc @@ -29,8 +29,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::MUL; - namespace ge { namespace { const std::set kMulSupportedType = {DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, diff --git a/src/ge/graph/passes/folding_kernel/pack_kernel.cc b/src/ge/graph/passes/folding_kernel/pack_kernel.cc index f9587771..5db3b394 100644 --- a/src/ge/graph/passes/folding_kernel/pack_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/pack_kernel.cc @@ -29,7 +29,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::PACK; namespace { const int64_t kShapeItemNumMAX = 2000000000; } // namespace @@ -68,8 +67,8 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, return PARAM_INVALID; } if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { - GELOGE(PARAM_INVALID, "Attr %s is not exist.", PACK_ATTR_NAME_NUM.c_str()); - return PARAM_INVALID; + n_ = 0; + GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); } if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); @@ -106,11 +105,7 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } - // check if tensor contains data - if (input[i]->GetData().size() == 0) { - GELOGW("Inputs %ld do not have value.", i); - return NOT_CHANGED; - } + if (i == 0) { // get first input shape shape = input[0]->GetTensorDesc().GetShape(); @@ -128,8 +123,8 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v auto dst_shape = tensor_desc.GetShape(); int64_t num = 1; for (auto dim : dst_shape.GetDims()) { - if (dim < 1) { - GELOGW("Invalid zero dim in the shape %s", formats::ShapeToString(shape).c_str()); + if (dim < 0) { + GELOGW("Invalid dim ld% in the shape %s", dim, formats::ShapeToString(shape).c_str()); return NOT_CHANGED; } num *= dim; @@ -142,6 +137,12 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v GELOGW("Shape of input %ld is not equal wiht input 0.", i); return NOT_CHANGED; } + + // check tensor data size is zero ot not + if (input[i]->GetData().size() == 0 && num != 0) { + GELOGW("Inputs %ld do not have value.", i); + return NOT_CHANGED; + } } return SUCCESS; } @@ -168,6 +169,13 @@ void PackKernel::ExpandDims(const int64_t axis, const std::vector &input, ge::GeTensorPtr &output_ptr) { + output_ptr->MutableTensorDesc().SetShape(final_shape); + output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); + if (final_shape.GetShapeSize() == 0 && final_shape.GetDims().size() != 0) { + // means has zero in shape list, output tnesor data is []. + return SUCCESS; + } + int64_t times = 1; int64_t unit = 1; // calculate data unit @@ -211,8 +219,6 @@ Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector< if (output_ptr->SetData(buf.get(), static_cast(output_size * data_size)) != GRAPH_SUCCESS) { GELOGW("CopyOutputData: SetData failed"); } - output_ptr->MutableTensorDesc().SetShape(final_shape); - output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); return SUCCESS; } diff --git a/src/ge/graph/passes/folding_kernel/permute_kernel.cc b/src/ge/graph/passes/folding_kernel/permute_kernel.cc index ce0737f0..4f0225ac 100644 --- a/src/ge/graph/passes/folding_kernel/permute_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/permute_kernel.cc @@ -33,13 +33,6 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "framework/common/ge_inner_error_codes.h" -using domi::PARAM_INVALID; -using domi::PERMUTE; -using domi::Status; -using domi::SUCCESS; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace ge { namespace { const char *const kAttrOrder = "order"; @@ -139,6 +132,5 @@ Status PermuteKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetName().c_str()); return PARAM_INVALID; } - if (data_tensor->GetData().size() == 0 || axis_tensor->GetData().size() == 0) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); - return PARAM_INVALID; - } + DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", @@ -153,7 +148,6 @@ Status ReduceProdKernel::DataCal(const std::vector &input, static_cast(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, GELOGW("set data failed"); return INTERNAL_ERROR); - output_ptr->MutableTensorDesc().SetDataType(data_dtype); } return SUCCESS; } @@ -165,10 +159,7 @@ void ReduceProdKernel::ShapeCal(const ge::OpDescPtr &op_desc_ptr, const std::vec vector data_dims = data_tensor->GetTensorDesc().GetShape().GetDims(); int32_t data_dim_size = static_cast(data_dims.size()); const uint8_t *axis_data = axis_tensor->GetData().GetData(); - if (axis_data == nullptr) { - DOMI_LOGE(param axis_data must not be null.); - return; - } + GE_CHECK_NOTNULL_EXEC(axis_data, return ); int32_t axis = *(const_cast(reinterpret_cast(axis_data))); bool keep_dims = false; if (!AttrUtils::GetBool(op_desc_ptr, "keep_dims", keep_dims)) { @@ -262,19 +253,32 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec if (ret != SUCCESS) { return NOT_CHANGED; } + } else if (input.at(kReduceProdAxisIndex)->GetData().size() == 0) { + // axis tensor value is [], means no process for input + output_ptr->MutableTensorDesc().SetShape(input.at(kReduceProdDataIndex)->GetTensorDesc().GetShape()); + output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); + if (output_ptr->SetData(input.at(kReduceProdDataIndex)->GetData()) != GRAPH_SUCCESS) { + GELOGW("Compute: SetData failed"); + } } else { // calculate axis to reduce ret = AxisCal(input); if (ret != SUCCESS) { return NOT_CHANGED; } - // calculate data and data type - ret = DataCal(input, output_ptr); - if (ret != SUCCESS) { - return NOT_CHANGED; - } - // calculate shape + // calculate and set shape ShapeCal(op_desc_ptr, input, output_ptr); + // set data type + output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); + + // data size == 0 means input tensor has zero in shape, and tensor value is []. + if (input.at(kReduceProdDataIndex)->GetData().size() != 0) { + // calculate data and data type + ret = DataCal(input, output_ptr); + if (ret != SUCCESS) { + return NOT_CHANGED; + } + } } // print output tensor information, and will be deleted diff --git a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h index 4b858b4a..326dd2f5 100644 --- a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h +++ b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h @@ -42,4 +42,4 @@ class ReduceProdKernel : public Kernel { }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_REDUCE_PROD_KERNEL_H_ \ No newline at end of file +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_REDUCE_PROD_KERNEL_H_ diff --git a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc b/src/ge/graph/passes/folding_kernel/reformat_kernel.cc index 1e43a073..8829d4c4 100644 --- a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/reformat_kernel.cc @@ -26,8 +26,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::REFORMAT; - namespace ge { namespace { const size_t kReFormatInputSize = 1; diff --git a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc b/src/ge/graph/passes/folding_kernel/reshape_kernel.cc index 525b4e03..4e925836 100644 --- a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/reshape_kernel.cc @@ -23,8 +23,6 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" -using domi::RESHAPE; - namespace ge { namespace { const int kReshapeDataIndex = 0; diff --git a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc b/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc index 809578eb..25e81713 100644 --- a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc @@ -28,10 +28,6 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" -using domi::PARAM_INVALID; -using domi::RSQRT; -using domi::SUCCESS; - namespace ge { namespace { const size_t kRsqrtInputSize = 1; @@ -49,6 +45,10 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetTensorDesc().GetDataType() != DT_FLOAT) { + GELOGW("input data type must be FP32."); + return NOT_CHANGED; + } const GeShape &x_shape = input_->GetTensorDesc().GetShape(); size_t data_size = input_->GetData().size(); diff --git a/src/ge/graph/passes/folding_kernel/shape_kernel.cc b/src/ge/graph/passes/folding_kernel/shape_kernel.cc index 9cb005c9..f7475b91 100644 --- a/src/ge/graph/passes/folding_kernel/shape_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/shape_kernel.cc @@ -24,8 +24,6 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" -using domi::SHAPE; - namespace ge { namespace { const size_t kShapeInputSize = 1; diff --git a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc b/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc index b7844876..8ed546de 100644 --- a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc @@ -24,8 +24,6 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" -using domi::SHAPEN; - namespace ge { Status ShapeNKernel::Compute(const NodePtr &node, std::vector &v_output) { GELOGD("ShapeN kernel in"); diff --git a/src/ge/graph/passes/folding_kernel/size_kernel.cc b/src/ge/graph/passes/folding_kernel/size_kernel.cc index 8f9ef8dd..3b121ba4 100644 --- a/src/ge/graph/passes/folding_kernel/size_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/size_kernel.cc @@ -30,7 +30,6 @@ #include "inc/kernel_factory.h" #include "omg/omg_inner_types.h" -using domi::SIZE; namespace ge { namespace { const size_t kSizeInputSize = 1; @@ -63,7 +62,7 @@ Status SizeKernel::Compute(const NodePtr &node, std::vector &v_outp int64_t size = 1; // Calculate the number of elements of the sensor for (int64_t dim : op_desc->GetInputDesc(0).GetShape().GetDims()) { - if (!domi::CheckInt64MulOverflow(size, dim)) { + if (!CheckInt64MulOverflow(size, dim)) { GELOGE(INTERNAL_ERROR, "int64 overflow!"); return INTERNAL_ERROR; } diff --git a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc b/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc index aaac2b44..900828c2 100644 --- a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc @@ -26,9 +26,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::SLICED; -using ge::fp16_t; - namespace ge { namespace { const int64_t kDimMinusOne = -1; diff --git a/src/ge/graph/passes/folding_kernel/slice_kernel.cc b/src/ge/graph/passes/folding_kernel/slice_kernel.cc index 30baa934..a1250367 100644 --- a/src/ge/graph/passes/folding_kernel/slice_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/slice_kernel.cc @@ -25,8 +25,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::SLICE; - namespace ge { namespace { const size_t kSliceInputSize = 3; diff --git a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc b/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc index dec5db50..b253f9a9 100644 --- a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc @@ -23,8 +23,6 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" -using domi::SQUEEZE; - namespace { constexpr uint32_t kInputDescIndex = 0; constexpr uint32_t kOutputDescIndex = 0; diff --git a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc b/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc index 42e97a7e..15985c5d 100644 --- a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc @@ -24,25 +24,12 @@ #include "common/math/math_util.h" #include "common/math_util.h" #include "common/types.h" -#include "framework/common/op/attr_define.h" #include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/attr_utils.h" #include "inc/kernel_factory.h" -using domi::NnSet; -using domi::SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; -using domi::SSD_PRIOR_BOX_ATTR_CLIP; -using domi::SSD_PRIOR_BOX_ATTR_FLIP; -using domi::SSD_PRIOR_BOX_ATTR_IMG_H; -using domi::SSD_PRIOR_BOX_ATTR_IMG_W; -using domi::SSD_PRIOR_BOX_ATTR_MAX_SIZE; -using domi::SSD_PRIOR_BOX_ATTR_MIN_SIZE; -using domi::SSD_PRIOR_BOX_ATTR_OFFSET; -using domi::SSD_PRIOR_BOX_ATTR_STEP_H; -using domi::SSD_PRIOR_BOX_ATTR_STEP_W; -using domi::SSD_PRIOR_BOX_ATTR_VARIANCE; -using domi::SSDPRIORBOX; namespace ge { namespace { const float kMinistBias = 1e-6; diff --git a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc b/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc index fa89249d..3448a071 100644 --- a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc @@ -27,13 +27,6 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" -using domi::STRIDE_SLICE_ATTR_BEGIN_MASK; -using domi::STRIDE_SLICE_ATTR_ELLIPSIS_MASK; -using domi::STRIDE_SLICE_ATTR_END_MASK; -using domi::STRIDE_SLICE_ATTR_NEW_AXIS_MASK; -using domi::STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; -using domi::STRIDEDSLICE; - namespace ge { namespace { const int32_t kNumOne = 1; @@ -181,7 +174,7 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vectorGetData().data(); - if (op_desc_ptr == nullptr || src_data == nullptr) { + + // src_data == nullptr is supported + if (op_desc_ptr == nullptr) { GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/transpose_kernel.cc b/src/ge/graph/passes/folding_kernel/transpose_kernel.cc new file mode 100644 index 00000000..da5c71d9 --- /dev/null +++ b/src/ge/graph/passes/folding_kernel/transpose_kernel.cc @@ -0,0 +1,161 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/folding_kernel/transpose_kernel.h" +#include +#include +#include "common/debug/log.h" +#include "common/formats/format_transfers/format_transfer_transpose.h" +#include "common/formats/formats.h" +#include "common/formats/utils/formats_trans_utils.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/passes/folding_kernel/kernel_utils.h" +#include "graph/utils/type_utils.h" +#include "inc/kernel_factory.h" + +namespace ge { +namespace { +const size_t kTransposeInputX = 0; +const size_t kTransposeInputPerm = 1; +const size_t kTransposeInputSize = 2; +const size_t kTransposeOutputY = 0; +const size_t kTransposeOutputSize = 1; +} // namespace + +Status TransposeKernel::ValidateInput(const OpDescPtr &op_desc_ptr, const std::vector &input) { + if (op_desc_ptr == nullptr) { + GELOGW("Input opDescPtr is nullptr."); + return PARAM_INVALID; + } + if (op_desc_ptr->GetInputsSize() != kTransposeInputSize || op_desc_ptr->GetOutputsSize() != kTransposeOutputSize) { + GELOGW("The input_size(%zu) and output_size(%zu) of op are invalid, op name: %s.", op_desc_ptr->GetInputsSize(), + op_desc_ptr->GetOutputsSize(), op_desc_ptr->GetName().c_str()); + return PARAM_INVALID; + } + if (input.size() != kTransposeInputSize) { + GELOGW("The size of input tensor vector is invalid, input size is %zu, op name: %s.", input.size(), + op_desc_ptr->GetName().c_str()); + return PARAM_INVALID; + } + ConstGeTensorPtr tensor_x_ptr = input[kTransposeInputX]; + ConstGeTensorPtr tensor_perm_ptr = input[kTransposeInputPerm]; + if (tensor_x_ptr == nullptr || tensor_perm_ptr == nullptr) { + GELOGW("Input tensor of op is nullptr, node name: %s.", op_desc_ptr->GetName().c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status TransposeKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) { + GELOGD("TransposeKernel in."); + Status status = ValidateInput(op_desc_ptr, input); + if (status != SUCCESS) { + GELOGW("TransposeKernel input is invalid, failed to fold node."); + return NOT_CHANGED; + } + + ConstGeTensorPtr const_weight_ptr = input[kTransposeInputX]; + GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(kTransposeOutputY); + GeTensorDesc op_desc_in = op_desc_ptr->GetInputDesc(kTransposeInputX); + auto src_format = op_desc_in.GetFormat(); + auto src_shape = op_desc_in.GetShape().GetDims(); + auto src_data_type = op_desc_in.GetDataType(); + auto data_shape = op_desc.GetShape().GetDims(); + auto data_format = op_desc.GetFormat(); + auto data_type = op_desc.GetDataType(); + GELOGD( + "current node %s, format %s, input shape %s, data type %s, weight format %s, shape %s, data type %s. " + "output format %s, shape %s, data type %s", + op_desc_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(src_format).c_str(), + formats::ShapeToString(src_shape).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::FormatToSerialString(const_weight_ptr->GetTensorDesc().GetFormat()).c_str(), + formats::ShapeToString(const_weight_ptr->GetTensorDesc().GetShape()).c_str(), + TypeUtils::DataTypeToSerialString(const_weight_ptr->GetTensorDesc().GetDataType()).c_str(), + TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + + ConstGeTensorPtr tensor_perm_ptr = input[kTransposeInputPerm]; + DataType data_dtype = tensor_perm_ptr->GetTensorDesc().GetDataType(); + auto input_perm_shape = tensor_perm_ptr->GetTensorDesc().GetShape(); + auto output_size = input_perm_shape.GetShapeSize(); + uint32_t data_size = GetSizeByDataType(data_dtype); + if (static_cast(output_size * data_size) != tensor_perm_ptr->GetData().size()) { + GELOGW("TransposeKernel input perm shape size and data size do not match."); + return NOT_CHANGED; + } + + vector perm_list; + auto input_perm = tensor_perm_ptr->GetData().data(); + if (data_dtype == DT_INT32) { + int32_t *input_perm_data = const_cast(reinterpret_cast(input_perm)); + for (int64_t i = 0; i < output_size; i++) { + perm_list.push_back(static_cast(input_perm_data[i])); + } + } else if (data_dtype == DT_INT64) { + int64_t *input_perm_data = const_cast(reinterpret_cast(input_perm)); + for (int64_t i = 0; i < output_size; i++) { + perm_list.push_back(input_perm_data[i]); + } + } else { + GELOGW("TransposeKernel input perm data type is invalid, data type is %s.", + TypeUtils::DataTypeToSerialString(data_dtype).c_str()); + return NOT_CHANGED; + } + + GELOGD("Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), + formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), + formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + if ((data_shape.empty()) || (src_data_type != data_type)) { + GELOGW("Transpose is not supported. Invalid shape (src: %s, dst: %s) or inconsistent datatype (src: %s, dst: %s)", + formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), + TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return NOT_CHANGED; + } + if (!KernelUtils::CheckSizeForTransOp(const_weight_ptr, op_desc_ptr)) { + GELOGE(FAILED, "CheckSize failed, input size is not equal to weight size"); + return NOT_CHANGED; + } + const uint8_t *src_data = const_weight_ptr->GetData().data(); + formats::TransResult trans_result; + auto ret = formats::TransposeWithShapeCheck(src_data, src_shape, data_shape, src_data_type, perm_list, trans_result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), + formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), + formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); + return NOT_CHANGED; + } + + GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kTransposeOutputY)); + GE_CHECK_NOTNULL(output_ptr); + if (output_ptr->SetData(trans_result.data.get(), trans_result.length) != GRAPH_SUCCESS) { + GELOGW("Compute: SetData failed"); + } + v_output.push_back(output_ptr); + + GELOGI("TransposeKernel success."); + return SUCCESS; +} + +REGISTER_KERNEL(TRANSPOSE, TransposeKernel); +} // namespace ge diff --git a/src/ge/graph/passes/folding_kernel/transpose_kernel.h b/src/ge/graph/passes/folding_kernel/transpose_kernel.h new file mode 100644 index 00000000..bb073c15 --- /dev/null +++ b/src/ge/graph/passes/folding_kernel/transpose_kernel.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FOLDING_KERNEL_TRANSPOSE_KERNEL_H_ +#define GE_GRAPH_PASSES_FOLDING_KERNEL_TRANSPOSE_KERNEL_H_ + +#include +#include "inc/kernel.h" + +namespace ge { +class TransposeKernel : public Kernel { + public: + Status Compute(const OpDescPtr attr, const std::vector &input, + std::vector &v_output) override; + + private: + Status ValidateInput(const OpDescPtr &attr, const std::vector &input); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_TRANSPOSE_KERNEL_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc b/src/ge/graph/passes/folding_kernel/unpack_kernel.cc index 985f822b..44f666fa 100644 --- a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/unpack_kernel.cc @@ -22,8 +22,6 @@ #include "graph/debug/ge_attr_define.h" #include "inc/kernel_factory.h" -using domi::UNPACK; - namespace ge { namespace { const size_t kUnpackInputNum = 1; @@ -63,16 +61,17 @@ Status UnpackKernel::Compute(const OpDescPtr attr, const std::vectorGetTensorDesc().GetShape().GetDimNum() != 1) { + GELOGW("input tensor not 1 dim"); + return NOT_CHANGED; + } + ge::DataType data_type; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetDataType(attr, ATTR_NAME_T, data_type), PARAM_INVALID, "get T attr failed."); // data_type must be FLOAT or INT32 GE_CHK_BOOL_RET_STATUS((data_type == DT_FLOAT || data_type == DT_INT32), PARAM_INVALID, "T must be float or int32."); - // input dim size must = 1 - GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE((dims->GetTensorDesc().GetShape().GetDimNum() == 1), - "input tensor must be 1 dim, real is %zu.", - dims->GetTensorDesc().GetShape().GetDimNum()); - int64_t num = 0; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetInt(attr, UNPACK_ATTR_NAME_NUM, num), PARAM_INVALID, "get num attr failed."); size_t data_count = dims->GetData().size() / sizeof(float); diff --git a/src/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc index dedf095d..41528ec3 100644 --- a/src/ge/graph/passes/folding_pass.cc +++ b/src/ge/graph/passes/folding_pass.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "framework/common/debug/ge_log.h" #include "graph/utils/graph_utils.h" @@ -27,7 +28,6 @@ #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "graph/debug/ge_attr_define.h" -#include "framework/common/op/attr_define.h" #include "ge_local_engine/engine/host_cpu_engine.h" namespace ge { @@ -39,8 +39,8 @@ shared_ptr GetKernelByType(const NodePtr &node) { } KernelFactory &factory = KernelFactory::Instance(); string type = node->GetType(); - if (type == domi::FRAMEWORKOP) { - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { + if (type == FRAMEWORKOP) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { return nullptr; } } @@ -49,7 +49,7 @@ shared_ptr GetKernelByType(const NodePtr &node) { } bool IsNoNeedConstantFolding(const NodePtr &node) { auto node_desc = node->GetOpDesc(); - return node_desc == nullptr || node_desc->HasAttr(domi::ATTR_NO_NEED_CONSTANT_FOLDING); + return node_desc == nullptr || node_desc->HasAttr(ATTR_NO_NEED_CONSTANT_FOLDING); } } // namespace folding_pass @@ -100,7 +100,7 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens } desc->SetName(name); - desc->SetType(domi::IDENTITY); + desc->SetType(IDENTITY); auto ret = desc->AddInputDesc(tensor); auto ret2 = desc->AddOutputDesc(tensor); if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { @@ -170,7 +170,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { if (in_node == nullptr) { continue; } - if ((in_node->GetType() == domi::SWITCH) || (in_node->GetType() == domi::REFSWITCH)) { + if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { @@ -257,9 +257,9 @@ Status FoldingPass::AddConstNode(NodePtr &node, IndexsToAnchors indexes_to_ancho } GE_CHECK_NOTNULL(node->GetOpDesc()); std::string stream_label; - if (AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_STREAM_LABEL, stream_label)) { + if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { GE_CHECK_NOTNULL(const_node->GetOpDesc()); - if (!AttrUtils::SetStr(const_node->GetOpDesc(), domi::ATTR_NAME_STREAM_LABEL, stream_label)) { + if (!AttrUtils::SetStr(const_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { GELOGE(INTERNAL_ERROR, "Failed to set stream label on dynamic const node %s, with stream label:%s.", const_node->GetName().c_str(), stream_label.c_str()); return INTERNAL_ERROR; diff --git a/src/ge/graph/passes/for_pass.cc b/src/ge/graph/passes/for_pass.cc new file mode 100644 index 00000000..f63e8627 --- /dev/null +++ b/src/ge/graph/passes/for_pass.cc @@ -0,0 +1,732 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/for_pass.h" +#include "common/ge/ge_util.h" +#include "common/op/ge_op_utils.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" + +namespace { +const uint32_t kWhileIInputIndex = 0; +const uint32_t kWhileNInputIndex = 1; +const uint32_t kWhileStartInputIndex = 2; +const uint32_t kWhileDeltaInputIndex = 3; +const uint32_t kWhileDataInputIndex = 4; +const uint32_t kSubgraphLoopVarInputIndex = 0; +const uint32_t kSubgraphInputIndex = 1; +const uint32_t kWhileOutputIndex = 4; +const std::string kAbs = "Abs"; +} // namespace + +namespace ge { +Status ForPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (node->GetType() != FOR) { + return SUCCESS; + } + + GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str()); + + ComputeGraphPtr graph = node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(graph); + ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); + GE_CHECK_NOTNULL(root_graph); + + ForInfo for_info; + GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info), "Build ForInfo failed, node:%s.", + node->GetName().c_str()); + + WhileInfo while_info; + GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info), "Transfer WhileInfo from ForInfo failed, node:%s.", + node->GetName().c_str()); + + ComputeGraphPtr cond_graph = BuildCondGraph(while_info); + if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str()); + return FAILED; + } + + ComputeGraphPtr body_graph = BuildBodyGraph(while_info); + if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str()); + return FAILED; + } + + GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info), "Update InputMapping for for-body-graph failed, node:%s.", + node->GetName().c_str()); + + // for node has and only has one subgraph + node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0)); + + GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str()); + return IsolateAndDeleteNode(node, std::vector()); +} + +/// +/// @brief Build for_info +/// @param [in] root_graph +/// @param [in] node +/// @param [out] for_info +/// @return Status +/// +Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) { + GELOGI("Begin to build for_info for node %s.", node->GetName().c_str()); + + OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT); + OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT); + OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT); + if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) { + GELOGE(FAILED, "BuildForInfo for %s failed: start / limit / delta is NULL.", node->GetName().c_str()); + return FAILED; + } + + std::vector data_inputs; + std::vector> data_outputs; + std::vector ctrl_inputs; + std::vector ctrl_outputs; + if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) { + GELOGE(FAILED, "BuildForInfo for %s failed: find inputs /outputs failed.", node->GetName().c_str()); + return FAILED; + } + NodeUtils::UnlinkAll(*node); + + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // For node has and only has one sub_graph + std::string for_body_name = op_desc->GetSubgraphInstanceName(0); + if (for_body_name.empty()) { + GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str()); + return FAILED; + } + ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name); + if (for_body == nullptr) { + GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str()); + return FAILED; + } + + for_info.for_node = node; + for_info.start = start; + for_info.limit = limit; + for_info.delta = delta; + for_info.body_name = for_body_name; + for_info.for_body = for_body; + for_info.data_inputs = std::move(data_inputs); + for_info.data_outputs = std::move(data_outputs); + for_info.ctrl_inputs = std::move(ctrl_inputs); + for_info.ctrl_outputs = std::move(ctrl_outputs); + + GELOGI("Build for_info for node %s succ.", node->GetName().c_str()); + return SUCCESS; +} + +/// +/// @brief Find input with index for For node +/// @param [in] node +/// @param [in] index +/// @return OutDataAnchorPtr +/// +OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) { + if (node == nullptr) { + GELOGE(FAILED, "FindInputWithIndex failed: node is NULL."); + return nullptr; + } + + InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); + if (in_data_anchor == nullptr) { + GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); + return nullptr; + } + + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_out_anchor == nullptr) { + GELOGE(FAILED, "FindInputWithIndex %s:%u failed: peer_out_anchor is NULL.", node->GetName().c_str(), index); + return nullptr; + } + + return peer_out_anchor; +} + +/// +/// @brief Find inputs / outputs for for node +/// @param [in] node +/// @param [out] data_inputs +/// @param [out] data_outputs +/// @param [out] ctrl_inputs +/// @param [out] ctrl_outputs +/// @return Status +/// +Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector &data_inputs, + std::vector> &data_outputs, + std::vector &ctrl_inputs, + std::vector &ctrl_outputs) { + GE_CHECK_NOTNULL(node); + + uint32_t input_data_num = node->GetAllInDataAnchorsSize(); + for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) { + InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index); + if (in_data_anchor == nullptr) { + GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index); + return FAILED; + } + data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor()); + } + + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + std::vector peer_in_data_anchors; + for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + peer_in_data_anchors.emplace_back(peer_in_data_anchor); + } + data_outputs.emplace_back(peer_in_data_anchors); + } + + InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + for (auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { + ctrl_inputs.emplace_back(peer_out_ctrl_anchor); + } + + OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + for (auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { + ctrl_outputs.emplace_back(peer_in_ctrl_anchor); + } + + return SUCCESS; +} + +/// +/// @brief Transfer while_info from for_info +/// @param [in] graph +/// @param [in] for_info +/// @param [out] while_info +/// @return Status +/// +Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) { + std::string for_name = for_info.for_node->GetName(); + GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str()); + + std::string i_name = for_name + "_i"; + NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0)); + if (i_node == nullptr) { + GELOGE(FAILED, "TranWhileInfo failed: create i_node failed."); + return FAILED; + } + AddRePassNode(i_node); + + // Const node has and only has one output + OutDataAnchorPtr i_input = i_node->GetOutDataAnchor(0); + if (i_input == nullptr) { + GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL."); + return FAILED; + } + + OutDataAnchorPtr n_input = CreateLoopCountInput(graph, for_info); + if (n_input == nullptr) { + GELOGE(FAILED, "TranWhileInfo failed: n_input is NULL."); + return FAILED; + } + + BuildWhileInfo(for_info, i_input, n_input, while_info); + + if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) { + GELOGE(FAILED, "TranWhileInfo failed: insert while node failed."); + return FAILED; + } + + GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.", for_name.c_str(), + while_info.while_node->GetName().c_str()); + return SUCCESS; +} + +/// +/// @brief Create const op_desc +/// @param [in] name +/// @param [in] value +/// @return OpDescPtr +/// +OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) { + OpDescPtr const_op_desc = MakeShared(name, CONSTANT); + if (const_op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str()); + return nullptr; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_NCHW, DT_INT32); + GeTensorPtr const_value = MakeShared(data_desc, reinterpret_cast(&value), sizeof(int32_t)); + if (const_value == nullptr) { + GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str()); + return nullptr; + } + + if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) { + GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str()); + return nullptr; + } + + if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str()); + return nullptr; + } + + return const_op_desc; +} + +/// +/// @brief Create loop_count node +/// @param [in] graph +/// @param [in] for_info +/// @return OutDataAnchorPtr +/// +OutDataAnchorPtr ForPass::CreateLoopCountInput(const ComputeGraphPtr &graph, const ForInfo &for_info) { + std::string for_name = for_info.for_node->GetName(); + GELOGD("Begin to create loop_count input, node:%s", for_name.c_str()); + + OutDataAnchorPtr start = for_info.start; + OutDataAnchorPtr limit = for_info.limit; + OutDataAnchorPtr delta = for_info.delta; + + std::string sub_name_0 = for_name + "_Sub_0"; + std::string abs_name_0 = for_name + "_Abs_0"; + std::string abs_name_1 = for_name + "_Abs_1"; + std::string add_name_0 = for_name + "_Add_0"; + std::string const_name = for_name + "_Const"; + std::string sub_name_1 = for_name + "_Sub_1"; + std::string cast_name_0 = for_name + "_Cast_0"; + std::string cast_name_1 = for_name + "_Cast_1"; + std::string div_name = for_name + "_RealDiv"; + std::string cast_name_2 = for_name + "_Cast_2"; + + // n = cast(cast(abs(limit-start) + abs(delta) - 1, float) / cast(abs(delta), float), int32) + PartialGraphBuilder graph_builder; + graph_builder.SetOwnerGraph(graph) + .AddExistNode(for_info.start->GetOwnerNode()) + .AddExistNode(for_info.limit->GetOwnerNode()) + .AddExistNode(for_info.delta->GetOwnerNode()) + .AddNode(CreateOpDesc(sub_name_0, SUB, false)) + .AddNode(CreateOpDesc(abs_name_0, kAbs, true)) + .AddNode(CreateOpDesc(abs_name_1, kAbs, true)) + .AddNode(CreateOpDesc(add_name_0, ADD, false)) + .AddNode(CreateConstDesc(const_name, 1)) + .AddNode(CreateOpDesc(sub_name_1, SUB, false)) + .AddNode(CreateCastDesc(cast_name_0, DT_INT32, DT_FLOAT)) + .AddNode(CreateCastDesc(cast_name_1, DT_INT32, DT_FLOAT)) + .AddNode(CreateOpDesc(div_name, REALDIV, false)) + .AddNode(CreateCastDesc(cast_name_2, DT_FLOAT, DT_INT32)) + .AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0) + .AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1) + .AddDataLink(sub_name_0, 0, abs_name_0, 0) + .AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_1, 0) + .AddDataLink(abs_name_0, 0, add_name_0, 0) + .AddDataLink(abs_name_1, 0, add_name_0, 1) + .AddDataLink(add_name_0, 0, sub_name_1, 0) + .AddDataLink(const_name, 0, sub_name_1, 1) + .AddDataLink(sub_name_1, 0, cast_name_0, 0) + .AddDataLink(abs_name_1, 0, cast_name_1, 0) + .AddDataLink(cast_name_0, 0, div_name, 0) + .AddDataLink(cast_name_1, 0, div_name, 1) + .AddDataLink(div_name, 0, cast_name_2, 0); + + graphStatus error_code = GRAPH_SUCCESS; + std::string error_msg; + if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); + return nullptr; + } + + NodePtr loop_count_node = graph_builder.GetNode(cast_name_2); + if (loop_count_node == nullptr) { + GELOGE(FAILED, "Create loop_count node failed: node is NULL."); + return nullptr; + } + + GELOGD("Create loop_count input succ, node:%s", for_name.c_str()); + // loop_count_node is a Cast node, has and only has one output + return loop_count_node->GetOutDataAnchor(0); +} + +/// +/// @brief Create cast op_desc +/// @param [in] name +/// @param [in] src_data_type +/// @param [in] dst_data_type +/// @return OpDescPtr +/// +OpDescPtr ForPass::CreateCastDesc(const std::string &name, DataType src, DataType dst) { + OpDescPtr cast_desc = CreateOpDesc(name, CAST, true); + if (cast_desc == nullptr) { + GELOGE(FAILED, "Create cast op_desc failed, node: %s.", name.c_str()); + return nullptr; + } + + // cast node has and only has one input /output + GeTensorDesc in_tensor = cast_desc->GetInputDesc(0); + in_tensor.SetDataType(src); + GeTensorDesc out_tensor = cast_desc->GetOutputDesc(0); + out_tensor.SetDataType(dst); + if ((cast_desc->UpdateInputDesc(0, in_tensor) != GRAPH_SUCCESS) || + (cast_desc->UpdateOutputDesc(0, out_tensor) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Update tensor failed."); + return nullptr; + } + + if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, src) && AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, dst) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, dst) && + AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR failed, node: %s.", name.c_str()); + return nullptr; + } + + return cast_desc; +} + +/// +/// @brief Create op_desc +/// @param [in] name +/// @param [in] type +/// @param [in] io_equal_flag +/// @return OpDescPtr +/// +OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) { + OpDescBuilder op_desc_builder(name, type); + if (io_equal_flag) { + op_desc_builder.AddInput("x").AddOutput("y"); + } else { + op_desc_builder.AddInput("x1").AddInput("x2").AddOutput("y"); + } + + return op_desc_builder.Build(); +} + +/// +/// @brief Build while-info +/// @param [in] for_info +/// @param [in] i_input +/// @param [in] n_input +/// @param [out] while_info +/// @return void +/// +void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &n_input, + WhileInfo &while_info) { + while_info.i = i_input; + while_info.n = n_input; + while_info.start = for_info.start; + while_info.delta = for_info.delta; + while_info.for_body_name = for_info.body_name; + while_info.for_body = for_info.for_body; + while_info.data_inputs.emplace_back(while_info.i); + while_info.data_inputs.emplace_back(while_info.n); + while_info.data_inputs.emplace_back(while_info.start); + while_info.data_inputs.emplace_back(while_info.delta); + for (auto &item : for_info.data_inputs) { + while_info.data_inputs.emplace_back(item); + } + for (auto &item : for_info.data_outputs) { + while_info.data_outputs.emplace_back(item); + } + for (auto &item : for_info.ctrl_inputs) { + while_info.ctrl_inputs.emplace_back(item); + } + for (auto &item : for_info.ctrl_outputs) { + while_info.ctrl_outputs.emplace_back(item); + } +} + +/// +/// @brief Insert while_node +/// @param [in] graph +/// @param [in] name +/// @param [in&out] while_info +/// @return Status +/// +Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) { + GELOGD("Begin to create while node, name:%s.", name.c_str()); + + size_t arg_num = while_info.data_inputs.size(); + OpDescBuilder op_desc_builder(name, WHILE); + OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build(); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str()); + return FAILED; + } + NodePtr while_node = graph->AddNode(op_desc); + if (while_node == nullptr) { + GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str()); + return FAILED; + } + AddRePassNode(while_node); + + while_info.while_node = while_node; + if (BuildWhileLink(while_info) != SUCCESS) { + GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str()); + return FAILED; + } + + GELOGD("Create while node succ, name:%s.", name.c_str()); + return SUCCESS; +} + +/// +/// @brief Build while link-edge +/// @param [in] while_info +/// @return Status +/// +Status ForPass::BuildWhileLink(const WhileInfo &while_info) { + NodePtr while_node = while_info.while_node; + GE_CHECK_NOTNULL(while_node); + + size_t input_num = while_info.data_inputs.size(); + for (size_t i = 0; i < input_num; i++) { + InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i); + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i]; + if (peer_out_anchor == nullptr) { + continue; + } + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor), "Add data-edge %s:%d->%s:%d failed.", + peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(), + while_node->GetName().c_str(), i); + } + + size_t output_num = while_info.data_outputs.size(); + for (size_t i = 0; i < output_num; i++) { + OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast(i + kWhileOutputIndex)); + GE_CHECK_NOTNULL(out_data_anchor); + for (auto &peer_in_anchor : while_info.data_outputs[i]) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor), + "Add data-edge %s:%d->%s:%d failed.", while_node->GetName().c_str(), + i + kWhileOutputIndex, peer_in_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_anchor->GetIdx()); + } + } + + InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + for (auto &peer_out_anchor : while_info.ctrl_inputs) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor), "Add ctrl-edge %s->%s failed.", + peer_out_anchor->GetOwnerNode()->GetName().c_str(), + in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); + } + + OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + for (auto &peer_in_anchor : while_info.ctrl_outputs) { + GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor), "Add ctrl-edge %s->%s failed.", + out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_anchor->GetOwnerNode()->GetName().c_str()); + } + + return SUCCESS; +} + +/// +/// @brief Build cond_graph for while_node +/// @param [in&out] while_info +/// @return ComputeGraphPtr +/// +ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) { + std::string cond_name = while_info.for_body_name + "_Cond"; + CompleteGraphBuilder graph_builder(cond_name); + + // Add parent node + graph_builder.SetParentNode(while_info.while_node); + + // Add Node + const std::string less_name = "Less"; + graph_builder.AddNode(CreateOpDesc(less_name, LESS, false)); + + // Set Input + graph_builder.SetInput(kWhileIInputIndex, {less_name}, {0}) + .SetInput(kWhileNInputIndex, {less_name}, {1}) + .SetUselessInput(kWhileStartInputIndex) + .SetUselessInput(kWhileDeltaInputIndex); + size_t input_num = while_info.data_inputs.size(); + for (size_t i = kWhileDataInputIndex; i < input_num; i++) { + graph_builder.SetUselessInput(i); + } + + // Add Output + graph_builder.AddOutput(less_name, 0); + + // Add Input-Mapping + std::map input_mapping; + for (size_t i = 0; i < input_num; i++) { + input_mapping[i] = i; + } + graph_builder.SetInputMapping(input_mapping); + + graphStatus error_code = GRAPH_SUCCESS; + std::string error_msg; + ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg); + if (cond_graph == nullptr) { + GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); + return nullptr; + } + + size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size(); + while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND); + while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name); + while_info.while_cond = cond_graph; + return cond_graph; +} + +/// +/// @brief Build body_graph for while_node +/// @param [in&out] while_info +/// @return ComputeGraphPtr +/// +ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) { + std::string body_name = while_info.for_body_name + "_Body"; + CompleteGraphBuilder graph_builder(body_name); + + // Add parent node + graph_builder.SetParentNode(while_info.while_node); + + // Add calculation nodes + std::string const_name = "Const"; + std::string add_name_0 = "Add_0"; + std::string mul_name = "Mul"; + std::string add_name_1 = "Add_1"; + graph_builder.AddNode(CreateConstDesc(const_name, 1)) + .AddNode(CreateOpDesc(add_name_0, ADD, false)) + .AddNode(CreateOpDesc(mul_name, MUL, false)) + .AddNode(CreateOpDesc(add_name_1, ADD, false)); + + // Add Subgraph node + auto input_num = static_cast(while_info.data_inputs.size()); + std::string sub_graph_node_name = while_info.for_body_name; + uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex; + auto sub_graph_output_num = static_cast(while_info.data_outputs.size()); + graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num)); + + // Set Input + graph_builder.SetInput(kWhileIInputIndex, {add_name_0, mul_name}, {0, 0}) + .SetUselessInput(kWhileNInputIndex) + .SetInput(kWhileStartInputIndex, {add_name_1}, {0}) + .SetInput(kWhileDeltaInputIndex, {mul_name}, {1}); + for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) { + graph_builder.SetInput(i + kWhileDataInputIndex, {sub_graph_node_name}, {i + kSubgraphInputIndex}); + } + + // Add Outputs + graph_builder.AddOutput(add_name_0, 0); + for (uint32_t i = kWhileNInputIndex; i < kWhileDataInputIndex; i++) { + graph_builder.AddOutput("Data_" + std::to_string(i), 0); + } + for (uint32_t i = 0; i < sub_graph_output_num; i++) { + graph_builder.AddOutput(sub_graph_node_name, i); + } + + // Add Edges + graph_builder.AddDataLink(const_name, 0, add_name_0, 1) + .AddDataLink(mul_name, 0, add_name_1, 1) + .AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex); + + // Add Input-Mapping + std::map input_mapping; + for (size_t i = 0; i < input_num; i++) { + input_mapping[i] = i; + } + graph_builder.SetInputMapping(input_mapping); + + // Add outputMapping + std::map output_mapping; + for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) { + output_mapping[i] = i; + } + graph_builder.SetOutputMapping(output_mapping); + + graphStatus error_code = GRAPH_SUCCESS; + std::string error_msg; + ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg); + if (body_graph == nullptr) { + GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str()); + return nullptr; + } + + NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name); + if (sub_graph_node == nullptr) { + GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str()); + return nullptr; + } + while_info.sub_graph_node = sub_graph_node; + + size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size(); + while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY); + while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name); + while_info.while_body = body_graph; + return body_graph; +} + +/// +/// @brief Create op_desc for subgraph node +/// @param [in] name +/// @param [in] input_num +/// @param [in] output_num +/// @return OpDescPtr +/// +OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) { + OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); + op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num); + + OpDescPtr op_desc = op_desc_builder.Build(); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str()); + return nullptr; + } + + size_t index = op_desc->GetSubgraphInstanceNames().size(); + op_desc->AddSubgraphName("f"); + op_desc->SetSubgraphInstanceName(index, name); + return op_desc; +} + +/// +/// @brief Update InputMapping for for-body-graph +/// @param [in] while_info +/// @return Status +/// +Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) { + ComputeGraphPtr for_body = while_info.for_body; + GE_CHECK_NOTNULL(for_body); + + // index_of_cur_graph_node_input -> index_of_new_graph_node_input + std::map input_mapping; + size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT; + for (size_t i = 0; i < input_num; i++) { + if (i == FOR_START_INPUT) { + input_mapping[i] = i; + } else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) { + continue; + } else { + input_mapping[i] = i - 2; + } + } + for_body->UpdateInputMapping(input_mapping); + for_body->SetParentNode(while_info.sub_graph_node); + for_body->SetParentGraph(while_info.while_body); + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/for_pass.h b/src/ge/graph/passes/for_pass.h new file mode 100644 index 00000000..3611171e --- /dev/null +++ b/src/ge/graph/passes/for_pass.h @@ -0,0 +1,192 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FOR_PASS_H +#define GE_GRAPH_PASSES_FOR_PASS_H + +#include "graph/passes/base_pass.h" + +struct ForInfo { + ForInfo() : for_node(nullptr), start(nullptr), limit(nullptr), delta(nullptr), for_body(nullptr) {} + ge::NodePtr for_node; + ge::OutDataAnchorPtr start; + ge::OutDataAnchorPtr limit; + ge::OutDataAnchorPtr delta; + std::string body_name; + ge::ComputeGraphPtr for_body; + std::vector data_inputs; + std::vector> data_outputs; + std::vector ctrl_inputs; + std::vector ctrl_outputs; +}; + +struct WhileInfo { + WhileInfo() : while_node(nullptr), sub_graph_node(nullptr), i(nullptr), n(nullptr), start(nullptr), delta(nullptr) {} + ge::NodePtr while_node; + ge::NodePtr sub_graph_node; + ge::OutDataAnchorPtr i; + ge::OutDataAnchorPtr n; + ge::OutDataAnchorPtr start; + ge::OutDataAnchorPtr delta; + std::string for_body_name; + ge::ComputeGraphPtr for_body; + ge::ComputeGraphPtr while_cond; + ge::ComputeGraphPtr while_body; + std::vector data_inputs; + std::vector> data_outputs; + std::vector ctrl_inputs; + std::vector ctrl_outputs; +}; + +namespace ge { +class ForPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + /// + /// @brief Build for_info + /// @param [in] root_graph + /// @param [in] node + /// @param [out] for_info + /// @return Status + /// + static Status BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info); + + /// + /// @brief Transfer while_info from for_info + /// @param [in] graph + /// @param [in] for_info + /// @param [out] while_info + /// @return Status + /// + Status TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info); + + /// + /// @brief Build cond_graph for while_node + /// @param [in&out] while_info + /// @return ComputeGraphPtr + /// + static ComputeGraphPtr BuildCondGraph(WhileInfo &while_info); + + /// + /// @brief Build body_graph for while_node + /// @param [in&out] while_info + /// @return ComputeGraphPtr + /// + static ComputeGraphPtr BuildBodyGraph(WhileInfo &while_info); + + /// + /// @brief Update InputMapping for for-body-graph + /// @param [in] while_info + /// @return Status + /// + static Status UpdateForBodyInputMapping(const WhileInfo &while_info); + + /// + /// @brief Find input with index for For node + /// @param [in] node + /// @param [in] index + /// @return OutDataAnchorPtr + /// + static OutDataAnchorPtr FindInputWithIndex(const NodePtr &node, uint32_t index); + + /// + /// @brief Find inputs / outputs for for node + /// @param [in] node + /// @param [out] data_inputs + /// @param [out] data_outputs + /// @param [out] ctrl_inputs + /// @param [out] ctrl_outputs + /// @return Status + /// + static Status FindInputsAndOutputs(const NodePtr &node, std::vector &data_inputs, + std::vector> &data_outputs, + std::vector &ctrl_inputs, + std::vector &ctrl_outputs); + + /// + /// @brief Create const op_desc + /// @param [in] name + /// @param [in] value + /// @return OpDescPtr + /// + static OpDescPtr CreateConstDesc(const std::string &name, int32_t value); + + /// + /// @brief Create loop_count input + /// @param [in] graph + /// @param [in] for_info + /// @return OutDataAnchorPtr + /// + OutDataAnchorPtr CreateLoopCountInput(const ComputeGraphPtr &graph, const ForInfo &for_info); + + /// + /// @brief Create cast op_desc + /// @param [in] name + /// @param [in] src_data_type + /// @param [in] dst_data_type + /// @return OpDescPtr + /// + static OpDescPtr CreateCastDesc(const std::string &name, DataType src, DataType dst); + + /// + /// @brief Create op_desc + /// @param [in] name + /// @param [in] type + /// @param [in] io_equal_flag + /// @return OpDescPtr + /// + static OpDescPtr CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag); + + /// + /// @brief Build while-info + /// @param [in] for_info + /// @param [in] i_input + /// @param [in] n_input + /// @param [out] while_info + /// @return void + /// + static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input, const OutDataAnchorPtr &n_input, + WhileInfo &while_info); + + /// + /// @brief Insert while_node + /// @param [in] graph + /// @param [in] name + /// @param [in] while_info + /// @return Status + /// + Status InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info); + + /// + /// @brief Build while link-edge + /// @param [in] while_info + /// @return Status + /// + static Status BuildWhileLink(const WhileInfo &while_info); + + /// + /// @brief Create op_desc for subgraph node + /// @param [in] name + /// @param [in] input_num + /// @param [in] output_num + /// @return OpDescPtr + /// + static OpDescPtr CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_FOR_PASS_H diff --git a/src/ge/graph/passes/get_original_format_pass.cc b/src/ge/graph/passes/get_original_format_pass.cc index d065d581..066c46ea 100644 --- a/src/ge/graph/passes/get_original_format_pass.cc +++ b/src/ge/graph/passes/get_original_format_pass.cc @@ -19,26 +19,18 @@ #include #include "common/debug/log.h" -#include "framework/common/debug/ge_log.h" #include "common/types.h" #include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/omg_inner_types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" -#include "framework/omg/omg_inner_types.h" -using domi::AIPP_DATA_TYPE; -using domi::ATTR_NAME_FORMAT; -using domi::ATTR_NAME_IGNORE_PRED_FORMAT; -using domi::ATTR_NAME_INFERRED_FORMAT; -using domi::BIASADD; -using domi::DATA_TYPE; using domi::DOMI_TENSOR_NCHW; using domi::DOMI_TENSOR_NHWC; using domi::DOMI_TENSOR_RESERVED; using domi::FAILED; using domi::PARAM_INVALID; -using domi::PERMUTE; -using domi::PERMUTE_ATTR_ORDER; using domi::SUCCESS; using domi::GetContext; @@ -70,8 +62,8 @@ Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph GE_CHECK_NOTNULL(desc_ptr); auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE); if (is_data) { - GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetContext().format); - ori_format = static_cast(GetContext().format); + GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), domi::GetContext().format); + ori_format = static_cast(domi::GetContext().format); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_FORMAT failed"); return FAILED); diff --git a/src/ge/graph/passes/guarantee_const_pass.cc b/src/ge/graph/passes/guarantee_const_pass.cc index 8c34b8f5..f099c01d 100644 --- a/src/ge/graph/passes/guarantee_const_pass.cc +++ b/src/ge/graph/passes/guarantee_const_pass.cc @@ -25,8 +25,6 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" -using domi::GUARANTEECONST; - namespace ge { namespace { const uint32_t kGuaranteeConstInputsSize = 1; diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index 60001e30..ac037d62 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -25,9 +25,6 @@ #include "framework/common/types.h" #include "graph/utils/graph_utils.h" -using domi::CONSTANTOP; -using domi::DATA; - namespace { const int32_t kAnchorSize = 1; const int kAnchorNum = 0; @@ -56,7 +53,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { NodePtr src_node = src_out_anchor->GetOwnerNode(); std::string src_type = src_node->GetType(); bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA); - if (check_src_type && node->GetType() == domi::HCOMALLREDUCE) { + if (check_src_type && node->GetType() == HCOMALLREDUCE) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -91,9 +88,9 @@ NodePtr HcclMemcpyPass::CreateMemcpyNode(const ComputeGraphPtr &graph, const Out return nullptr; } - std::string node_name = pre_node->GetName() + "_" + domi::MEMCPYASYNC; + std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), domi::MEMCPYASYNC); + OpDescPtr op_desc = MakeShared(node_name.c_str(), MEMCPYASYNC); if (op_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); return nullptr; @@ -144,8 +141,8 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { /// @return bool /// bool HcclMemcpyPass::NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const { - return (op_desc->GetType() == domi::HCOMALLGATHER || op_desc->GetType() == domi::HCOMALLREDUCE || - op_desc->GetType() == domi::HCOMREDUCESCATTER); + return (op_desc->GetType() == HCOMALLGATHER || op_desc->GetType() == HCOMALLREDUCE || + op_desc->GetType() == HCOMREDUCESCATTER); } /// diff --git a/src/ge/graph/passes/identity_pass.cc b/src/ge/graph/passes/identity_pass.cc index fa6ff4ed..9b15f77a 100644 --- a/src/ge/graph/passes/identity_pass.cc +++ b/src/ge/graph/passes/identity_pass.cc @@ -23,9 +23,6 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" -using domi::IDENTITY; -using domi::IDENTITYN; - namespace ge { namespace { /// @@ -41,7 +38,7 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); return ret; } - if ((node_type != domi::SWITCH) && (node_type != domi::REFSWITCH)) { + if ((node_type != SWITCH) && (node_type != REFSWITCH)) { GELOGD("skip identity %s connected to switch", node->GetName().c_str()); break; } @@ -57,7 +54,7 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); return ret; } - if ((node_type != domi::MERGE) && (node_type != domi::REFMERGE)) { + if ((node_type != MERGE) && (node_type != REFMERGE)) { GELOGD("skip identity %s connected to merge", node->GetName().c_str()); break; } diff --git a/src/ge/graph/passes/isolated_op_remove_pass.cc b/src/ge/graph/passes/isolated_op_remove_pass.cc index c7e52a64..152104eb 100644 --- a/src/ge/graph/passes/isolated_op_remove_pass.cc +++ b/src/ge/graph/passes/isolated_op_remove_pass.cc @@ -20,9 +20,6 @@ #include "common/types.h" #include "common/util.h" -using domi::SUCCESS; -using domi::TO_BE_OUTPUT; - namespace ge { Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index d1fe211c..a5dafdca 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -27,10 +27,8 @@ #include "graph/common/omg_util.h" #include "graph/graph.h" #include "graph/node.h" -#include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" - -using domi::MEMCPYASYNC; +#include "graph/utils/graph_utils.h" namespace ge { const char *const kGetNext = "GetNext"; diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc index 62f8c57a..ff150a54 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -17,8 +17,6 @@ #include "graph/passes/link_gen_mask_nodes_pass.h" #include -#include -#include #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" @@ -28,10 +26,6 @@ using std::set; using std::vector; -using domi::CONSTANT; -using domi::CONSTANTOP; -using domi::DROPOUTDOMASK; - namespace ge { namespace { const size_t kGenMaskInputIndex = 1; @@ -74,8 +68,8 @@ Status LinkGenMaskNodesPass::Run(ComputeGraphPtr graph) { auto dest_anchor = dest_node->GetInControlAnchor(); GE_CHECK_NOTNULL(dest_anchor); - graphStatus status = src_anchor->LinkTo(dest_anchor); - if (status != GRAPH_SUCCESS) { + graphStatus status_link_to = src_anchor->LinkTo(dest_anchor); + if (status_link_to != GRAPH_SUCCESS) { GELOGE(FAILED, "Link from %s to %s failed.", src_node->GetName().c_str(), dest_node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.h b/src/ge/graph/passes/link_gen_mask_nodes_pass.h index decc2d30..f9979ab1 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.h +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.h @@ -17,6 +17,10 @@ #ifndef GE_GRAPH_PASSES_LINK_GEN_MASK_NODES_PASS_H_ #define GE_GRAPH_PASSES_LINK_GEN_MASK_NODES_PASS_H_ +#include +#include +#include + #include "graph/graph.h" #include "inc/graph_pass.h" diff --git a/src/ge/graph/passes/merge_pass.cc b/src/ge/graph/passes/merge_pass.cc index 96dbf37f..f4114474 100644 --- a/src/ge/graph/passes/merge_pass.cc +++ b/src/ge/graph/passes/merge_pass.cc @@ -28,8 +28,6 @@ #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" -using domi::CONSTANT; -using domi::MERGE; using domi::PARAM_INVALID; using domi::SUCCESS; @@ -45,7 +43,7 @@ Status MergePass::Run(NodePtr &node) { std::string op_type; GE_CHK_STATUS_RET(GetOriginalType(node, op_type), "get original type failed"); - if (op_type != domi::MERGE) { + if (op_type != MERGE) { return SUCCESS; } diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index 428fada5..aac72892 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -29,21 +29,10 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -using domi::ATTR_NAME_STREAM_LABEL; - -using domi::NETOUTPUT; -using domi::STREAMACTIVE; -using domi::STREAMMERGE; -using domi::STREAMSWITCHN; -using domi::SWITCHN; - namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); - GraphUtils::DumpGEGraph(graph, "BeforeMultiBatchPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "BeforeMultiBatchPass"); - OutDataAnchorPtr pred_value = nullptr; Status ret = FindPredValue(graph, pred_value); if (ret == NOT_CHANGED) { @@ -75,9 +64,6 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { } } - GraphUtils::DumpGEGraph(graph, "AfterMultiBatchPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "AfterMultiBatchPass"); - GELOGD("MultiBatchPass Leave"); return SUCCESS; } @@ -463,7 +449,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { OpDescPtr op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); const std::string type = op_desc->GetType(); - if ((type == STREAMMERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { + if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { continue; } if (type == NETOUTPUT) { diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index fd4e6b57..6e3f5e46 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -47,4 +47,4 @@ class MultiBatchPass : public GraphPass { std::vector> batch_head_nodes_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ \ No newline at end of file +#endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ diff --git a/src/ge/graph/passes/net_output_pass.cc b/src/ge/graph/passes/net_output_pass.cc index 31b7fb4e..4eed597b 100644 --- a/src/ge/graph/passes/net_output_pass.cc +++ b/src/ge/graph/passes/net_output_pass.cc @@ -30,17 +30,9 @@ #include "graph/utils/type_utils.h" #include "graph/debug/ge_attr_define.h" -using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; -using domi::ATTR_NAME_NET_OUTPUT_DATATYPE; -using domi::ATTR_NAME_NET_OUTPUT_FORMAT; -using domi::ATTR_NAME_TRUE_BRANCH_STREAM; -using domi::NETOUTPUT; -using domi::NODE_NAME_NET_OUTPUT; -using domi::RETVAL_ATTR_NAME_INDEX; - namespace ge { Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, - std::map> &retval_node_index_map) { + std::map &retval_node_index_map) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); int64_t output_index = 0; @@ -52,13 +44,16 @@ Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, GELOGE(PARAM_INVALID, "Retval has duplicate index."); return PARAM_INVALID; } + int parent_node_index = -1; + (void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_node_index); InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(0); GE_CHECK_NOTNULL(in_data_anchor); GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); int32_t src_node_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); NodePtr src_node_ptr = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - retval_node_index_map[output_index] = std::make_pair(src_node_ptr, src_node_index); + retval_node_index_map[output_index] = {src_node_ptr, src_node_index, parent_node_index}; // if user targets include retval node,delete it from set and insert its input node instead + // better to GetInNodes here auto iter = targets_.find(node); if (iter != targets_.end()) { targets_.erase(iter); @@ -69,9 +64,8 @@ Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, return SUCCESS; } -Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, - std::vector> &output_nodes_info) { - std::map> retval_node_index_map; +Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector &output_nodes_info) { + std::map retval_node_index_map; for (NodePtr &node : graph->GetDirectNode()) { Status ret = SUCCESS; if ((node->GetOpDesc() != nullptr) && (node->GetOpDesc()->HasAttr(RETVAL_ATTR_NAME_INDEX))) { @@ -85,21 +79,21 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, } } GELOGI("Get retval node size:%zu.", retval_node_index_map.size()); - std::vector> out_nodes_tmp; + std::vector out_nodes_tmp; /// The Netoutput output is determined by Retval, and the input order /// of Netoutput is sorted according to the index value of Retval. - for (auto it = retval_node_index_map.begin(); it != retval_node_index_map.end(); ++it) { - out_nodes_tmp.push_back(it->second); + for (auto &it : retval_node_index_map) { + out_nodes_tmp.push_back(it.second); } // when user set targets, mean that no output result for (auto &ele : graph->GetGraphOutNodesInfo()) { auto iter = targets_.find(ele.first); if (iter != targets_.end()) { - GELOGI("user set out node [%s] is found in user def targets, out node is prio!", (ele.first)->GetName().c_str()); + GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str()); targets_.erase(iter); } - output_nodes_info.push_back(ele); + output_nodes_info.push_back({ele.first, ele.second, -1}); } GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size()); for (auto &ele : out_nodes_tmp) { @@ -115,10 +109,9 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, return SUCCESS; } -Status NetOutputPass::CheckOutputNodeInfo(const ComputeGraphPtr &graph, - const std::vector> &outputs) { +Status NetOutputPass::CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector &outputs) { for (auto &item : outputs) { - NodePtr node = item.first; + NodePtr node = item.output_node; if (node == nullptr) { GELOGE(PARAM_INVALID, "Node in outputs is null."); return PARAM_INVALID; @@ -129,7 +122,7 @@ Status NetOutputPass::CheckOutputNodeInfo(const ComputeGraphPtr &graph, } GE_CHECK_NOTNULL(node->GetOpDesc()); int32_t out_size = node->GetOpDesc()->GetOutputsSize(); - int32_t index = item.second; + int32_t index = item.node_output_index; if (index < 0 || index >= out_size) { GELOGE(PARAM_INVALID, "User declared out node (%s) output index:%d must be smaller " @@ -152,8 +145,6 @@ void NetOutputPass::AddInOutForNetOutputOp(const ge::ComputeGraphPtr &graph, con } ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); GE_IF_BOOL_EXEC(net_output_desc->AddInputDesc(out_desc) != SUCCESS, GELOGW("add input desc failed"); return ); - TensorUtils::SetOutputTensor(out_desc, true); - GE_IF_BOOL_EXEC(net_output_desc->AddOutputDesc(out_desc) != SUCCESS, GELOGW("add output desc failed"); return ); } Status NetOutputPass::RemoveUnusedNode(const ge::ComputeGraphPtr &graph) { @@ -211,11 +202,6 @@ Status NetOutputPass::UpdateNetOutputDesc(const ge::NodePtr &net_output) { GELOGE(INTERNAL_ERROR, "Update input desc failed, index:%u.", index); return INTERNAL_ERROR; } - TensorUtils::SetOutputTensor(output_in_desc, true); - if (net_output_desc->UpdateOutputDesc(index, output_in_desc) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Update output desc failed, index:%u.", index); - return INTERNAL_ERROR; - } GELOGD("Update desc, format:%s, data type:%s, index:%u.", TypeUtils::FormatToSerialString(output_in_desc.GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(output_in_desc.GetDataType()).c_str(), index); @@ -260,20 +246,33 @@ void NetOutputPass::SaveAndRemoveTargets(const ge::ComputeGraphPtr &graph) { } Status NetOutputPass::AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node, - const std::vector> &output_nodes_info) { + const std::vector &output_nodes_info) { int32_t net_input_index = 0; for (auto &item : output_nodes_info) { - NodePtr src_node = item.first; + NodePtr src_node = item.output_node; GE_CHECK_NOTNULL(src_node); - graphStatus status = - GraphUtils::AddEdge(src_node->GetOutDataAnchor(item.second), net_out_node->GetInDataAnchor(net_input_index)); + graphStatus status = GraphUtils::AddEdge(src_node->GetOutDataAnchor(item.node_output_index), + net_out_node->GetInDataAnchor(net_input_index)); if (status != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "AddEdge failed, src name:%s, src index:%d, dst index:%d.", src_node->GetName().c_str(), - item.second, net_input_index); + item.node_output_index, net_input_index); return INTERNAL_ERROR; } - GELOGD("AddEdge to output node, src name:%s, src index:%d, dst index:%d.", src_node->GetName().c_str(), item.second, - net_input_index); + GELOGD("AddEdge to output node, src name:%s, src index:%d, dst index:%d.", src_node->GetName().c_str(), + item.node_output_index, net_input_index); + if (item.parent_node_index >= 0) { + GELOGI("Add parent node index %d for the netoutput input %d on graph %s", item.parent_node_index, net_input_index, + graph->GetName().c_str()); + auto input_desc = net_out_node->GetOpDesc()->MutableInputDesc(net_input_index); + if (input_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "Can not find intput tensor desc from NetOutput, index %d", net_input_index); + return INTERNAL_ERROR; + } + if (!AttrUtils::SetInt(input_desc, ATTR_NAME_PARENT_NODE_INDEX, item.parent_node_index)) { + GELOGE(INTERNAL_ERROR, "Failed to add parent index to NetOutput, index %d", net_input_index); + return INTERNAL_ERROR; + } + } net_input_index++; } if (RemoveUnusedNode(graph) != SUCCESS) { @@ -438,7 +437,7 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { GELOGI("NetOutputPass Run."); NodePtr output_node = graph->FindNode(NODE_NAME_NET_OUTPUT); OpDescPtr net_output_desc = nullptr; - std::vector> output_nodes_info; + std::vector output_nodes_info; // save user targets node SaveAndRemoveTargets(graph); @@ -486,11 +485,11 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { } std::vector is_input_const; for (auto iter = output_nodes_info.begin(); iter != output_nodes_info.end();) { - ge::NodePtr src_node = (*iter).first; + ge::NodePtr src_node = iter->output_node; if (src_node == nullptr) { continue; } - int32_t src_index = (*iter).second; + int32_t src_index = iter->node_output_index; // if src_node is in targets_, no need to Add in and out for netoutput auto it = targets_.find(src_node); if (it != targets_.end()) { diff --git a/src/ge/graph/passes/net_output_pass.h b/src/ge/graph/passes/net_output_pass.h index 6c86d8ef..6a022d79 100644 --- a/src/ge/graph/passes/net_output_pass.h +++ b/src/ge/graph/passes/net_output_pass.h @@ -26,6 +26,12 @@ #include "inc/graph_pass.h" namespace ge { +struct RetvalInfo { + NodePtr output_node; + int32_t node_output_index; + int parent_node_index; +}; + class NetOutputPass : public GraphPass { public: /// @@ -47,8 +53,7 @@ class NetOutputPass : public GraphPass { /// @return OTHERS: Execution failed /// @author /// - Status GetRetvalOutputInfo(const ge::NodePtr &node, - std::map> &retval_node_index_map); + Status GetRetvalOutputInfo(const ge::NodePtr &node, std::map &retval_node_index_map); /// /// Get the output node of the graph @@ -58,8 +63,7 @@ class NetOutputPass : public GraphPass { /// @return OTHERS: Execution failed /// @author /// - Status GetOutputNode(const ge::ComputeGraphPtr &graph, - std::vector> &output_nodes_info); + Status GetOutputNode(const ge::ComputeGraphPtr &graph, std::vector &output_nodes_info); /// /// Check if the network output node is legal @@ -69,7 +73,7 @@ class NetOutputPass : public GraphPass { /// @return OTHERS: Execution failed /// @author /// - Status CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector> &outputs); + Status CheckOutputNodeInfo(const ComputeGraphPtr &graph, const std::vector &outputs); /// /// Set input and output for the NetOutput node @@ -129,7 +133,7 @@ class NetOutputPass : public GraphPass { /// @author /// Status AddEdgesForNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node, - const std::vector> &output_nodes_info); + const std::vector &output_nodes_info); /// /// Add ctrl edges for leaf node /// @param [in] graph: Input ComputeGraph diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index fdea1f8a..030ff6ac 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -30,15 +30,6 @@ #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" -using domi::ENTER; -using domi::LOOPCOND; -using domi::MERGE; -using domi::NEXTITERATION; -using domi::REFENTER; -using domi::REFMERGE; -using domi::STREAMACTIVE; -using domi::SWITCH; - namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { GELOGD("NextIterationPass Enter"); diff --git a/src/ge/graph/passes/no_use_reshape_remove_pass.cc b/src/ge/graph/passes/no_use_reshape_remove_pass.cc index c0f46e2a..1e78cc40 100644 --- a/src/ge/graph/passes/no_use_reshape_remove_pass.cc +++ b/src/ge/graph/passes/no_use_reshape_remove_pass.cc @@ -19,14 +19,14 @@ #include #include +#include "common/op/ge_op_utils.h" #include "external/graph/types.h" #include "framework/common/debug/ge_log.h" -#include "common/op/ge_op_utils.h" #include "framework/common/ge_inner_error_codes.h" +#include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" -#include "graph/passes/pass_utils.h" namespace ge { namespace { @@ -40,7 +40,7 @@ Status NoUseReshapeRemovePass::Run(ge::NodePtr &node) { GELOGE(PARAM_INVALID, "NoUseReshapeRemovePass enter. OpDesc is null."); return PARAM_INVALID; } - if (op_desc_ptr->GetType() != domi::RESHAPE) { + if (op_desc_ptr->GetType() != RESHAPE) { return SUCCESS; } GELOGI("NoUseReshapeRemovePass enter."); @@ -60,6 +60,11 @@ Status NoUseReshapeRemovePass::Run(ge::NodePtr &node) { std::vector input_4dims = input_desc->GetShape().GetDims(); std::vector output_4dims = output_desc->GetShape().GetDims(); + if (input_desc->GetShape().IsUnknownShape() || output_desc->GetShape().IsUnknownShape()) { + GELOGI("Current Reshape %s is unkown shape which should be kept.", op_desc_ptr->GetName().c_str()); + return SUCCESS; + } + if (input_4dims.size() != output_4dims.size()) { GELOGI("Input and output dim size is not equal.Keep this reshape op."); return SUCCESS; diff --git a/src/ge/graph/passes/parallel_concat_start_op_pass.cc b/src/ge/graph/passes/parallel_concat_start_op_pass.cc new file mode 100644 index 00000000..0ac26b91 --- /dev/null +++ b/src/ge/graph/passes/parallel_concat_start_op_pass.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/parallel_concat_start_op_pass.h" +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace ge { +namespace { +const size_t kParallelConcatStartOutputSize = 1; +const uint32_t kParallelConcatStartOutputDataIndex = 0; +const char *const kAttrDtype = "dtype"; +const char *const kAttrShape = "shape"; +} // namespace +Status ParallelConcatStartOpPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + if (node->GetType() != PARALLELCONCATSTART) { + return SUCCESS; + } + + OpDescPtr node_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(node_op_desc); + string node_name = node->GetName(); + GELOGI("Start to replace operator _ParallelConcatStart with Constant, node name: %s.", node_name.c_str()); + + if (node_op_desc->GetOutputsSize() != kParallelConcatStartOutputSize) { + GELOGE(PARAM_INVALID, "Node[%s] output size is unexpected, the value is %zu.", node_name.c_str(), + node_op_desc->GetOutputsSize()); + return PARAM_INVALID; + } + auto output_tensor_desc = node_op_desc->GetOutputDesc(kParallelConcatStartOutputDataIndex); + GeTensorPtr output_ptr = MakeShared(output_tensor_desc); + if (output_ptr == nullptr) { + GELOGE(MEMALLOC_FAILED, "Malloc GeTensor failed, node name %s.", node_name.c_str()); + return FAILED; + } + + ge::DataType attr_dtype; + if (!ge::AttrUtils::GetDataType(node_op_desc, kAttrDtype, attr_dtype)) { + GELOGE(PARAM_INVALID, "Node:%s failed to get attribute dtype.", node_name.c_str()); + return PARAM_INVALID; + } + output_ptr->MutableTensorDesc().SetDataType(attr_dtype); + + vector attr_shape_list; + if (!ge::AttrUtils::GetListInt(node_op_desc, kAttrShape, attr_shape_list)) { + GELOGE(PARAM_INVALID, "Node:%s failed to get attribute shape.", node_name.c_str()); + return PARAM_INVALID; + } + output_ptr->MutableTensorDesc().SetShape(GeShape(attr_shape_list)); + + vector outputs; + outputs.emplace_back(output_ptr); + + return Folding(node, outputs); +} +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_output.h b/src/ge/graph/passes/parallel_concat_start_op_pass.h similarity index 57% rename from src/ge/graph/load/new_model_manager/model_output.h rename to src/ge/graph/passes/parallel_concat_start_op_pass.h index 1b05bdd6..0f6e754a 100644 --- a/src/ge/graph/load/new_model_manager/model_output.h +++ b/src/ge/graph/passes/parallel_concat_start_op_pass.h @@ -14,22 +14,14 @@ * limitations under the License. */ -#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_OUTPUT_H_ -#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_OUTPUT_H_ - -#include "common/ge_inner_error_codes.h" -#include "common/types.h" -#include "common/ge_types.h" -#include "graph/op_desc.h" +#ifndef GE_GRAPH_PASSES_PARALLEL_CONCAT_START_OP_PASS_H_ +#define GE_GRAPH_PASSES_PARALLEL_CONCAT_START_OP_PASS_H_ +#include "graph/passes/folding_pass.h" namespace ge { -class DavinciModel; - -class ModelOutput { +class ParallelConcatStartOpPass : public FoldingPass { public: - static Status CopyResult(DavinciModel *model, OpDescPtr op_desc, OutputData &rslt, uint32_t &data_index, - bool support_mem_share); + Status Run(NodePtr &node) override; }; } // namespace ge - -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_MODEL_OUTPUT_H_ +#endif // GE_GRAPH_PASSES_PARALLEL_CONCAT_START_OP_PASS_H_ diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index d690e9c1..f62ea160 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -21,8 +21,6 @@ #include "graph/utils/node_utils.h" #include "omg/omg_inner_types.h" -using domi::SUCCESS; - namespace ge { const vector &PassManager::GraphPasses() const { return graph_passes_; } diff --git a/src/ge/graph/passes/pass_utils.cc b/src/ge/graph/passes/pass_utils.cc index 80b85774..9b3f6b5f 100644 --- a/src/ge/graph/passes/pass_utils.cc +++ b/src/ge/graph/passes/pass_utils.cc @@ -27,8 +27,8 @@ #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "common/types.h" -#include "framework/common/op/attr_define.h" #include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/graph_utils.h" @@ -111,7 +111,7 @@ bool PassUtils::IsConstant(const ConstNodePtr &node) { } auto src_node_type = node->GetType(); - bool is_constant = (src_node_type == domi::CONSTANT) || (src_node_type == domi::CONSTANTOP); + bool is_constant = (src_node_type == CONSTANT) || (src_node_type == CONSTANTOP); return is_constant; } @@ -203,7 +203,7 @@ Status PassUtils::RemoveBranch(const NodePtr &node, std::vector &delete auto dst_node = dst_in_anchor->GetOwnerNode(); std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(dst_node, node_type), "get original type failed"); - if (node_type == domi::NETOUTPUT) { + if (node_type == NETOUTPUT) { if (dst_in_anchor->IsTypeOf()) { GELOGE(INTERNAL_ERROR, "[%s] Inactive branch connected to " @@ -215,7 +215,7 @@ Status PassUtils::RemoveBranch(const NodePtr &node, std::vector &delete GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(src_out_anchor, dst_in_anchor), "remove edge failed"); end_nodes.push_back(dst_node); } - } else if (node_type == domi::MERGE) { + } else if (node_type == MERGE) { /// Unlink connection between the inactive branch and Merge/NetOutput. /// The removal of inactive nodes will be handled in PrunePass GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(src_out_anchor, dst_in_anchor), "remove edge failed"); @@ -254,7 +254,7 @@ bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { return false; } - if (!ge::VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER)) { + if (!ge::VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(NODE_NAME_FLOWCTRL_LOOP_PER_ITER)) { return false; } return compute_graph->GetNeedIteration(); @@ -319,7 +319,7 @@ Status PassUtils::RemoveInactiveBranchToMerge(const OutDataAnchorPtr &inactive_o if (dst_node != nullptr) { std::string dst_node_type; GE_CHK_STATUS_RET(GetOriginalType(dst_node, dst_node_type), "get original type failed"); - if (dst_node_type == domi::MERGE) { + if (dst_node_type == MERGE) { GELOGD("[%s] Switch connected directly to Merge", inactive_output_anchor->GetOwnerNode()->GetName().c_str()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(inactive_output_anchor, dst_anchor), "remove edge failed"); continue; diff --git a/src/ge/graph/passes/pass_utils.h b/src/ge/graph/passes/pass_utils.h index a8b1cfe3..b889a056 100644 --- a/src/ge/graph/passes/pass_utils.h +++ b/src/ge/graph/passes/pass_utils.h @@ -26,6 +26,7 @@ namespace ge { class PassUtils { public: PassUtils() = delete; + ~PassUtils() = delete; static NodePtr GetInDataNode(const ConstNodePtr &node, int index); diff --git a/src/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc index 0847453f..1b04b3fa 100644 --- a/src/ge/graph/passes/permute_pass.cc +++ b/src/ge/graph/passes/permute_pass.cc @@ -25,18 +25,10 @@ #include "inc/kernel_factory.h" #include "framework/omg/omg_inner_types.h" -using domi::ATTR_NAME_FORMAT; -using domi::ATTR_NAME_PRED_PERMUTE_DELETED; -using domi::CONVOLUTION; -using domi::DECONVOLUTION; -using domi::DEPCONVOLUTION; using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_NHWC; +using domi::FMK_TYPE_T; using domi::GetContext; -using domi::INTERNAL_ERROR; -using domi::PAD; -using domi::PERMUTE; -using domi::PERMUTE_ATTR_ORDER; using domi::SUCCESS; namespace ge { @@ -44,7 +36,7 @@ Status PermutePass::Run(ComputeGraphPtr graph) { GE_TIMESTAMP_START(PermutePass); GE_CHECK_NOTNULL(graph); std::vector isolate_nodes; - for (NodePtr &node : graph->GetAllNodes()) { + for (NodePtr &node : graph->GetDirectNode()) { OpDescPtr op_desc_ptr = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc_ptr); GE_IF_BOOL_EXEC( @@ -55,7 +47,7 @@ Status PermutePass::Run(ComputeGraphPtr graph) { GetContext().format != DOMI_TENSOR_ND, // Get input origin foramt for (NodePtr &n - : graph->GetAllNodes()) { + : graph->GetDirectNode()) { GE_IF_BOOL_EXEC( n->GetOpDesc()->GetType() == PERMUTE, std::queue q_node; q_node.push(n); bool jump_out = false; while (!q_node.empty()) { diff --git a/src/ge/graph/passes/placeholder_with_default_pass.cc b/src/ge/graph/passes/placeholder_with_default_pass.cc index cf1f84a6..7a72fc36 100644 --- a/src/ge/graph/passes/placeholder_with_default_pass.cc +++ b/src/ge/graph/passes/placeholder_with_default_pass.cc @@ -20,8 +20,6 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" -using domi::PLACEHOLDERWITHDEFAULT; - namespace ge { Status PlaceholderWithDefaultPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); diff --git a/src/ge/graph/passes/prevent_gradient_pass.cc b/src/ge/graph/passes/prevent_gradient_pass.cc index 049fece8..87c1b3a1 100644 --- a/src/ge/graph/passes/prevent_gradient_pass.cc +++ b/src/ge/graph/passes/prevent_gradient_pass.cc @@ -21,8 +21,6 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" -using domi::PREVENTGRADIENT; - namespace ge { Status PreventGradientPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); diff --git a/src/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h index cf8db6c5..64bf6573 100644 --- a/src/ge/graph/passes/print_op_pass.h +++ b/src/ge/graph/passes/print_op_pass.h @@ -18,8 +18,8 @@ #define GE_GRAPH_PASSES_PRINT_OP_PASS_H_ #include "framework/common/debug/ge_log.h" -#include "framework/common/op/attr_define.h" #include "framework/common/types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/common/omg_util.h" #include "graph/graph.h" #include "graph/passes/base_pass.h" diff --git a/src/ge/graph/passes/prune_pass.cc b/src/ge/graph/passes/prune_pass.cc index 8122e6e2..f7d09740 100644 --- a/src/ge/graph/passes/prune_pass.cc +++ b/src/ge/graph/passes/prune_pass.cc @@ -24,10 +24,6 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -using domi::AIPPDATA; -using domi::DATA; -using domi::NETOUTPUT; - namespace ge { Status PrunePass::Run(ge::ComputeGraphPtr graph) { GELOGD("PrunePass Start"); diff --git a/src/ge/graph/passes/replace_transshape_pass.cc b/src/ge/graph/passes/replace_transshape_pass.cc new file mode 100644 index 00000000..28a8244d --- /dev/null +++ b/src/ge/graph/passes/replace_transshape_pass.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/replace_transshape_pass.h" + +#include + +#include "common/ge/ge_util.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "graph/common/omg_util.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +Status ReplaceTransShapePass::Run(ge::ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetDirectNode()) { + if (node->GetType() == TRANSSHAPE) { + auto ret = ReplaceTransShapeNode(graph, node); + if (ret != SUCCESS) { + GELOGE(FAILED, "Trans shape node %s failed", node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} + +Status ReplaceTransShapePass::ReplaceTransShapeNode(ComputeGraphPtr &graph, NodePtr &trans_shape_node) { + std::string op_type; + auto ret = GetOriginalType(trans_shape_node, op_type); + if (ret != SUCCESS) { + GELOGE(FAILED, "Get node %s original type failede", trans_shape_node->GetName().c_str()); + return FAILED; + } + auto src_op_desc = trans_shape_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + + std::string node_name = trans_shape_node->GetName() + "ToMemcpy"; + auto dst_op_desc = MakeShared(node_name, MEMCPYASYNC); + if (dst_op_desc == nullptr) { + GELOGE(FAILED, "Make node %s opdesc failed", node_name.c_str()); + return FAILED; + } + GELOGI("Create memcpy Op, name=%s.", node_name.c_str()); + for (InDataAnchorPtr &in_anchor : trans_shape_node->GetAllInDataAnchors()) { + auto ret = dst_op_desc->AddInputDesc(src_op_desc->GetInputDesc(in_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add input desc failed"); + return FAILED; + } + } + for (OutDataAnchorPtr &out_anchor : trans_shape_node->GetAllOutDataAnchors()) { + auto ret = dst_op_desc->AddOutputDesc(src_op_desc->GetOutputDesc(out_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add output desc failed"); + return FAILED; + } + } + NodePtr memcpy_node = graph->AddNode(dst_op_desc); + GE_CHECK_NOTNULL(memcpy_node); + + for (InDataAnchorPtr &in_data_anchor : trans_shape_node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Remove Memcpy data input fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(in_data_anchor->GetIdx())), + "Memcpy node add edge fail."); + } + + for (OutDataAnchorPtr &out_data_anchor : trans_shape_node->GetAllOutDataAnchors()) { + for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Memcpy data output fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), + "Memcpy node add edge fail."); + } + } + ReplaceControlEdges(trans_shape_node, memcpy_node); + return SUCCESS; +} + +void ReplaceTransShapePass::CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag) { + GE_CHECK_NOTNULL_JUST_RETURN(old_node); + GE_CHECK_NOTNULL_JUST_RETURN(new_node); + GE_IF_BOOL_EXEC(old_node == new_node, return ); + for (NodePtr &node : old_node->GetInControlNodes()) { + auto out_control_anchor = node->GetOutControlAnchor(); + GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, new_node->GetInControlAnchor()), "Add in ctl edge fail."); + }); + } + + for (NodePtr &node : old_node->GetOutControlNodes()) { + GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Add out ctl edge fail."); + }); + } +} + +void ReplaceTransShapePass::RemoveControlEdges(NodePtr &node) { + GE_CHECK_NOTNULL_JUST_RETURN(node); + for (NodePtr &in_node : node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Remove in ctl edge fail."); + } + + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (auto &in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, in_ctrl_anchor), "Remove in ctl edge fail."); + } + } + + auto out_control_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); + for (auto &peer_anchor : out_control_anchor->GetPeerAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_control_anchor, peer_anchor), "Remove out ctl edge fail."); + } +} + +void ReplaceTransShapePass::ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node) { + GE_IF_BOOL_EXEC(old_node == new_node, return ); + CopyControlEdges(old_node, new_node); + RemoveControlEdges(old_node); +} +} // namespace ge diff --git a/src/ge/graph/passes/end_graph_pass.h b/src/ge/graph/passes/replace_transshape_pass.h similarity index 59% rename from src/ge/graph/passes/end_graph_pass.h rename to src/ge/graph/passes/replace_transshape_pass.h index c6ff422a..6673b11d 100644 --- a/src/ge/graph/passes/end_graph_pass.h +++ b/src/ge/graph/passes/replace_transshape_pass.h @@ -14,24 +14,21 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_END_GRAPH_PASS_H_ -#define GE_GRAPH_PASSES_END_GRAPH_PASS_H_ +#ifndef GE_GRAPH_PASSES_REPLACE_TRANS_SHAPE_PASS_H_ +#define GE_GRAPH_PASSES_REPLACE_TRANS_SHAPE_PASS_H_ -#include "graph/types.h" #include "inc/graph_pass.h" namespace ge { -class EndGraphPass : public GraphPass { +class ReplaceTransShapePass : public GraphPass { public: - /// - /// Entry of the NetOutputPass optimizer - /// @param [in] graph: Input ComputeGraph - /// @return SUCCESS: Execution succeed - /// @return OTHERS: Execution failed - /// @author - /// Status Run(ge::ComputeGraphPtr graph) override; + + private: + Status ReplaceTransShapeNode(ComputeGraphPtr &graph, NodePtr &trans_shape_node); + void CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag = false); + void RemoveControlEdges(NodePtr &node); + void ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node); }; } // namespace ge -#endif // GE_GRAPH_PASSES_END_GRAPH_PASS_H_ - +#endif // GE_GRAPH_PASSES_REPLACE_TRANS_SHAPE_PASS_H_ diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.cc b/src/ge/graph/passes/replace_with_empty_const_pass.cc new file mode 100644 index 00000000..b76b2cc9 --- /dev/null +++ b/src/ge/graph/passes/replace_with_empty_const_pass.cc @@ -0,0 +1,156 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/replace_with_empty_const_pass.h" +#include +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/utils/graph_utils.h" + +namespace ge { +Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { + GELOGD("ReplaceWithEmptyConstPass in."); + if (node == nullptr) { + GELOGE(PARAM_INVALID, "Parameter is null."); + return PARAM_INVALID; + } + if (node->GetOpDesc() == nullptr) { + GELOGE(PARAM_INVALID, "Param [opDesc] must not be null."); + return PARAM_INVALID; + } + // Node like no op, it has no output + if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { + GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); + return SUCCESS; + } + // If outputs of current node are all empty, replace it with empty const + bool is_all_output_empty = true; + for (const auto &output_desc_ptr : node->GetOpDesc()->GetAllOutputsDescPtr()) { + if (output_desc_ptr == nullptr) { + GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); + return SUCCESS; + } + if (!IsEmptyTenor(output_desc_ptr->GetShape())) { + is_all_output_empty = false; + break; + } + } + if (is_all_output_empty) { + GELOGI("Node %s has empty tensor output. It will be replaced by empty const.", node->GetName().c_str()); + // Replace op which all output is empty with empty const + Status ret = ReplaceWithEmptyConst(node); + if (ret != SUCCESS) { + // If replace failed, it should not break whole process, so still return success + GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); + } + } + GELOGD("ReplaceWithEmptyConstPass end."); + return SUCCESS; +} + +Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace) { + std::map> shape_out_idx_map; + auto op_desc = node_to_replace->GetOpDesc(); + // Collect out_idx follow different out shape + for (const auto &out_anchor : node_to_replace->GetAllOutDataAnchors()) { + auto out_desc = op_desc->GetOutputDesc(out_anchor->GetIdx()); + shape_out_idx_map[GetDimStr(out_desc.GetShape())].emplace_back(out_anchor->GetIdx()); + } + + for (const auto &shape_2_out_idx : shape_out_idx_map) { + // Create empty const + // The out_desc in one group should be same shape, so here only get first out_desc. its valid index. + auto out_desc = op_desc->GetOutputDesc(shape_2_out_idx.second[0]); + NodePtr const_node; + auto graph = node_to_replace->GetOwnerComputeGraph(); + Status ret = InsertEmptyConst(out_desc, const_node, graph); + if (ret != SUCCESS) { + GELOGE(FAILED, "Failed insert const node."); + return FAILED; + } + + // Repalce data anchors + if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, shape_2_out_idx.second) != GRAPH_SUCCESS) { + GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); + return FAILED; + } + // Copy in control edge + if (GraphUtils::CopyInCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "CopyInCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), + const_node->GetName().c_str()); + return FAILED; + } + // Copy out control edge + if (GraphUtils::CopyOutCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "CopyOutCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), + const_node->GetName().c_str()); + return FAILED; + } + GELOGI("Node %s has been replaced by empty const %s.", node_to_replace->GetName().c_str(), + const_node->GetName().c_str()); + } + // Unlink control edge from node_to_replace to graph + if (node_to_replace->GetInControlAnchor() != nullptr) { + node_to_replace->GetInControlAnchor()->UnlinkAll(); + } + if (node_to_replace->GetOutControlAnchor() != nullptr) { + node_to_replace->GetOutControlAnchor()->UnlinkAll(); + } + return SUCCESS; +} +Status ReplaceWithEmptyConstPass::InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, + ComputeGraphPtr &graph) { + GeTensorPtr empty_tensor = MakeShared(); + if (empty_tensor == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed create empty tensor."); + return OUT_OF_MEMORY; + } + empty_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); + empty_tensor->MutableTensorDesc().SetFormat(out_desc.GetFormat()); + empty_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); + auto const_desc = OpDescUtils::CreateConstOp(empty_tensor); + if (const_desc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor"); + return OUT_OF_MEMORY; + } + + const_node = graph->AddNode(const_desc); + if (const_node == nullptr) { + GELOGE(FAILED, "Failed insert const node."); + return FAILED; + } + return SUCCESS; +} + +bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { + for (auto dim : shape.GetDims()) { + if (dim == 0) { + return true; + } + } + return false; +} + +string ReplaceWithEmptyConstPass::GetDimStr(const GeShape &shape) { + std::stringstream dim_str; + for (auto dim : shape.GetDims()) { + dim_str << dim << '-'; + } + return dim_str.str(); +} +} // namespace ge diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.h b/src/ge/graph/passes/replace_with_empty_const_pass.h new file mode 100644 index 00000000..495b75b3 --- /dev/null +++ b/src/ge/graph/passes/replace_with_empty_const_pass.h @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ +#define GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ + +#include "graph/passes/base_pass.h" + +namespace ge { +class ReplaceWithEmptyConstPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + Status ReplaceWithEmptyConst(NodePtr &node_to_replace); + Status InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, ComputeGraphPtr &graph); + bool IsEmptyTenor(const GeShape &shape) const; + std::string GetDimStr(const GeShape &shape); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ diff --git a/src/ge/graph/passes/reshape_remove_pass.cc b/src/ge/graph/passes/reshape_remove_pass.cc index 49945f38..bd84882a 100644 --- a/src/ge/graph/passes/reshape_remove_pass.cc +++ b/src/ge/graph/passes/reshape_remove_pass.cc @@ -15,6 +15,7 @@ */ #include "graph/passes/reshape_remove_pass.h" +#include "framework/common/util.h" #include "graph/passes/pass_utils.h" namespace ge { @@ -24,19 +25,20 @@ const int kReshapeShapeIndex = 1; } // namespace Status ReshapeRemovePass::Run(NodePtr &node) { - if (node == nullptr) { - GELOGE(FAILED, "parameter is null."); - return FAILED; - } - if (node->GetType() != domi::RESHAPE) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (node->GetType() != RESHAPE && node->GetType() != REFORMAT) { return SUCCESS; } - GELOGD("Remove reshape node %s", node->GetName().c_str()); - auto ret = PassUtils::UnlinkNodeWithControlCopy(node, kReshapeShapeIndex); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed unlink shape edge for reshape node %s", node->GetName().c_str()); - return ret; + auto op_desc = node->GetOpDesc(); + auto output_desc = op_desc->GetOutputDescPtr(kReshapeDataIndex); + GE_CHECK_NOTNULL(output_desc); + if (output_desc->GetShape().IsUnknownShape()) { + GELOGD("Reshape node %s is unknown shape. It should be remained.", node->GetName().c_str()); + return SUCCESS; } + + GELOGD("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); return IsolateAndDeleteNode(node, {kReshapeDataIndex}); } } // namespace ge diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index 0d22d557..5f861660 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -28,13 +28,6 @@ #include "graph/utils/op_desc_utils.h" #include "init/gelib.h" -using domi::ATTR_NAME_STREAM_LABEL; -using domi::CAST; -using domi::RESHAPE; -using domi::TRANSDATA; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace { const char *const kRemainNode = "node_remain"; const int kNoTransOp = 1; @@ -351,7 +344,7 @@ graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { return GRAPH_SUCCESS; } - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetDirectNode()) { if (IsTransOp(node) || node->GetOutDataNodes().size() <= 1) { continue; } diff --git a/src/ge/graph/passes/shape_operate_op_remove_pass.cc b/src/ge/graph/passes/shape_operate_op_remove_pass.cc index b04c2a18..5a6e5f99 100644 --- a/src/ge/graph/passes/shape_operate_op_remove_pass.cc +++ b/src/ge/graph/passes/shape_operate_op_remove_pass.cc @@ -20,13 +20,12 @@ #include "common/util.h" #include "graph/utils/attr_utils.h" -using domi::ATTR_TO_BE_DELETED; using domi::SUCCESS; namespace ge { Status ShapeOperateOpRemovePass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetDirectNode()) { OpDescPtr op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); bool to_be_deleted = false; diff --git a/src/ge/graph/passes/snapshot_pass.cc b/src/ge/graph/passes/snapshot_pass.cc index 83510e19..702cf4de 100644 --- a/src/ge/graph/passes/snapshot_pass.cc +++ b/src/ge/graph/passes/snapshot_pass.cc @@ -20,8 +20,6 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" -using domi::SNAPSHOT; - namespace ge { Status SnapshotPass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/stop_gradient_pass.cc b/src/ge/graph/passes/stop_gradient_pass.cc index 175c8756..bd5c0ea8 100644 --- a/src/ge/graph/passes/stop_gradient_pass.cc +++ b/src/ge/graph/passes/stop_gradient_pass.cc @@ -17,8 +17,6 @@ #include "graph/passes/stop_gradient_pass.h" #include -using domi::STOPGRADIENT; - namespace ge { Status StopGradientPass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc new file mode 100644 index 00000000..6c4ad385 --- /dev/null +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -0,0 +1,214 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/subgraph_pass.h" + +#include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +namespace { +const std::set kWhileTypes = {ge::WHILE, ge::_WHILE, ge::STATELESSWHILE}; +} + +namespace ge { + +/** + * @ingroup ge + * @brief Subgraph optimizer. + * @param [in] graph: Input ComputeGraph + * @return: 0 for success / others for fail + */ +Status SubgraphPass::Run(ComputeGraphPtr graph) { + const bool is_sub_graph = graph->GetParentNode() != nullptr; + for (const NodePtr &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + + if (is_sub_graph && (node->GetType() == DATA)) { + if (SubgraphInputNode(graph, node) != SUCCESS) { + return FAILED; + } + continue; + } + + // 2. Const->NetOutput in subgraph + // 3. Data->NetOutput in subgraph but not while body + if (is_sub_graph && (node->GetType() == NETOUTPUT)) { + if (SubgraphOutputNode(graph, node) != SUCCESS) { + return FAILED; + } + continue; + } + + // 4. Input->While and Input link to other nodes + if (kWhileTypes.count(node->GetType()) > 0) { + if (WhileInputNodes(graph, node) != SUCCESS) { + return FAILED; + } + continue; + } + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Check Subgraph NetOutput node + * @param [in] graph: ComputeGraph. + * @param [in] node: NetOutput node in Subgraph. + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + return FAILED; + } + + // Subgraph Data Node, check for constant input. + std::string const_type; + NodePtr in_node = NodeUtils::GetParentInput(node); + if (!NodeUtils::GetConstOpType(in_node, const_type)) { + return SUCCESS; + } + + if (!AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { + return FAILED; + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Check Subgraph NetOutput node + * @param [in] graph: ComputeGraph. + * @param [in] node: NetOutput node in Subgraph. + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const NodePtr &node) { + for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node); + + // Need insert memcpy + // 2. Const->NetOutput in subgraph + // 3. Data->NetOutput in subgraph but not while body + std::string op_type; + bool input_const_flag = NodeUtils::GetConstOpType(in_node, op_type); + if ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)) { + input_const_flag = true; + } + + if (input_const_flag) { + GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + if (InsertMemcpyNode(graph, peer_out_anchor, in_data_anchor, name) != SUCCESS) { + return FAILED; + } + } + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Check is Input->While and Input link to other nodes + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr &node) { + for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + const OutDataAnchorPtr &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(in_node); + + // Need insert memcpy + // 4. Input->While and Input link to other nodes + if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { + GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + if (InsertMemcpyNode(graph, peer_out_anchor, in_data_anchor, name) != SUCCESS) { + return FAILED; + } + } + } + + return SUCCESS; +} + +/** + * @ingroup ge + * @brief Check is data->netoutput in while body + * @param [in] in_data_anchor + * @return: true for data->netoutput in while body / for false for others + */ +bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { + // Check is subgraph + NodePtr parent_node = in_data_anchor->GetOwnerNode()->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; + } + + // Check if parent_node is While + if (kWhileTypes.count(parent_node->GetType()) == 0) { + return false; + } + + // While cond / body + OpDescPtr op_desc = in_data_anchor->GetOwnerNode()->GetOpDesc(); + if (op_desc == nullptr) { + return false; + } + return AttrUtils::HasAttr(op_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX); +} + +/** + * @ingroup ge + * @brief Insert memcpy node + * @param [in] graph + * @param [in] out_anchor + * @param [in] in_anchor + * @param [in] name + * @return: 0 for success / others for fail + */ +Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, const std::string &name) { + GE_CHECK_NOTNULL(out_anchor); + GE_CHECK_NOTNULL(in_anchor); + NodePtr in_node = out_anchor->GetOwnerNode(); + OpDescBuilder op_desc_builder(name, MEMCPYASYNC); + OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) + .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) + .Build(); + if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s between %s->%s failed.", name.c_str(), in_node->GetName().c_str(), + in_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + +} // namespace ge diff --git a/src/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h new file mode 100644 index 00000000..57e4e4c6 --- /dev/null +++ b/src/ge/graph/passes/subgraph_pass.h @@ -0,0 +1,91 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ +#define GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ + +#include +#include +#include +#include + +#include "graph/types.h" +#include "inc/graph_pass.h" + +namespace ge { +class SubgraphPass : public GraphPass { + public: + /** + * @ingroup ge + * @brief Subgraph optimizer. + * @param [in] graph: Input ComputeGraph + * @return: 0 for success / others for fail + */ + Status Run(ComputeGraphPtr graph) override; + + private: + /** + * @ingroup ge + * @brief Check Subgraph Data node. + * @param [in] graph: ComputeGraph. + * @param [in] node: NetOutput node in Subgraph. + * @return: 0 for SUCCESS / others for FAILED + */ + Status SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Check Subgraph NetOutput node. + * @param [in] graph: ComputeGraph. + * @param [in] node: NetOutput node in Subgraph. + * @return: 0 for SUCCESS / others for FAILED + */ + Status SubgraphOutputNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Check is Input->While and Input link to other nodes + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: 0 for SUCCESS / others for FAILED + */ + Status WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Check is data->netoutput in while body + * @param [in] in_data_anchor + * @return: true for data->netoutput in while body / for false for others + */ + bool IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor); + + /** + * @ingroup ge + * @brief Insert memcpy node + * @param [in] graph + * @param [in] out_anchor + * @param [in] in_anchor + * @param [in] name + * @return: 0 for success / others for fail + */ + Status InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, const std::string &name); + + // Append index for new memcpy node. + uint32_t memcpy_num_{0}; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ diff --git a/src/ge/graph/passes/switch_logic_remove_pass.cc b/src/ge/graph/passes/switch_logic_remove_pass.cc index 1ac25e13..be84a582 100644 --- a/src/ge/graph/passes/switch_logic_remove_pass.cc +++ b/src/ge/graph/passes/switch_logic_remove_pass.cc @@ -37,7 +37,7 @@ char const *GetOutputNameFromIndex(int index) { return "UNKNOWN"; } -inline bool IsSwitch(const std::string &type) { return type == domi::SWITCH || type == domi::REFSWITCH; } +inline bool IsSwitch(const std::string &type) { return type == SWITCH || type == REFSWITCH; } Status GetPredNode(const NodePtr &switch_node, PredNodeAndOut &pred_node_index) { GE_CHECK_NOTNULL(switch_node); diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc index 1e1975d5..7eae40f8 100644 --- a/src/ge/graph/passes/switch_op_pass.cc +++ b/src/ge/graph/passes/switch_op_pass.cc @@ -31,36 +31,10 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" -using domi::ATTR_NAME_STREAM_LABEL; -using domi::ATTR_NAME_WEIGHTS; -using domi::CAST_ATTR_DSTT; -using domi::CAST_ATTR_SRCT; - -using domi::CAST; -using domi::CONSTANT; -using domi::ENTER; -using domi::EXIT; -using domi::MEMCPYASYNC; -using domi::MERGE; -using domi::NETOUTPUT; -using domi::NEXTITERATION; -using domi::REFENTER; -using domi::REFEXIT; -using domi::REFMERGE; -using domi::REFNEXTITERATION; -using domi::REFSWITCH; -using domi::STREAMACTIVE; -using domi::STREAMMERGE; -using domi::STREAMSWITCH; -using domi::SWITCH; - namespace ge { Status SwitchOpPass::Run(ComputeGraphPtr graph) { GELOGD("SwitchOpPass Enter"); - GraphUtils::DumpGEGraph(graph, "BeforeSwitchOpPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "BeforeSwitchOpPass"); - GE_CHK_STATUS_RET(CheckCycleDependence(graph), "CheckCycleDependence fail."); for (auto &switch_node : switch_nodes_) { @@ -68,7 +42,14 @@ Status SwitchOpPass::Run(ComputeGraphPtr graph) { } for (auto &merge_node : merge_nodes_) { - GE_CHK_STATUS_RET(ReplaceMergeNode(graph, merge_node), "Add StreamMerge node fail."); + OpDescPtr merge_op_desc = merge_node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, merge_node, true), "Merge add memcpy node fail."); + GE_CHK_STATUS_RET(SetStreamLabel(merge_node, merge_node->GetName()), "Set stream label failed"); + } else { + GE_CHK_STATUS_RET(ReplaceMergeNode(graph, merge_node), "Add StreamMerge node fail."); + } } GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes fail."); @@ -94,9 +75,6 @@ Status SwitchOpPass::Run(ComputeGraphPtr graph) { GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode fail."); - GraphUtils::DumpGEGraph(graph, "AfterSwitchOpPass"); - GraphUtils::DumpGEGraphToOnnx(*graph, "AfterSwitchOpPass"); - GELOGD("SwitchOpPass Leave"); return SUCCESS; } @@ -160,7 +138,7 @@ Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_n NodePtr out_node = peer_in_anchor->GetOwnerNode(); GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); if ((type == MERGE) || (type == REFMERGE)) { - NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor); + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), "MemcpyAsync node add edge fail."); @@ -257,16 +235,9 @@ Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_nod need_label_nodes_.emplace_back(stream_merge); } - if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { - if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { - GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); - return FAILED; - } - } - (void)bypass_nodes_.insert(merge_node); - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge), "StreamMerge add memcpy node fail."); + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, false), "StreamMerge add memcpy node fail."); return SUCCESS; } @@ -325,17 +296,20 @@ NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodeP /// @brief Add MemcpyAsync Node /// @param [in] graph /// @param [in] in_node +/// @param [in] multi_batch_flag /// @return ge::NodePtr /// -NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { +NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + bool multi_batch_flag) { GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); - std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC; + std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; + std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; node_name = CheckDuplicateName(node_name); GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, MEMCPYASYNC); + OpDescPtr op_desc = MakeShared(node_name, memcpy_type); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); return nullptr; @@ -455,9 +429,10 @@ NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { /// @brief Add MemcpyAsync Op as StreamMerge in_node /// @param [in] graph /// @param [in] node +/// @param [in] multi_batch_flag /// @return Status /// -Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) { +Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); @@ -470,7 +445,7 @@ Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) continue); GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { - in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor); + in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), @@ -682,7 +657,7 @@ Status SwitchOpPass::UpdateCondBranch(NodePtr &node) { std::stack nodes; nodes.push(node); - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE}; + static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; bool merge_flag = false; bool exit_flag = false; bool net_output_flag = false; diff --git a/src/ge/graph/passes/switch_op_pass.h b/src/ge/graph/passes/switch_op_pass.h index 14cdd22c..7e107e3b 100644 --- a/src/ge/graph/passes/switch_op_pass.h +++ b/src/ge/graph/passes/switch_op_pass.h @@ -103,13 +103,13 @@ class SwitchOpPass : public GraphPass { NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, OutDataAnchorPtr &peer_cond_anchor); - NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); + NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); Status CombineSwitchNode(ComputeGraphPtr &graph); NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); - Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node); + Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); diff --git a/src/ge/graph/passes/switch_pass.cc b/src/ge/graph/passes/switch_pass.cc index 36fb4d81..8230d294 100644 --- a/src/ge/graph/passes/switch_pass.cc +++ b/src/ge/graph/passes/switch_pass.cc @@ -25,10 +25,6 @@ #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" -using domi::MERGE; -using domi::REFSWITCH; -using domi::SWITCH; - namespace ge { namespace { const std::vector::size_type kDataInputIndex = 0; diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc index 30ca6a53..b2f66bfc 100644 --- a/src/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -24,15 +24,6 @@ #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" -using domi::ATTR_NAME_STREAM_LABEL; -using domi::CAST; -using domi::FAILED; -using domi::RESHAPE; -using domi::SUCCESS; -using domi::TRANSDATA; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace ge { Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { GE_TIMESTAMP_START(TransOpBreadthFusionPass); @@ -40,7 +31,7 @@ Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { return SUCCESS; } - for (auto const &node : graph->GetAllNodes()) { + for (auto const &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); auto ids_to_trans_nodes = GetOutputTransOpNodes(node); for (auto const &id_to_trans_nodes : ids_to_trans_nodes) { diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index 39989580..68899e2e 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -26,15 +26,6 @@ #include "graph/utils/graph_utils.h" #include "graph/common/transop_util.h" -using domi::CAST; -using domi::EXPANDDIMS; -using domi::REFORMAT; -using domi::RESHAPE; -using domi::SQUEEZE; -using domi::TRANSDATA; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { GE_TIMESTAMP_START(TransOpDepthFusionPass); @@ -42,7 +33,7 @@ graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { if (graph == nullptr) { return GRAPH_SUCCESS; } - for (const auto &node : graph->GetAllNodes()) { + for (const auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); if (TransOpUtil::IsTransOp(node)) { continue; diff --git a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc index 383ab285..4b08e956 100644 --- a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc +++ b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc @@ -28,7 +28,7 @@ Status TransOpNearbyAllreduceFusionPass::Run(NodePtr &node) { return SUCCESS; } - if (node->GetType() == domi::HCOMALLREDUCE) { + if (node->GetType() == HCOMALLREDUCE) { GELOGI("found allreduce op %s", node->GetName().c_str()); Status ret = RemoveNearbyPairedTransOps(node); if (ret != SUCCESS) { @@ -46,7 +46,7 @@ bool TransOpNearbyAllreduceFusionPass::IsSymmetricTransOps(const NodePtr &node1, return false; } - if (node1->GetType() != domi::TRANSDATA || node2->GetType() != domi::TRANSDATA) { + if (node1->GetType() != TRANSDATA || node2->GetType() != TRANSDATA) { return false; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc new file mode 100644 index 00000000..af75d9d0 --- /dev/null +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -0,0 +1,169 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "transop_symmetry_elimination_pass.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/common/transop_util.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/type_utils.h" + +namespace { +const int kTransOpOutIndex = 0; +static std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; + +} // namespace +namespace ge { +Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (!TransOpUtil::IsTransOp(node)) { + return SUCCESS; + } + GELOGD("Symmetry Elimination Pass in."); + for (const auto &out_anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_anchor); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()); + GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()->GetOpDesc()); + if (!CheckCanBeEliminated(node, peer_in_anchor)) { + break; + } + + auto dst_node = peer_in_anchor->GetOwnerNode(); + Status ret = EliminateTransOp(node, out_anchor, dst_node, peer_in_anchor); + if (ret != SUCCESS) { + // if eliminate failed ,it should't break precess, so give a warning here + GELOGW("Eliminate %s and %s failed, ignore current pass.", node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + } + } + GELOGD("Symmetry Elimination Pass end."); + return SUCCESS; +} + +bool TransOpSymmetryEliminationPass::CheckCanBeEliminated(const ge::NodePtr &src_node, + const InDataAnchorPtr &dst_in_anchor) { + auto dst_node = dst_in_anchor->GetOwnerNode(); + if (src_node->GetType() != dst_node->GetType()) { + GELOGD("Pre node %s type %s is not equal with node %s type %s. Ignore pass.", src_node->GetName().c_str(), + src_node->GetType().c_str(), dst_node->GetName().c_str(), dst_node->GetType().c_str()); + return false; + } + if (dst_in_anchor->GetIdx() != TransOpUtil::GetTransOpDataIndex(src_node)) { + GELOGD("Next node %s type %s input %d is not for transform. Ignore pass.", dst_node->GetName().c_str(), + dst_node->GetType().c_str(), dst_in_anchor->GetIdx()); + return false; + } + if (!DescAreSymmetry(src_node, dst_node) || !CheckPrecisionLoss(src_node)) { + GELOGD("Not satisfied symmetry or has precision loss, ignore pass."); + return false; + } + return true; +} +bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) { + const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0); + const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0); + GE_CHECK_NOTNULL(src_input_desc); + GE_CHECK_NOTNULL(dst_output_desc); + const auto &src_input_dtype = src_input_desc->GetDataType(); + const auto &src_input_format = src_input_desc->GetFormat(); + const auto &src_input_shape = src_input_desc->GetShape().GetDims(); + const auto &dst_output_dtype = dst_output_desc->GetDataType(); + const auto &dst_output_format = dst_output_desc->GetFormat(); + const auto &dst_output_shape = dst_output_desc->GetShape().GetDims(); + + if (src_node->GetType() == CAST && dst_node->GetType() == CAST) { + return (src_input_dtype == dst_output_dtype) && (src_input_format == dst_output_format); + } else { + return (src_input_dtype == dst_output_dtype) && (src_input_shape == dst_output_shape) && + (src_input_format == dst_output_format); + } +} +bool TransOpSymmetryEliminationPass::CheckPrecisionLoss(const ge::NodePtr &src_node) { + auto idx = TransOpUtil::GetTransOpDataIndex(src_node); + auto input_desc = src_node->GetOpDesc()->GetInputDesc(idx); + auto output_desc = src_node->GetOpDesc()->GetOutputDesc(kTransOpOutIndex); + auto src_dtype = input_desc.GetDataType(); + auto dst_dtype = output_desc.GetDataType(); + auto iter = precision_loss_transfer_map.find(src_dtype); + if (iter != precision_loss_transfer_map.end() && iter->second == dst_dtype) { + GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss.", src_node->GetName().c_str(), + TypeUtils::DataTypeToSerialString(src_dtype).c_str(), TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); + return false; + } + return true; +} + +Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor, + NodePtr &dst_node, const InDataAnchorPtr &dst_in_anchor) { + // Two transform nodes can be offset like A->T1->T2->B + // 1.Unlink T1->T2 + auto ret = src_out_anchor->Unlink(dst_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Unlink data anchor from %s to %s.", src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + // 2.Link A->T2 + auto data_idx = TransOpUtil::GetTransOpDataIndex(src_node); + auto in_anchor = src_node->GetInDataAnchor(data_idx); + GE_CHECK_NOTNULL(in_anchor); + GE_CHECK_NOTNULL(in_anchor->GetPeerOutAnchor()); + auto pre_normal_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode(); + ret = GraphUtils::AddEdge(in_anchor->GetPeerOutAnchor(), dst_in_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data edge from %s to %s failed.", pre_normal_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + // 3.Copy in-control/data-in-control from T1->T2 + ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Copy control edge from %s to %s failed.", src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + // 4.IsolateAndDelete T2, A will link to B automatically, and all control edge will also relink. + ret = IsolateAndDeleteNode(dst_node, {0}); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", dst_node->GetName().c_str(), + dst_node->GetType().c_str()); + return ret; + } + GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", dst_node->GetName().c_str()); + // 5.If T1 has no data out, isolate and deleted it. + if (src_node->GetOutDataNodesSize() == 0) { + // 5.1 Copy out control to pre normal node + ret = GraphUtils::CopyOutCtrlEdges(src_node, pre_normal_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Copy control edge from %s to %s failed.", src_node->GetName().c_str(), + dst_node->GetName().c_str()); + return ret; + } + // 5.2 Isolate and delete T1 + ret = IsolateAndDeleteNode(src_node, {}); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", src_node->GetName().c_str(), + src_node->GetType().c_str()); + return ret; + } + GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", src_node->GetName().c_str()); + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.h b/src/ge/graph/passes/transop_symmetry_elimination_pass.h new file mode 100644 index 00000000..b0cff0c9 --- /dev/null +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.h @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_SYMMETRY_ELIMINATION_PASS_H +#define GE_SYMMETRY_ELIMINATION_PASS_H + +#include "graph/passes/base_pass.h" + +namespace ge { +class TransOpSymmetryEliminationPass : public BaseNodePass { + public: + Status Run(NodePtr &node) override; + + private: + /// + /// Judge whether the node can be offset + /// 1.both are transform op + /// 2.is symmetry position + /// 3.satisfy precision loss + /// @param node + /// @return True or False + /// + static bool CheckCanBeEliminated(const ge::NodePtr &src_node, const InDataAnchorPtr &dst_in_anchor); + /// + /// two transform nodes can be offset only when the front node's input is + /// consistent with the back one's output + /// @param src_node: the front node + /// @param dst_node: the back node + /// @return True or False, whether can be offset or not + /// + static bool DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node); + + /// + /// two transform nodes can not be offset if there is precision loss, like FP32->BOOL BOOL->FP32. + /// keep this pair of transform nodes if it has precision loss. + /// @param src_node: the front node + /// @param dst_node: the back node + /// @return True or False, whether can be offset or not + /// + static bool CheckPrecisionLoss(const NodePtr &src_node); + + /// + /// two transform nodes can be offset like A->T1->T2->B + /// 1.unlink T1->T2 + /// 2.link A->T2 + /// 3.copy in-control/data-in-control from T1->T2 + /// 4.isolateAndDelete T2, it will re-pass all in and out node + /// then we get A->B . Leave T1 to prune pass. + /// ->T1 + /// @param src_node: the front node + /// @param src_out_anchor: the front node out anchor + /// @param dst_node: the back node + /// @param dst_in_anchor: the back node in anchor + /// @return SUCCESS or Fail, whether + /// + Status EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor, NodePtr &dst_node, + const InDataAnchorPtr &dst_in_anchor); +}; +} // namespace ge + +#endif // GE_SYMMETRY_ELIMINATION_PASS_H diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index 80ed5d56..92ae75e6 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -28,18 +28,11 @@ #include "graph/ge_tensor.h" #include "graph/op_desc.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" -using domi::ATTR_NAME_INPUT_FORMAT; -using domi::ATTR_NAME_OUTPUT_FORMAT; -using domi::CAST; -using domi::RESHAPE; -using domi::TRANSDATA; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace { const char *const kRemainNode = "node_remain"; const int kInvalidFusionOpCount = -1; @@ -745,11 +738,23 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { return GRAPH_SUCCESS; } - for (const auto &node : graph->GetAllNodes()) { + for (const auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); if (IsTransOp(node)) { continue; } + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), + node->GetType().c_str()); + continue; + } GELOGI("Current normal node name: %s, type: %s.", node->GetName().c_str(), node->GetType().c_str()); for (const auto &out_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_anchor); diff --git a/src/ge/graph/passes/transpose_transdata_pass.cc b/src/ge/graph/passes/transpose_transdata_pass.cc index ebc068a9..7ac7b7a3 100644 --- a/src/ge/graph/passes/transpose_transdata_pass.cc +++ b/src/ge/graph/passes/transpose_transdata_pass.cc @@ -23,13 +23,10 @@ #include "framework/common/debug/ge_log.h" #include "graph/utils/type_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/utils/node_utils.h" #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" -using domi::TRANSDATA; -using domi::TRANSPOSE; -using domi::TRANSPOSED; - namespace { const char *const kAttrNameSrcFormat = "src_format"; } // namespace @@ -52,6 +49,17 @@ Status TransposeTransDataPass::Run(NodePtr &node) { if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) { return FAILED; } + bool is_unknown = false; + auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); + if (ret != GRAPH_SUCCESS) { + GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), node->GetType().c_str()); + return INTERNAL_ERROR; + } + if (is_unknown) { + GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), + node->GetType().c_str()); + return SUCCESS; + } GELOGD("[%s] TransposeTransDataPass in.", node->GetName().c_str()); auto out_nodes = node->GetOutDataNodes(); diff --git a/src/ge/graph/passes/unused_const_pass.cc b/src/ge/graph/passes/unused_const_pass.cc index 750c95f8..386633b5 100644 --- a/src/ge/graph/passes/unused_const_pass.cc +++ b/src/ge/graph/passes/unused_const_pass.cc @@ -19,8 +19,6 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -using domi::UNUSEDCONST; - namespace ge { /// /// run pass @@ -38,7 +36,7 @@ Status UnusedConstPass::Run(NodePtr &node) { } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == domi::UNUSEDCONST) { + if (op_type == UNUSEDCONST) { GELOGD("op type is unused const."); return IsolateAndDeleteNode(node, {-1}); } diff --git a/src/ge/graph/passes/unused_op_remove_pass.cc b/src/ge/graph/passes/unused_op_remove_pass.cc index 9a56e3a2..093d931a 100644 --- a/src/ge/graph/passes/unused_op_remove_pass.cc +++ b/src/ge/graph/passes/unused_op_remove_pass.cc @@ -29,13 +29,7 @@ #include "inc/pass_manager.h" #include "graph/passes/isolated_op_remove_pass.h" -using domi::ASSERT; -using domi::ATTENTIONDECODER; -using domi::DROPOUT; -; -using domi::PERMUTE; using domi::SUCCESS; -using domi::UNUSEDCONST; namespace ge { const std::set kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; diff --git a/src/ge/graph/passes/var_is_initialized_op_pass.cc b/src/ge/graph/passes/var_is_initialized_op_pass.cc index 0e5e4674..c88db80c 100644 --- a/src/ge/graph/passes/var_is_initialized_op_pass.cc +++ b/src/ge/graph/passes/var_is_initialized_op_pass.cc @@ -26,10 +26,6 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" -using domi::CONSTANT; -using domi::VARIABLE; -using domi::VARISINITIALIZEDOP; - namespace ge { namespace { const int kAssignVarRefIndex = 0; @@ -284,7 +280,7 @@ bool VarIsInitializedOpPass::IsVarInitedOnTheGraphAndNode(const NodePtr &node, i Status VarIsInitializedOpPass::CheckAndSetVarInited(const NodePtr &node, bool &inited, int64_t &inited_var) { GE_CHECK_NOTNULL(node); inited = false; - if (node->GetType() != domi::ASSIGN) { + if (node->GetType() != ASSIGN) { return SUCCESS; } auto ref_in_anchor = node->GetInDataAnchor(kAssignVarRefIndex); diff --git a/src/ge/graph/passes/variable_format_pass.cc b/src/ge/graph/passes/variable_format_pass.cc index 302011fe..28f6a4f7 100644 --- a/src/ge/graph/passes/variable_format_pass.cc +++ b/src/ge/graph/passes/variable_format_pass.cc @@ -26,7 +26,7 @@ Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != domi::VARIABLE, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); ge::NodePtr use_node = nullptr; if (GetApplyMomentumOpByVariableInput(node, use_node)) { @@ -79,7 +79,7 @@ Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); if (in_node != nullptr) { string in_op_type = in_node->GetType(); - if ((in_op_type == domi::VARIABLE) && (in_node->GetOpDesc() != nullptr) && + if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && (in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); @@ -104,7 +104,7 @@ Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &nod NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); if (in_node != nullptr) { string in_op_type = in_node->GetType(); - if ((in_op_type == domi::VARIABLE) && (in_node->GetOpDesc() != nullptr)) { + if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); op_desc_ptr->MutableInputDesc(0)->SetFormat(format); op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index 302598da..d5dedbdc 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -18,15 +18,15 @@ #include #include +#include "common/formats/formats.h" +#include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" +#include "graph/ge_context.h" #include "graph/graph.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "common/formats/formats.h" -#include "common/formats/utils/formats_trans_utils.h" -#include "graph/ge_context.h" namespace ge { namespace { @@ -91,9 +91,12 @@ Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) { } bool IsTransSupport(const TransNodeInfo &trans_info) { - if (trans_info.node_type == domi::RESHAPE || trans_info.node_type == domi::REFORMAT) { + if (trans_info.output.GetShape().IsUnknownShape()) { + return false; + } + if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { return true; - } else if (trans_info.node_type == domi::TRANSDATA) { + } else if (trans_info.node_type == TRANSDATA) { formats::TransArgs args{nullptr, trans_info.input.GetFormat(), trans_info.output.GetFormat(), @@ -101,7 +104,7 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { trans_info.output.GetShape().GetDims(), trans_info.input.GetDataType()}; return formats::IsTransFormatSupport(args); - } else if (trans_info.node_type == domi::CAST) { + } else if (trans_info.node_type == CAST) { formats::CastArgs datatype_args{nullptr, static_cast(trans_info.input.GetShape().GetShapeSize()), trans_info.input.GetDataType(), trans_info.output.GetDataType()}; return formats::IsTransDataTypeSupport(datatype_args); @@ -422,12 +425,12 @@ Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &com std::map names_to_var; std::map> names_to_refs; GE_CHECK_NOTNULL(compute_graph); - for (auto &node : compute_graph->GetAllNodes()) { - if (node->GetType() != domi::VARIABLE) { + for (auto &node : compute_graph->GetDirectNode()) { + if (node->GetType() != VARIABLE) { continue; } std::string ref_var_name; - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_name)) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) { names_to_var[node->GetName()] = node; } else { names_to_refs[ref_var_name].insert(node); @@ -583,8 +586,8 @@ Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) { // renew var manager desc Status ret = SUCCESS; for (auto &node : graph->GetDirectNode()) { - bool is_var_node = (node->GetType() == domi::VARIABLE) || (node->GetType() == domi::VARIABLEV2) || - (node->GetType() == domi::VARHANDLEOP); + bool is_var_node = + (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP); if (is_var_node) { if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) { GELOGD("var manager does not exist var node[%s]", node->GetName().c_str()); diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index 981e1235..4db78a46 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -22,16 +22,14 @@ #include "common/ge/ge_util.h" #include "external/graph/graph.h" #include "framework/common/debug/ge_log.h" +#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/node.h" #include "graph/utils/tensor_utils.h" -using domi::ASSIGN; -using domi::REF_VAR_SRC_VAR_NAME; -using domi::VAR_ATTR_VAR_OUT_INDEX; -using domi::VARIABLE; - namespace ge { +std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ + {REFSWITCH, {{0, 0}, {0, 1}}}}; Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { @@ -48,9 +46,7 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - bool is_variable = node->GetOpDesc()->GetType() == VARIABLE; - bool is_deal = has_dealed_variable_.find(node->GetName()) == has_dealed_variable_.end(); - if (is_variable && is_deal) { + if (node->GetOpDesc()->GetType() == VARIABLE) { Status ret = DealVariableNode(node); if (ret != SUCCESS) { GELOGE(ret, "variable add back edge failed"); @@ -63,7 +59,7 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GELOGI("ref type:[ %s ]", iter->first.c_str()); auto index_map = iter->second; for (auto index_iter = index_map.begin(); index_iter != index_map.end(); ++index_iter) { - GELOGI("{ %d:%d }", index_iter->first, index_iter->second); + GELOGI("{ %d : %d }", index_iter->first, index_iter->second); } } @@ -154,7 +150,7 @@ NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, } } if (!found_writeable_node) { - GELOGI("final writable node is %s", current_node->GetName().c_str()); + GELOGD("final writable node is %s", current_node->GetName().c_str()); return current_node; } } @@ -164,53 +160,68 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g GE_CHECK_NOTNULL(final_writable_node); GE_CHECK_NOTNULL(var_node); - NodePtr var_ref_node = CreatVariableRef(final_writable_node, var_node); - GE_CHECK_NOTNULL(var_ref_node); - // add control anchor between var_ref_node and final peer node - // var_ref_node need to execute before other nodes + if (final_writable_node->GetType() == FRAMEWORKOP) { + GELOGD("No need to add variable_ref for frameworkop"); + return SUCCESS; + } + // Check for duplicate creation + OutDataAnchorPtr out_anchor = final_writable_node->GetOutDataAnchor(index); + GE_CHECK_NOTNULL(out_anchor); + for (const auto &peer_anchor : out_anchor->GetPeerAnchors()) { + NodePtr peer_node = peer_anchor->GetOwnerNode(); + OpDescPtr peer_opdesc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_opdesc); + string src_var_name; + (void)ge::AttrUtils::GetStr(peer_opdesc, REF_VAR_SRC_VAR_NAME, src_var_name); + if (peer_node->GetType() == VARIABLE && var_node->GetName() == src_var_name) { + GELOGI("The corresponding variable_ref has been added to this connection."); + return SUCCESS; + } + } + // creat variable_ref + std::stringstream variable_ref_name; + variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; + NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + Status ret_check = CheckStreamLabel(variable_ref_node, final_writable_node); + if (ret_check != SUCCESS) { + GELOGE(FAILED, "check stream lable failed"); + return FAILED; + } + + GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); + GE_CHECK_NOTNULL(variable_ref_node); + // add control anchor between variable_ref and final peer node + // variable_ref_node need to execute before other nodes auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); for (auto &final_writable_outAnchor : final_writable_outAnchors) { GE_CHECK_NOTNULL(final_writable_outAnchor); for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { GE_CHECK_NOTNULL(final_writable_peerAnchor); NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); - graphStatus ret = ge::GraphUtils::AddEdge(var_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); + graphStatus ret = + ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add control anchor between var_ref_node and final_writable peer_node failed"); + GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); return FAILED; } } } - // add edge final node:index ---> var_ref_node:0 - graphStatus ret = - ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), var_ref_node->GetInDataAnchor(0)); + graphStatus ret = ge::GraphUtils::AddEdge(out_anchor, variable_ref_node->GetInDataAnchor(0)); if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add data anchor between var_ref_node and final_writable peer_node failed"); + GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); return FAILED; } return SUCCESS; } -ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) { - if ((final_writable_node == nullptr) || (var_node == nullptr) || (var_node->GetOwnerComputeGraph() == nullptr)) { - GELOGE(FAILED, "parameter ptr is null."); - return nullptr; - } - GELOGD("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(), - var_node->GetName().c_str()); - - static uint32_t var_ref_count = 0; - std::stringstream var_ref_name; - var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++; - +ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { OpDescPtr var_op_desc = var_node->GetOpDesc(); if (var_op_desc == nullptr) { GELOGE(FAILED, "get var opdesc is nullptr"); return nullptr; } - OpDescPtr var_ref_op_desc = - MakeShared(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()); + OpDescPtr var_ref_op_desc = MakeShared(variable_ref_name.c_str(), var_op_desc->GetType()); if (var_ref_op_desc == nullptr) { GELOGE(FAILED, "var_ref opdesc is nullptr"); return nullptr; @@ -222,15 +233,15 @@ ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_ GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, GELOGW("add input desc edge failed"); return nullptr); - NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); - GE_IF_BOOL_EXEC(var_ref_node == nullptr, GELOGW("var_ref_node is null"); return nullptr); - has_dealed_variable_.insert(var_node->GetName()); + NodePtr variable_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); + GE_IF_BOOL_EXEC(variable_ref_node == nullptr, GELOGW("variable_ref_node is null"); return nullptr); bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); if (is_set_str) { - GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", var_ref_node->GetName().c_str(), var_op_desc->GetName().c_str()); + GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", variable_ref_node->GetName().c_str(), + var_op_desc->GetName().c_str()); } - return var_ref_node; + return variable_ref_node; } int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { @@ -239,22 +250,14 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu } GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index); auto node_type = node->GetType(); - if (node_type == ASSIGN) { - if (UpdateAssignOpDesc(node) != SUCCESS) { - return -1; - } - } - auto node_iter = ref_input_output_map_.find(node_type); - if (node_iter == ref_input_output_map_.end()) { - return -1; + if (node_type == FRAMEWORKOP) { + std::string original_type; + GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); + GELOGD("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str()); + return FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_); } - - auto index_iter = node_iter->second.find(input_index); - if (index_iter == node_iter->second.end()) { - return -1; - } - return index_iter->second; + return FindRefOutIndex(node_type, input_index, ref_input_output_map_); } void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { @@ -274,7 +277,7 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node } auto ref_type_and_input_output_iter = ref_input_output_map_.find(node->GetType()); if (ref_type_and_input_output_iter != ref_input_output_map_.end()) { - auto input_output_index_map = ref_type_and_input_output_iter->second; + auto &input_output_index_map = ref_type_and_input_output_iter->second; if (input_output_index_map.find(input_index) == input_output_index_map.end()) { input_output_index_map.emplace(input_index, output_index); GELOGD("Add RefInputOutputMap %s:{ %d, %d }", node->GetType().c_str(), input_index, output_index); @@ -286,23 +289,31 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node } } -Status VariablePrepareOpPass::UpdateAssignOpDesc(const ge::NodePtr &node) { - GE_CHECK_NOTNULL(node); - ge::InDataAnchorPtr var_anchor = node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(var_anchor); - GE_CHECK_NOTNULL(var_anchor->GetPeerOutAnchor()); - ge::NodePtr var_node = var_anchor->GetPeerOutAnchor()->GetOwnerNode(); - ge::OpDescPtr var_op_desc = var_node->GetOpDesc(); - GE_CHECK_NOTNULL(var_op_desc); - ge::GeTensorDesc var_tensor_desc = var_op_desc->GetOutputDesc(0); +int VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, + const std::map> &ref_map) { + auto node_iter = ref_map.find(node_type); + if (node_iter == ref_map.end()) { + return -1; + } - ge::OpDescPtr assign_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(assign_op_desc); - Status update_input_desc_ret = assign_op_desc->UpdateInputDesc(0, var_tensor_desc); - Status update_output_desc_ret = assign_op_desc->UpdateOutputDesc(0, var_tensor_desc); - if (update_input_desc_ret != GRAPH_SUCCESS || update_output_desc_ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "update input or output desc success"); - return FAILED; + auto index_iter = node_iter->second.find(input_index); + if (index_iter == node_iter->second.end()) { + return -1; + } + return index_iter->second; +} + +Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node, + const ge::NodePtr &final_writable_node) { + // Solve the problem that the writable node is not in the same stream as the subsequent node. + // Causes the stream to not trigger properly. + // The label of node should be handled uniformly. + OpDescPtr writable_desc = final_writable_node->GetOpDesc(); + GE_CHECK_NOTNULL(writable_desc); + std::string stream_label; + (void)AttrUtils::GetStr(writable_desc, ATTR_NAME_STREAM_LABEL, stream_label); + if (!stream_label.empty()) { + GE_CHK_STATUS_RET(SetStreamLabel(var_ref_node, stream_label), "set stream label failed"); } return SUCCESS; } diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index 0fbd311c..c8b9883e 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -33,13 +33,15 @@ class VariablePrepareOpPass : public GraphPass { Status DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index); NodePtr GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index); Status AddVariableRef(ge::NodePtr &node, ge::NodePtr &var_node, int index); - NodePtr CreatVariableRef(ge::NodePtr &final_ref_type_node, ge::NodePtr &var_node); + NodePtr CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node); int GetWritableNodeOutIndex(const NodePtr &node, int input_index); - Status UpdateAssignOpDesc(const ge::NodePtr &node); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); + int FindRefOutIndex(const std::string &node_type, int input_index, + const std::map> &ref_map); + Status CheckStreamLabel(const ge::NodePtr &var_ref_node, const ge::NodePtr &final_writable_node); std::map> ref_input_output_map_; - std::unordered_set has_dealed_variable_{}; + static std::map> ref_node_without_prototype_map_; }; } // namespace ge diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index 1daa6e5c..dfdb8335 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -18,10 +18,6 @@ #include #include "framework/common/debug/ge_log.h" -using domi::REF_VAR_PRE_PEER_OUT_INDEX; -using domi::REF_VAR_SRC_VAR_NAME; -using domi::VARIABLE; - namespace ge { Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { GE_TIMESTAMP_START(VariableRefDeleteOpPass); @@ -35,8 +31,8 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string ref_var_src_var_name; - bool is_variable_ref = (node->GetOpDesc()->GetType() == domi::VARIABLE) && - (ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)); + bool is_variable_ref = (node->GetOpDesc()->GetType() == VARIABLE) && + (ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)); if (!is_variable_ref) { continue; } @@ -58,13 +54,7 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge::NodePtr &variable_ref, const std::string &ref_var_src_var_name) { - GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(variable_ref); - // remove variable_ref all out anchor - for (auto &variable_ref_outAnchor : variable_ref->GetAllOutAnchors()) { - variable_ref_outAnchor->UnlinkAll(); - } - auto inAnchor0 = variable_ref->GetInDataAnchor(0); if (inAnchor0 == nullptr) { GELOGE(FAILED, "variable_ref [%s] no input", variable_ref->GetName().c_str()); @@ -78,30 +68,34 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: // get previous node of variable_ref NodePtr peer_node = inAnchor0->GetPeerOutAnchor()->GetOwnerNode(); - // remove in anchor [0] of variable_ref - inAnchor0->UnlinkAll(); - if (ge::GraphUtils::RemoveJustNode(graph, variable_ref) != GRAPH_SUCCESS) { - GELOGE(FAILED, "remove variable_ref failed"); - return FAILED; - } - // add attr [REF_VAR_SRC_VAR_NAME] to the previous node of the variable_ref GE_CHECK_NOTNULL(peer_node->GetOpDesc()); - bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - ge::NodePtr var_ref_src_var = graph->FindNode(ref_var_src_var_name); - if (var_ref_src_var == nullptr) { - GELOGE(FAILED, "get var_ref_src_var failed"); + ge::NodePtr ref_var_src_var = graph->FindNode(ref_var_src_var_name); + if (ref_var_src_var == nullptr) { + GELOGE(FAILED, "get ref_var_src_var failed"); return FAILED; } - GE_CHECK_NOTNULL(var_ref_src_var->GetOpDesc()); - bool is_set_index = ge::AttrUtils::SetInt(var_ref_src_var->GetOpDesc(), domi::REF_VAR_PRE_PEER_OUT_INDEX, index); + GE_CHECK_NOTNULL(ref_var_src_var->GetOpDesc()); + bool is_set_index = ge::AttrUtils::SetInt(ref_var_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, index); if (is_set_str && is_set_index) { GELOGI("[%s]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), ref_var_src_var_name.c_str()); - GELOGI("[%s]: add attr [ REF_VAR_PRE_PEER_OUT_INDEX: %d ]", var_ref_src_var->GetName().c_str(), index); + GELOGI("[%s]: add attr [REF_VAR_PRE_PEER_OUT_INDEX: %d]", ref_var_src_var->GetName().c_str(), index); } + // remove variable_ref + if (GraphUtils::IsolateNode(variable_ref, {0}) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", variable_ref->GetName().c_str(), + variable_ref->GetType().c_str()); + return FAILED; + } + if (GraphUtils::RemoveNodeWithoutRelink(graph, variable_ref) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", variable_ref->GetName().c_str(), + variable_ref->GetType().c_str()); + return FAILED; + } return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 8447552d..9850ef9b 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/helper/model_helper.h" @@ -25,6 +26,7 @@ #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" #include "graph/common/transop_util.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" @@ -40,19 +42,21 @@ #include "graph/passes/dimension_adjust_pass.h" #include "graph/passes/dimension_compute_pass.h" #include "graph/passes/dropout_pass.h" -#include "graph/passes/end_graph_pass.h" #include "graph/passes/enter_pass.h" #include "graph/passes/flow_ctrl_pass.h" +#include "graph/passes/for_pass.h" #include "graph/passes/get_original_format_pass.h" #include "graph/passes/guarantee_const_pass.h" #include "graph/passes/hccl_memcpy_pass.h" #include "graph/passes/identity_pass.h" +#include "graph/passes/cond_pass.h" #include "graph/passes/infershape_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/merge_pass.h" #include "graph/passes/net_output_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/no_use_reshape_remove_pass.h" +#include "graph/passes/parallel_concat_start_op_pass.h" #include "graph/passes/placeholder_with_default_pass.h" #include "graph/passes/prevent_gradient_pass.h" #include "graph/passes/print_op_pass.h" @@ -71,6 +75,9 @@ #include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/common_subexpression_elimination_pass.h" +#include "graph/passes/replace_with_empty_const_pass.h" +#include "graph/passes/subgraph_pass.h" +#include "graph/passes/replace_transshape_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" @@ -80,14 +87,22 @@ #include "multi_batch_copy_graph.h" #include "runtime/dev.h" -using domi::AIPP; -using domi::AIPPDATA; -using domi::ASSIGN; -using domi::CAST; -using domi::DATA; -using domi::MULTIPLY; -using domi::StringUtils; -using ge::CheckInt64Uint32MulOverflow; +#include "graph/passes/transop_nearby_allreduce_fusion_pass.h" +#include "graph/passes/reshape_remove_pass.h" +#include "graph/passes/dimension_adjust_pass.h" +#include "graph/passes/identify_reference_pass.h" +#include "graph/passes/link_gen_mask_nodes_pass.h" +#include "graph/passes/permute_pass.h" +#include "graph/passes/same_transdata_breadth_fusion_pass.h" +#include "graph/passes/transop_breadth_fusion_pass.h" +#include "graph/passes/transop_depth_fusion_pass.h" + +#include "graph/passes/transop_without_reshape_fusion_pass.h" +#include "graph/passes/cast_remove_pass.h" +#include "graph/passes/transpose_transdata_pass.h" +#include "graph/passes/variable_op_pass.h" +#include "graph/passes/variable_prepare_op_pass.h" +#include "graph/passes/variable_ref_delete_op_pass.h" namespace ge { namespace { @@ -96,6 +111,8 @@ static std::map output_type_str_to_datatype = { {"UINT16", ge::DT_UINT16}, {"UINT8", ge::DT_UINT8}, {"INT32", ge::DT_INT32}, {"INT64", ge::DT_INT64}, {"UINT32", ge::DT_UINT32}, {"UINT64", ge::DT_UINT64}, {"DOUBLE", ge::DT_DOUBLE}}; +const char *const kMbatchSwitchnName = "mbatch-switch-name"; + OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { GeTensorPtr tensor = MakeShared(); if (tensor == nullptr) { @@ -137,7 +154,7 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { void AddTransNodeAttr(const std::string &node_type, const GeTensorDesc &input, const GeTensorDesc &output, OpDescPtr &op_desc) { // For format transfer node, the IR definition has src/dst format attrs - if (node_type == domi::TRANSDATA) { + if (node_type == TRANSDATA) { GE_IF_BOOL_EXEC( !AttrUtils::SetStr(op_desc, FORMAT_TRANSFER_SRC_FORMAT, TypeUtils::FormatToSerialString(input.GetFormat())), GELOGW("SetStr FORMAT_TRANSFER_SRC_FORMAT failed");) @@ -146,7 +163,7 @@ void AddTransNodeAttr(const std::string &node_type, const GeTensorDesc &input, c GELOGW("SetStr FORMAT_TRANSFER_DST_FORMAT failed");) } // For cast node, the IR definition has src/dst attrs - if (node_type == domi::CAST) { + if (node_type == CAST) { GE_IF_BOOL_EXEC(!AttrUtils::SetInt(op_desc, CAST_ATTR_SRCT, static_cast(input.GetDataType())), GELOGW("SetInt CAST_ATTR_SRCT failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetInt(op_desc, CAST_ATTR_DSTT, static_cast(output.GetDataType())), @@ -201,7 +218,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c AddTransNodeAttr(node_type, input, output, op_desc); NodePtr shape_node = nullptr; - if (node_type == domi::RESHAPE) { + if (node_type == RESHAPE) { auto shape_desc = CreateTensorShape(output); if (shape_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add shape for reshape %s, can not create the shape input", @@ -227,7 +244,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c return nullptr; } - if (node_type == domi::RESHAPE) { + if (node_type == RESHAPE) { if (GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), trans_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to add shape node for reshape %s, can not add the edge", name.c_str()); return nullptr; @@ -338,6 +355,14 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { index, iter->node_type.c_str()); return INTERNAL_ERROR; } + // set stream_label + OpDescPtr var_desc = var->GetOpDesc(); + GE_CHECK_NOTNULL(var_desc); + std::string stream_label; + (void)AttrUtils::GetStr(var_desc, ATTR_NAME_STREAM_LABEL, stream_label); + if (!stream_label.empty()) { + GE_CHK_STATUS_RET(SetStreamLabel(last_node, stream_label), "set stream label failed"); + } GE_CHK_BOOL_EXEC((ge::AttrUtils::SetBool(last_node->GetOpDesc(), ge::ATTR_INSERTED_BY_GE, true)), return INTERNAL_ERROR, "Set attr ATTR_INSERTED_BY_GE failed."); GELOGD("Recover trans node %s type %s success", trans_name.c_str(), iter->node_type.c_str()); @@ -362,6 +387,15 @@ Status RecoverTransRoadForVarRef(const std::set &nodes, const VarTransR var->GetName().c_str(), index, iter->node_type.c_str()); return INTERNAL_ERROR; } + // set stream_label + OpDescPtr var_desc = var->GetOpDesc(); + GE_CHECK_NOTNULL(var_desc); + std::string stream_label; + (void)AttrUtils::GetStr(var_desc, ATTR_NAME_STREAM_LABEL, stream_label); + if (!stream_label.empty()) { + GE_CHK_STATUS_RET(SetStreamLabel(last_node, stream_label), "set stream label failed"); + } + GE_CHK_BOOL_EXEC((ge::AttrUtils::SetBool(last_node->GetOpDesc(), ge::ATTR_INSERTED_BY_GE, true)), return INTERNAL_ERROR, "Set attr ATTR_INSERTED_BY_GE failed."); } @@ -382,10 +416,10 @@ VarNamesToRefs CollectVarNamesToRefs(const ComputeGraphPtr &graph) { return names_to_refs; } for (auto &node : graph->GetAllNodes()) { - if (node->GetType() != domi::VARIABLE) { + if (node->GetType() != VARIABLE) { continue; } - if (AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, var_name)) { + if (AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, var_name)) { (void)names_to_refs[var_name].insert(node); } } @@ -412,7 +446,7 @@ NodePtr CreateCastOp(const ge::GeShape &shape, const ge::DataType input_data_typ static uint32_t transop_count = 0; std::string name = std::string("cast_node").append(std::to_string(transop_count++)); - GELOGI("create cast op:%s, input datatype:%s, out datatype:%s", name.c_str(), + GELOGI("create cast op:%s, input datatype:%s, out datatype:%s.", name.c_str(), TypeUtils::DataTypeToSerialString(input_data_type).c_str(), TypeUtils::DataTypeToSerialString(output_data_type).c_str()); GeTensorDesc input(shape, format, input_data_type); @@ -427,12 +461,12 @@ NodePtr CreateCastOp(const ge::GeShape &shape, const ge::DataType input_data_typ output.SetOriginDataType(output_data_type); ge::TensorUtils::SetRealDimCnt(output, static_cast(shape.GetDims().size())); - auto cast_node = CreateTransNode(name, domi::CAST, input, output, node); + auto cast_node = CreateTransNode(name, CAST, input, output, node); GELOGD("Create cast node success."); return cast_node; } -Status ProcessInputFP16(NodePtr &node_ptr) { +Status ProcessInputFP16(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -440,35 +474,61 @@ Status ProcessInputFP16(NodePtr &node_ptr) { GE_CHECK_NOTNULL(input); ge::DataType src_dtype = input->GetDataType(); if (src_dtype == DT_FLOAT16) { - GELOGI("The node name %s dtype is fp16", node_ptr->GetName().c_str()); + GELOGI("The node name, %s dtype is fp16", node_ptr->GetName().c_str()); return SUCCESS; } - int64_t desc_shape = input->GetShape().GetShapeSize(); - uint32_t len = 0; - if (!TypeUtils::GetDataTypeLength(DT_FLOAT16, len)) { - GELOGE(INTERNAL_ERROR, "GET FP16 datatype length failed"); - return FAILED; - } - FMK_INT64_UINT32_MULCHECK(desc_shape, len); - int64_t shape_size = desc_shape * len; input->SetDataType(DT_FLOAT16); input->SetOriginDataType(DT_FLOAT16); - ge::TensorUtils::SetSize(*input, shape_size); + int64_t input_shape_size = 0; + int64_t output_shape_size = 0; + ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); + ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); + if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); + return FAILED; + } + ge::TensorUtils::SetSize(*input, input_shape_size); const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(DT_FLOAT16); output->SetOriginDataType(DT_FLOAT16); - ge::TensorUtils::SetSize(*output, shape_size); + ge::TensorUtils::SetSize(*output, output_shape_size); - NodePtr cast_node = CreateCastOp(output->GetShape(), DT_FLOAT16, src_dtype, output->GetFormat(), node_ptr); - GE_CHECK_NOTNULL(cast_node); - OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); - InDataAnchorPtr cast_in = cast_node->GetInDataAnchor(0); - OutDataAnchorPtr cast_out = cast_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", - node_ptr->GetName().c_str(), cast_node->GetName().c_str()); - return FAILED; + if (!is_dynamic_batch) { + NodePtr cast_node = CreateCastOp(output->GetShape(), DT_FLOAT16, src_dtype, output->GetFormat(), node_ptr); + GE_CHECK_NOTNULL(cast_node); + OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); + InDataAnchorPtr cast_in = cast_node->GetInDataAnchor(0); + OutDataAnchorPtr cast_out = cast_node->GetOutDataAnchor(0); + if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", + node_ptr->GetName().c_str(), cast_node->GetName().c_str()); + return FAILED; + } + } else { + auto switchn_op_desc = switchn_node->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_op_desc); + const GeTensorDescPtr &switchn_input = switchn_op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(switchn_input); + switchn_input->SetDataType(DT_FLOAT16); + switchn_input->SetOriginDataType(DT_FLOAT16); + for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { + const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(switchn_output); + switchn_output->SetDataType(DT_FLOAT16); + switchn_output->SetOriginDataType(DT_FLOAT16); + NodePtr cast_node = + CreateCastOp(switchn_output->GetShape(), DT_FLOAT16, src_dtype, switchn_output->GetFormat(), node_ptr); + GE_CHECK_NOTNULL(cast_node); + OutDataAnchorPtr src_out = switchn_node->GetOutDataAnchor(i); + InDataAnchorPtr cast_in = cast_node->GetInDataAnchor(0); + OutDataAnchorPtr cast_out = cast_node->GetOutDataAnchor(0); + if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", + switchn_node->GetName().c_str(), cast_node->GetName().c_str()); + return FAILED; + } + } } return SUCCESS; } @@ -479,7 +539,7 @@ NodePtr CreateTransdataNode(const ge::GeShape &in_shape, const ge::Format input_ // Does not involve multithreading. std::string name = std::string("transdata_node").append(std::to_string(transop_count++)); - GELOGI("create trandata op:%s, input format:%s, out format:%s", name.c_str(), + GELOGI("create trandata op:%s, input format:%s, out format:%s.", name.c_str(), TypeUtils::FormatToSerialString(input_format).c_str(), TypeUtils::FormatToSerialString(output_format).c_str()); GeTensorDesc input(in_shape, input_format, dt); @@ -492,7 +552,7 @@ NodePtr CreateTransdataNode(const ge::GeShape &in_shape, const ge::Format input_ output.SetOriginShape(out_shape); output.SetOriginDataType(dt); - return CreateTransNode(name, domi::TRANSDATA, input, output, node); + return CreateTransNode(name, TRANSDATA, input, output, node); } Status TransferShape2NC1HWC0(Format src_format, const std::vector &src_shape, DataType dt, Format dst_format, @@ -552,7 +612,24 @@ Status ModifyInputFormatAndShape(NodePtr &node_ptr) { return SUCCESS; } -Status ProcessInputNC1HWC0(NodePtr &node_ptr) { +Status ModifyFormatAndShapeForSingleTensor(const GeTensorDescPtr &input_output) { + GE_CHECK_NOTNULL(input_output); + ge::Format old_format = input_output->GetFormat(); + std::vector old_shape = input_output->GetShape().GetDims(); + ge::DataType dt = input_output->GetDataType(); + std::vector dst_shape_dims; + if (TransferShape2NC1HWC0(old_format, old_shape, dt, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Trans shape failed"); + return FAILED; + } + input_output->SetFormat(FORMAT_NC1HWC0); + input_output->SetOriginFormat(FORMAT_NC1HWC0); + input_output->SetShape(ge::GeShape(dst_shape_dims)); + input_output->SetOriginShape(ge::GeShape(dst_shape_dims)); + return SUCCESS; +} + +Status ProcessInputNC1HWC0(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -573,16 +650,71 @@ Status ProcessInputNC1HWC0(NodePtr &node_ptr) { GELOGE(INTERNAL_ERROR, "modify format and shape failed"); return FAILED; } + if (!is_dynamic_batch) { + NodePtr trans_node = + CreateTransdataNode(input->GetShape(), FORMAT_NC1HWC0, old_shape, old_format, input->GetDataType(), node_ptr); + GE_CHECK_NOTNULL(trans_node); + OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); + InDataAnchorPtr trans_in = trans_node->GetInDataAnchor(0); + OutDataAnchorPtr trans_out = trans_node->GetOutDataAnchor(0); + if (AddTransNodeBetweenTwoNodes(src_out, trans_in, trans_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "add node between two nodes failed"); + return FAILED; + } + } else { + auto switchn_op_desc = switchn_node->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_op_desc); + const GeTensorDescPtr &switchn_input = switchn_op_desc->MutableInputDesc(0); + if (ModifyFormatAndShapeForSingleTensor(switchn_input) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { + const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(switchn_output); + old_format = switchn_output->GetFormat(); + old_shape = switchn_output->GetShape(); + if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + NodePtr trans_node = CreateTransdataNode(switchn_output->GetShape(), FORMAT_NC1HWC0, old_shape, old_format, + switchn_output->GetDataType(), node_ptr); + GE_CHECK_NOTNULL(trans_node); + OutDataAnchorPtr src_out = switchn_node->GetOutDataAnchor(i); + InDataAnchorPtr cast_in = trans_node->GetInDataAnchor(0); + OutDataAnchorPtr cast_out = trans_node->GetOutDataAnchor(0); + if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", + switchn_node->GetName().c_str(), trans_node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} - NodePtr trans_node = - CreateTransdataNode(input->GetShape(), FORMAT_NC1HWC0, old_shape, old_format, input->GetDataType(), node_ptr); - GE_CHECK_NOTNULL(trans_node); - OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); - InDataAnchorPtr trans_in = trans_node->GetInDataAnchor(0); - OutDataAnchorPtr trans_out = trans_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, trans_in, trans_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed"); - return FAILED; +Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &switchn_node) { + is_dynamic_batch = false; + std::string related_node_name; + if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { + if (related_node_name.empty()) { + GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty", + data_node->GetName().c_str()); + return INTERNAL_ERROR; + } + for (const NodePtr &next_node : data_node->GetOutNodes()) { + if (next_node->GetName() == related_node_name) { + switchn_node = next_node; + break; + } + } + if (switchn_node == nullptr) { + GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph", + data_node->GetName().c_str(), related_node_name.c_str()); + return INTERNAL_ERROR; + } + is_dynamic_batch = true; } return SUCCESS; } @@ -601,8 +733,14 @@ Status ProcessDataNode(NodePtr &node_ptr) { return FAILED; } } - GELOGI("input_fp16 is found, the node name is %s", node_ptr->GetName().c_str()); - if (ProcessInputFP16(node_ptr) != SUCCESS) { + GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); + bool is_dynamic_batch = false; + NodePtr switchn_node = nullptr; + if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { + GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); + return FAILED; + } + if (ProcessInputFP16(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); return FAILED; } @@ -611,28 +749,371 @@ Status ProcessDataNode(NodePtr &node_ptr) { if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format) || !set_format) { return SUCCESS; } - GELOGI("The format of node [%s] should be set NC1HWC0", node_ptr->GetName().c_str()); - if (ProcessInputNC1HWC0(node_ptr) != SUCCESS) { + GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); + if (ProcessInputNC1HWC0(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); return FAILED; } return SUCCESS; } +bool CheckIfSetOutputType(std::string &output_type, ge::DataType &output_data_type) { + if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { + output_data_type = output_type_str_to_datatype[output_type]; + return true; + } else { + GELOGI("output_type [%s] is not set or set unexpected", output_type.c_str()); + return false; + } + return false; +} +bool CheckOpType(const NodePtr &node, const std::string type) { + if (node->GetType() == type) { + return true; + } + return false; +} + +Status ProcessFp16Nc1hwc0Dynamic(const OpDescPtr &src_op_desc, NodePtr &node) { + auto merge_out = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_out); + if (ModifyFormatAndShapeForSingleTensor(merge_out) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + auto merge_in = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_in); + ge::Format old_format = merge_in->GetFormat(); + ge::GeShape old_shape = merge_in->GetShape(); + if (ModifyFormatAndShapeForSingleTensor(merge_in) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + ge::GeShape new_shape = merge_in->GetShape(); + NodePtr trans_node = CreateTransdataNode(old_shape, old_format, new_shape, FORMAT_NC1HWC0, DT_FLOAT16, node); + GE_CHECK_NOTNULL(trans_node); + const InDataAnchorPtr &dst_in_anchor = node->GetInDataAnchor(i); + GE_CHECK_NOTNULL(dst_in_anchor); + const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, trans_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } + return SUCCESS; +} + Status ProcessNetoutputNodeFp16Nc1hwc0(GeTensorDesc &src_desc, const InDataAnchorPtr &in_anchor, GeTensorDescPtr &net_output_input_desc, NodePtr &node) { + bool is_dynamic = CheckOpType(node, MERGE); + auto src_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); ge::GeShape src_shape = src_desc.GetShape(); ge::Format src_format = src_desc.GetFormat(); ge::DataType src_dtype = src_desc.GetDataType(); if (src_dtype != DT_FLOAT16) { + if (!is_dynamic) { + auto peer_out = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out); + NodePtr cast_node = CreateCastOp(src_shape, src_dtype, DT_FLOAT16, src_format, node); + GE_CHECK_NOTNULL(cast_node); + if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } else { + // Update outputdesc + const GeTensorDescPtr &merge_output = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_output); + merge_output->SetDataType(DT_FLOAT16); + merge_output->SetOriginDataType(DT_FLOAT16); + // Update input + for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { + const GeTensorDescPtr &merge_input = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_input); + src_shape = merge_input->GetShape(); + src_format = merge_input->GetFormat(); + src_dtype = merge_input->GetDataType(); + merge_input->SetDataType(DT_FLOAT16); + merge_input->SetOriginDataType(DT_FLOAT16); + const InDataAnchorPtr &dst_in_anchor = node->GetInDataAnchor(i); + const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); + NodePtr cast_node = CreateCastOp(src_shape, src_dtype, DT_FLOAT16, src_format, node); + if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, cast_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } + } + net_output_input_desc->SetDataType(DT_FLOAT16); + net_output_input_desc->SetOriginDataType(DT_FLOAT16); + } + if (src_format == FORMAT_NC1HWC0) { + GELOGI("Format is NC1HWC0, no need to transfer"); + return SUCCESS; + } + std::vector dst_shape_dims; + std::vector src_shape_dims = src_shape.GetDims(); + if (TransferShape2NC1HWC0(src_format, src_shape_dims, DT_FLOAT16, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Trans shape failed"); + return FAILED; + } + ge::GeShape dst_shape(dst_shape_dims); + net_output_input_desc->SetFormat(FORMAT_NC1HWC0); + net_output_input_desc->SetOriginFormat(FORMAT_NC1HWC0); + net_output_input_desc->SetShape(dst_shape); + net_output_input_desc->SetOriginShape(dst_shape); + if (!is_dynamic) { + NodePtr trans_node = CreateTransdataNode(src_shape, src_format, dst_shape, FORMAT_NC1HWC0, DT_FLOAT16, node); + GE_CHECK_NOTNULL(trans_node); + auto peer_out_new = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_new); + if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out_new, in_anchor, trans_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } else { + if (ProcessFp16Nc1hwc0Dynamic(src_op_desc, node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessFp16Nc1hwc0Dynamic failed"); + return FAILED; + } + } + return SUCCESS; +} + +Status ProcessOutputDynamic(const NodePtr &src_node, NodePtr &node, ge::DataType &output_data_type) { + OpDescPtr src_op_desc = src_node->GetOpDesc(); + const GeTensorDescPtr &merge_output = src_op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(merge_output); + merge_output->SetDataType(output_data_type); + merge_output->SetOriginDataType(output_data_type); + // Update input + for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { + const GeTensorDescPtr &merge_input = src_op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(merge_input); + ge::GeShape src_shape = merge_input->GetShape(); + ge::Format src_format = merge_input->GetFormat(); + ge::DataType src_dtype = merge_input->GetDataType(); + merge_input->SetDataType(output_data_type); + merge_input->SetOriginDataType(output_data_type); + const InDataAnchorPtr &dst_in_anchor = src_node->GetInDataAnchor(i); + GE_CHECK_NOTNULL(dst_in_anchor); + const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + NodePtr cast_node = CreateCastOp(src_shape, src_dtype, output_data_type, src_format, node); + if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, cast_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } + return SUCCESS; +} + +Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + ge::DataType output_data_type = ge::DT_FLOAT; + bool is_set_output_type = CheckIfSetOutputType(output_type, output_data_type); + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + auto index = static_cast(in_anchor->GetIdx()); auto peer_out = in_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out); - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, DT_FLOAT16, src_format, node); - GE_CHECK_NOTNULL(cast_node); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + auto src_index = static_cast(peer_out->GetIdx()); + auto src_node = peer_out->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + bool is_dynamic = CheckOpType(src_node, MERGE); + + OpDescPtr src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + auto net_output_input_desc = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(net_output_input_desc); + + ge::GeShape src_shape = src_op_desc->GetOutputDesc(src_index).GetShape(); + ge::Format src_format = src_op_desc->GetOutputDesc(src_index).GetFormat(); + ge::DataType src_dtype = src_op_desc->GetOutputDesc(src_index).GetDataType(); + // Update datatype + if (is_set_output_type) { + GELOGI("Enter into process output_type schedule"); + if (src_dtype == output_data_type) { + GELOGI("Data type is same ,no need to transfer."); + continue; + } + if (!is_dynamic) { + NodePtr cast_node = CreateCastOp(src_shape, src_dtype, output_data_type, src_format, node); + if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); + return FAILED; + } + } else { + // Update outputdesc + if (ProcessOutputDynamic(src_node, node, output_data_type) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessOutputDynamic failed"); + return FAILED; + } + } + net_output_input_desc->SetDataType(output_data_type); + net_output_input_desc->SetOriginDataType(output_data_type); + continue; + } + // output_node is not set,check if is_output_adjust_hw_layout is set + bool set_fp16_nc1hwc0 = false; + if (!is_dynamic) { + (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + } else { + // need check dynamic scene, graph structure: node->merge->netoutput + const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); + GE_CHECK_NOTNULL(merge_input_anchor); + const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(src_out_anchor); + auto src_merge_node = src_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_merge_node); + auto src_merge_node_opdesc = src_merge_node->GetOpDesc(); + (void)AttrUtils::GetBool(src_merge_node_opdesc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + } + if (set_fp16_nc1hwc0) { + GELOGI("Node [%s] should be set FP16 and NC1HWC0", src_op_desc->GetName().c_str()); + if ((src_format != FORMAT_NCHW) && (src_format != FORMAT_NHWC) && (src_format != FORMAT_NC1HWC0)) { + GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); + return FAILED; + } + GeTensorDesc src_desc(src_shape, src_format, src_dtype); + if (ProcessNetoutputNodeFp16Nc1hwc0(src_desc, in_anchor, net_output_input_desc, src_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); + return FAILED; + } + } + } + return SUCCESS; +} + +Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { + auto op = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op); + auto inputDescsPtr = op->GetAllInputsDescPtr(); + auto outputDescsPtr = op->GetAllOutputsDescPtr(); + ge::Format format = ge::FORMAT_ND; + // if user set shape larger than 4, inferformat may set NCHW or NHWC, GE should set ND before FE + // process, otherwise fe will insert transdata. + for (auto &inputDescPtr : inputDescsPtr) { + GE_CHECK_NOTNULL(inputDescPtr); + if ((inputDescPtr->GetShape().GetDims().size() > ge::DIM_DEFAULT_SIZE) && + ((inputDescPtr->GetFormat() == ge::FORMAT_NCHW) || (inputDescPtr->GetFormat() == ge::FORMAT_NHWC))) { + GELOGI("The node inputdesc [%s] format need to be set ND", op->GetName().c_str()); + inputDescPtr->SetFormat(format); + inputDescPtr->SetOriginFormat(format); + } + } + for (auto &outputDescPtr : outputDescsPtr) { + GE_CHECK_NOTNULL(outputDescPtr); + if ((outputDescPtr->GetShape().GetDims().size() > ge::DIM_DEFAULT_SIZE) && + ((outputDescPtr->GetFormat() == ge::FORMAT_NCHW) || (outputDescPtr->GetFormat() == ge::FORMAT_NHWC))) { + GELOGI("The node outputdesc [%s] format need to be set ND", op->GetName().c_str()); + outputDescPtr->SetFormat(format); + outputDescPtr->SetOriginFormat(format); + } + } + return SUCCESS; +} + +// A new function ending in 'DynShape' has been added for the dynamic shape processing. +// In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole +// graph optimization, GE only sets the final data_type/format/shape information for variable, +// data and netoutput, and no longer inserts the transnode. +Status ProcessInputFP16DynShape(NodePtr &node_ptr) { + GE_CHECK_NOTNULL(node_ptr); + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const GeTensorDescPtr &input = op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(input); + ge::DataType src_dtype = input->GetDataType(); + if (src_dtype == DT_FLOAT16) { + GELOGI("The node name, %s dtype is fp16", node_ptr->GetName().c_str()); + return SUCCESS; + } + input->SetDataType(DT_FLOAT16); + input->SetOriginDataType(DT_FLOAT16); + int64_t shape_size = 0; + ge::graphStatus graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, shape_size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(graph_status, "GetTensorSizeInBytes failed!"); + return FAILED; + } + ge::TensorUtils::SetSize(*input, shape_size); + const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); + GE_CHECK_NOTNULL(output); + output->SetDataType(DT_FLOAT16); + output->SetOriginDataType(DT_FLOAT16); + ge::TensorUtils::SetSize(*output, shape_size); + + return SUCCESS; +} + +Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr) { + GE_CHECK_NOTNULL(node_ptr); + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const GeTensorDescPtr &input = op_desc->MutableInputDesc(0); + GE_CHECK_NOTNULL(input); + ge::Format old_format = input->GetFormat(); + ge::GeShape old_shape = input->GetShape(); + bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC)); + if (!support) { + GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); + return FAILED; + } + if (old_format == FORMAT_NC1HWC0) { + GELOGI("No need to transfer format"); + return SUCCESS; + } + if (ModifyInputFormatAndShape(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + return FAILED; + } + + return SUCCESS; +} + +Status ProcessDataNodeDynShape(NodePtr &node_ptr) { + bool set_fp16 = false; + if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_fp16", set_fp16) || !set_fp16) { + return SUCCESS; + } + for (auto const &next_node : node_ptr->GetOutNodes()) { + if (next_node->GetType() == AIPP) { + GELOGE(INTERNAL_ERROR, + "This input node [%s] is linked to aipp, can not be set to fp16," + "please check your atc parma insert_op_conf, input_fp16_nodes.", + node_ptr->GetName().c_str()); return FAILED; } + } + GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); + if (ProcessInputFP16DynShape(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); + return FAILED; + } + // check if need to set format + bool set_format = false; + if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format) || !set_format) { + return SUCCESS; + } + GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); + if (ProcessInputNC1HWC0DynShape(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); + return FAILED; + } + return SUCCESS; +} + +Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorDescPtr &net_output_input_desc, + NodePtr &node) { + ge::GeShape src_shape = src_desc.GetShape(); + ge::Format src_format = src_desc.GetFormat(); + ge::DataType src_dtype = src_desc.GetDataType(); + if (src_dtype != DT_FLOAT16) { net_output_input_desc->SetDataType(DT_FLOAT16); net_output_input_desc->SetOriginDataType(DT_FLOAT16); } @@ -647,14 +1128,6 @@ Status ProcessNetoutputNodeFp16Nc1hwc0(GeTensorDesc &src_desc, const InDataAncho return FAILED; } ge::GeShape dst_shape(dst_shape_dims); - NodePtr trans_node = CreateTransdataNode(src_shape, src_format, dst_shape, FORMAT_NC1HWC0, DT_FLOAT16, node); - GE_CHECK_NOTNULL(trans_node); - auto peer_out_new = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_new); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out_new, in_anchor, trans_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } net_output_input_desc->SetFormat(FORMAT_NC1HWC0); net_output_input_desc->SetOriginFormat(FORMAT_NC1HWC0); net_output_input_desc->SetShape(dst_shape); @@ -662,7 +1135,7 @@ Status ProcessNetoutputNodeFp16Nc1hwc0(GeTensorDesc &src_desc, const InDataAncho return SUCCESS; } -Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { +Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); ge::DataType output_data_type = ge::DT_FLOAT; @@ -685,15 +1158,6 @@ Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { GE_CHECK_NOTNULL(src_op_desc); auto net_output_input_desc = op_desc->MutableInputDesc(index); GE_CHECK_NOTNULL(net_output_input_desc); - auto net_output_output_desc = op_desc->MutableOutputDesc(index); - GE_CHECK_NOTNULL(net_output_output_desc); - // Update netoutput outputdesc - net_output_output_desc->SetDataType(net_output_input_desc->GetDataType()); - net_output_output_desc->SetOriginDataType(net_output_input_desc->GetDataType()); - net_output_output_desc->SetFormat(net_output_input_desc->GetFormat()); - net_output_output_desc->SetOriginFormat(net_output_input_desc->GetFormat()); - net_output_output_desc->SetShape(net_output_input_desc->GetShape()); - net_output_output_desc->SetOriginShape(net_output_input_desc->GetShape()); ge::GeShape src_shape = src_op_desc->GetOutputDesc(src_index).GetShape(); ge::Format src_format = src_op_desc->GetOutputDesc(src_index).GetFormat(); @@ -705,15 +1169,8 @@ Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { GELOGI("Data type is same ,no need to transfer."); continue; } - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, output_data_type, src_format, node); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } net_output_input_desc->SetDataType(output_data_type); net_output_input_desc->SetOriginDataType(output_data_type); - net_output_output_desc->SetDataType(output_data_type); - net_output_output_desc->SetOriginDataType(output_data_type); continue; } // output_node is not set,check if is_output_adjust_hw_layout is set @@ -726,21 +1183,16 @@ Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { return FAILED; } GeTensorDesc src_desc(src_shape, src_format, src_dtype); - if (ProcessNetoutputNodeFp16Nc1hwc0(src_desc, in_anchor, net_output_input_desc, node) != SUCCESS) { + if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(src_desc, net_output_input_desc, node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); return FAILED; } - net_output_output_desc->SetDataType(net_output_input_desc->GetDataType()); - net_output_output_desc->SetOriginDataType(net_output_input_desc->GetDataType()); - net_output_output_desc->SetFormat(net_output_input_desc->GetFormat()); - net_output_output_desc->SetOriginFormat(net_output_input_desc->GetFormat()); - net_output_output_desc->SetShape(net_output_input_desc->GetShape()); - net_output_output_desc->SetOriginShape(net_output_input_desc->GetShape()); } } } return SUCCESS; } + } // namespace GraphPrepare::GraphPrepare() : compute_graph_(nullptr) {} @@ -758,7 +1210,7 @@ Status GraphPrepare::UpdateVariableFormats(ComputeGraphPtr &graph) { if (node == nullptr) { continue; } - if (node->GetType() != domi::VARIABLE) { + if (node->GetType() != VARIABLE) { continue; } auto trans_road = VarManager::Instance(graph->GetSessionID())->GetTransRoad(node->GetName()); @@ -787,8 +1239,48 @@ Status GraphPrepare::UpdateVariableFormats(ComputeGraphPtr &graph) { return SUCCESS; } -void GraphPrepare::SetOptions(const ge::GraphManagerOptions &options) { options_ = options; } - +Status GraphPrepare::UpdateVariableFormatsDynShape(ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + auto var_names_to_refs = CollectVarNamesToRefs(graph); + for (auto &node : graph->GetAllNodes()) { + if (node == nullptr) { + continue; + } + if (node->GetType() != VARIABLE) { + continue; + } + auto trans_road = VarManager::Instance(graph->GetSessionID())->GetTransRoad(node->GetName()); + if (trans_road == nullptr) { + GELOGD("The variable %s does not have any trans road", node->GetName().c_str()); + continue; + } + + GELOGI("Recover the trans road for var %s reversely", node->GetName().c_str()); + + if (!(trans_road->empty())) { + auto ret = UpdateVarFormats(node, trans_road->rbegin()->output); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to update var formats for var %s", node->GetName().c_str()); + return INTERNAL_ERROR; + } + } + + auto iter = var_names_to_refs.find(node->GetName()); + if (iter != var_names_to_refs.end()) { + for (auto &var : iter->second) { + if (!(trans_road->empty()) && (UpdateVarFormats(var, trans_road->rbegin()->input) != SUCCESS)) { + GELOGE(INTERNAL_ERROR, "Failed to update var formats for ref var %s", var->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + } + + return SUCCESS; +} + +void GraphPrepare::SetOptions(const ge::GraphManagerOptions &options) { options_ = options; } + Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { compute_graph_ = GraphUtils::GetComputeGraph(graph); if (compute_graph_ != nullptr) { @@ -799,7 +1291,7 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { GELOGE(ret, "RunGraph graph check fail, ret:%u", ret); return ret; } - compute_graph_->TopologicalSorting(); + (void)compute_graph_->TopologicalSorting(); ret = CheckRefOp(); if (ret != SUCCESS) { GELOGE(ret, "RunGraph check ref op fail, ret:%u", ret); @@ -831,9 +1323,12 @@ Status GraphPrepare::CheckGraph() { Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::unordered_set &ref_nodes) { - static std::unordered_set acceptable_types = { - domi::VARIABLE, domi::VARIABLEV2, domi::VARHANDLEOP, domi::REFSWITCH, - domi::REFMERGE, domi::REFENTER, domi::REFNEXTITERATION, domi::REFEXIT}; + // Acceptable input types should be ref node, variable or Switch operator, which is issued by ME for dynamic + // lossscale and would be optimized in SwitchOpPass. Since ME dont differentiate between RefSwitch and Switch, + // and only issue Switch. + static std::unordered_set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, + ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, + ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -852,7 +1347,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i return SUCCESS; } auto input_type = input_op_desc->GetType(); - if (input_type == domi::FRAMEWORKOP) { + if (input_type == ge::FRAMEWORKOP) { if (!ge::AttrUtils::GetStr(input_op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, input_type)) { GELOGE(PARAM_INVALID, "Get original type failed."); return PARAM_INVALID; @@ -944,7 +1439,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GE_CHECK_NOTNULL(input_node); OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == domi::DATA) { + if (op->GetType() == DATA) { GeAttrValue::INT index = 0; if ((!(AttrUtils::GetInt(op, ATTR_NAME_INDEX, index))) || (domi::GetContext().is_dynamic_input)) { GELOGW("Get index from data attr failed"); @@ -977,7 +1472,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { int64_t desc_shape = desc.GetShape().GetShapeSize(); FMK_INT64_UINT32_MULCHECK(desc_shape, length); int64_t shape_size = desc_shape * length; - GE_IF_BOOL_EXEC(shape_size == 0, shape_size = static_cast(length)); + GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); int64_t size = 0; GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "TensorUtils GetSize failed"); @@ -1111,6 +1606,10 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { GE_RETURN_IF_ERROR(InsertNewOpUtil::Instance().UpdateDataNodeByAipp(compute_graph_)); for (auto &node_ptr : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); + if (CheckIfNeedSetNdFormat(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Set node [%s] format ND failed", node_ptr->GetName().c_str()); + return FAILED; + } if (node_ptr->GetType() == DATA) { if (ProcessDataNode(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process data node failed"); @@ -1118,7 +1617,7 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { } } - if (node_ptr->GetType() == domi::NETOUTPUT) { + if (node_ptr->GetType() == ge::NETOUTPUT) { if (ProcessNetoutputNode(node_ptr, options_.output_datatype) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process netoutput node failed"); return FAILED; @@ -1193,7 +1692,7 @@ Status GraphPrepare::OptimizeBeforeInfershape() { return SUCCESS; } -void GraphPrepare::SaveOriginalGraphToOmModel() { +Status GraphPrepare::SaveOriginalGraphToOmModel() { if (options_.save_original_model == "true") { ModelHelper model_helper; Status ret = model_helper.SaveOriginalGraphToOmModel(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph_), @@ -1203,6 +1702,7 @@ void GraphPrepare::SaveOriginalGraphToOmModel() { GELOGW("SaveOriginalGraphToOmModel fail"); } } + return SUCCESS; } Status GraphPrepare::Preprocess(const std::vector &user_input) { @@ -1213,34 +1713,27 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { return ret; } - ret = CheckUserInput(user_input); + ret = CheckAndUpdateInput(user_input); if (ret != SUCCESS) { GELOGE(ret, "Check user input failed."); return ret; } + GraphUtils::DumpGEGraph(compute_graph_, "after_update_input"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_update_input"); - compute_graph_->SetInputSize(user_input.size()); - - ret = UpdateInput(user_input); + GEPass ge_passes(compute_graph_); + NamesToPass names_to_passes; + ForPass for_pass; + names_to_passes.emplace_back("ForPass", &for_pass); + GE_TIMESTAMP_START(names_to_passes); + ret = ge_passes.Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::ForPass"); if (ret != SUCCESS) { - GELOGE(ret, "UpdateInput fail, ret:%u", ret); + GELOGE(ret, "Run ForPass optimize for preprocess failed, ret:%u.", ret); return ret; } - GraphUtils::DumpGEGraph(compute_graph_, "after_update_input"); - GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_update_input"); - if (user_input.size() != 0) { - ret = CheckConstOp(); - if (ret != SUCCESS) { - GELOGE(ret, "CheckConstOp fail, ret:%u", ret); - return ret; - } - } else { - ret = compute_graph_->TopologicalSorting(); - if (ret != SUCCESS) { - GELOGE(ret, "graph prepare error: compute_graph_->Topological Sorting"); - return FAILED; - } - } + GraphUtils::DumpGEGraph(compute_graph_, "after_for_pass"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "after_for_pass"); GE_TIMESTAMP_START(netoutput_process); ret = ProcessNetOutput(); @@ -1306,8 +1799,76 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { return SUCCESS; } +#define PP_RUN_AND_DUMP(name, func, ...) \ + do { \ + GE_RUN(Prepare, func, __VA_ARGS__); \ + GraphUtils::DumpGEGraph(compute_graph, "PrepareAfter" name); \ + GraphUtils::DumpGEGraphToOnnx(*compute_graph, "PrepareAfter" name); \ + GELOGI("Prepare %s on graph %s success.", name, compute_graph->GetName().c_str()); \ + } while (0) + +#define PP_RUN(name, func, ...) \ + do { \ + GE_RUN(Prepare, func, __VA_ARGS__); \ + GELOGI("Prepare %s on graph %s success.", name, compute_graph->GetName().c_str()); \ + } while (0) + +Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vector &user_input, + ge::ComputeGraphPtr &compute_graph, uint64_t session_id) { + GE_CHECK_NOTNULL(graph); + GE_CHECK_NOTNULL(compute_graph); + if (options_.train_graph_flag) { + domi::GetContext().train_flag = true; + } + domi::GetContext().type = static_cast(options_.framework_type); + const Graph &const_graph = *graph; + + PP_RUN("Init", Init, const_graph, session_id); + PP_RUN("SetRtContext", SetRtContext, rtContext_t(), RT_CTX_GEN_MODE); + PP_RUN_AND_DUMP("CheckAndUpdateInput", CheckAndUpdateInput, user_input); + PP_RUN_AND_DUMP("ProcessOutput", ProcessNetOutput); + PP_RUN_AND_DUMP("ProcessMultiBatch", multibatch::ProcessMultiBatch, compute_graph_); + PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); + PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); + PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); + // todo: return when save mode + PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); + PP_RUN_AND_DUMP("PrepareOptimize", PrepareOptimize); + PP_RUN_AND_DUMP("UpdateInputOutputByUserOptions", UpdateInputOutputByOptions); + PP_RUN_AND_DUMP("UpdateVariableByRunningFormat", UpdateVariableFormatsDynShape, compute_graph_); + + return SUCCESS; +} + +#undef PP_RUN_AND_DUMP +#undef PP_RUN + +Status GraphPrepare::GenerateInfershapeGraph(ConstGraphPtr graph) { + if (graph == nullptr) { + GELOGE(GE_GRAPH_NULL_INPUT, "Input Graph is NULL"); + return GE_GRAPH_NULL_INPUT; + } + const Graph &const_graph = *graph; + Status ret = Init(const_graph, 0); + if (ret != SUCCESS) { + GELOGE(ret, "Init graph_prepare fail, ret:%u", ret); + return ret; + } + GELOGI("Start infershape for dump json."); + GEPass ge_passes(compute_graph_); + NamesToPass names_to_passes; + InferShapePass infer_shape_pass; + names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + ret = ge_passes.Run(names_to_passes); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes infershape for preprocess failed, ret:%u.", ret); + return ret; + } + return SUCCESS; +} + Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &user_input, - ge::ComputeGraphPtr &compute_graph, uint64_t session_id) { + ge::ComputeGraphPtr &compute_graph, VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id) { // train graph flag if (options_.train_graph_flag) { domi::GetContext().train_flag = true; @@ -1325,6 +1886,19 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u return ret; } + GraphOptimize graph_optimize; + if (!domi::GetContext().train_flag) { + GraphUtils::DumpGEGraph(compute_graph_, "BeforeOriginalGraphForQuantize"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "BeforeOriginalGraphForQuantize"); + GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); + ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); + GE_TIMESTAMP_END(OptimizeOriginalGraphForQuantize, "GraphPrepare::OptimizeOriginalGraphForQuantize"); + if (ret != SUCCESS) { + GELOGE(ret, "originalGraph optimize for Quantize Failed"); + return ret; + } + } + GraphUtils::DumpGEGraph(compute_graph_, "BeforePreprocess"); GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "BeforePreprocess"); @@ -1336,7 +1910,6 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u return ret; } // OriginalGraph optimize - GraphOptimize graph_optimize; ret = graph_optimize.SetOptions(options_); GE_CHK_STATUS_RET(ret, "Graph optimize initial fail"); if (options_.local_fmk_op_flag) { @@ -1346,25 +1919,28 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u GraphUtils::DumpGEGraph(compute_graph_, "Prepare"); GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "Prepare"); - if (!domi::GetContext().train_flag) { - GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); - ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); - GE_TIMESTAMP_END(OptimizeOriginalGraphForQuantize, "GraphPrepare::OptimizeOriginalGraphForQuantize"); - if (ret != SUCCESS) { - GELOGE(ret, "originalGraph optimize for Quantize Failed"); - return ret; - } - } GE_TIMESTAMP_START(OptimizeOriginalGraph); - ret = graph_optimize.OptimizeOriginalGraph(compute_graph_); + const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); + if (buffer_optimize_on != nullptr) { + ret = graph_optimize.NewOptimizeOriginalGraph(compute_graph_); + } else { + ret = graph_optimize.OptimizeOriginalGraph(compute_graph_); + } GE_TIMESTAMP_END(OptimizeOriginalGraph, "GraphPrepare::OptimizeOriginalGraph"); + GraphUtils::DumpGEGraph(compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); + GraphUtils::DumpGEGraphToOnnx(*compute_graph_, "PreProcessOptimizeOriginalGraphAfter"); if (ret != SUCCESS) { GELOGE(ret, "originalGraph optimize Failed"); return ret; } GE_TIMESTAMP_START(OptimizeBeforeSubGraph); - ret = OptimizeGraphBeforeSubGraph(); + + if (buffer_optimize_on != nullptr) { + ret = NewOptimizeGraphBeforeSubGraph(var_acc_ctrl); + } else { + ret = OptimizeGraphBeforeSubGraph(); + } GE_TIMESTAMP_END(OptimizeBeforeSubGraph, "GraphPrepare::OptimizeBeforeSubGraph"); if (ret != SUCCESS) { GELOGE(ret, "originalGraph optimize Failed"); @@ -1377,10 +1953,10 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u Status GraphPrepare::CheckConstOp() { for (auto &node_ptr : compute_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr); - if (node_ptr->GetType() == domi::CONSTANT) { + if (node_ptr->GetType() == CONSTANT) { Status ret = VerifyConstOp(node_ptr); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Const Op Check failed"); - } else if (node_ptr->GetType() == domi::FRAMEWORKOP) { + } else if (node_ptr->GetType() == FRAMEWORKOP) { auto op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(PARAM_INVALID, "Get op desc failed"); @@ -1390,7 +1966,7 @@ Status GraphPrepare::CheckConstOp() { GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type), GELOGI("Get FrameWorkOp original type [%s]", original_type.c_str())); GELOGI("original type is %s", original_type.c_str()); - if (original_type == domi::CONSTANT) { + if (original_type == CONSTANT) { Status ret = VerifyConstOp(node_ptr); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Const Op Check failed"); } @@ -1421,9 +1997,17 @@ Status GraphPrepare::VerifyConstOp(const NodePtr &node) { FMK_INT64_UINT32_MULCHECK(shape_size, length); GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, shape_size * length, TypeUtils::DataTypeToSerialString(data_type).c_str()); - if ((shape_size != 0) || (data_size / length != 1)) { - GE_CHK_BOOL_EXEC(data_size == static_cast(shape_size * length) && data_size != 0, - return GRAPH_PARAM_INVALID, "Const input data size is not equal with tensor desc shape"); + if (shape_size == 0) { + if (ge_tensor_desc.GetShape().GetDims().size() == 0) { + // shape = [], means it's a sclar tensor. + GE_CHK_BOOL_EXEC(data_size / length == 1, return PARAM_INVALID, "Const is invalid scalar tensor."); + } else { + // shape = [x, y, 0,...], means it's a vector tensor that value is []. + GE_CHK_BOOL_EXEC(data_size == 0, return PARAM_INVALID, "Const is invalid vector scalar."); + } + } else { + GE_CHK_BOOL_EXEC(data_size == static_cast(shape_size * length) && data_size != 0, return PARAM_INVALID, + "Const input data size is not equal with tensor desc shape"); } return SUCCESS; } @@ -1439,7 +2023,7 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); node_num++; - if (op->GetType() == domi::DATA || op->GetType() == domi::AIPPDATA) { + if (op->GetType() == DATA || op->GetType() == AIPPDATA) { data_num++; GeAttrValue::INT index = 0; if (!(AttrUtils::GetInt(op, ATTR_NAME_INDEX, index))) { @@ -1453,8 +2037,8 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { - if (desc.GetShape().GetDim(i) <= 0) { - GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need > 0, real:%ld.", i, + if (desc.GetShape().GetDim(i) < 0) { + GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, desc.GetShape().GetDim(i)); return GE_GRAPH_INIT_FAILED; } @@ -1513,7 +2097,113 @@ Status GraphPrepare::InferShapeForPreprocess() { } return SUCCESS; } +Status GraphPrepare::PrepareOptimize() { + GELOGI("Start optimize for preprocess."); + PassManager original_graph_passes; + // Graph pass + try { + (void)original_graph_passes.AddPass(new ShapeOperateOpRemovePass); + } catch (std::bad_alloc &e) { + GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); + return INTERNAL_ERROR; + } + + GE_TIMESTAMP_START(original_graph_passes); + Status ret = original_graph_passes.Run(compute_graph_); + GE_TIMESTAMP_END(original_graph_passes, "GraphPrepare::OriginalGraphPasses"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); + return ret; + } + // New pass + GEPass ge_passes(compute_graph_); + NamesToPass names_to_passes; + EnterPass enter_pass; + PrintOpPass print_pass; + names_to_passes.emplace_back("EnterPass", &enter_pass); + if (options_.enable_print_op_pass) { + names_to_passes.emplace_back("PrintOpPass", &print_pass); + } + NoUseReshapeRemovePass no_use_reshape_remove_pass; + names_to_passes.emplace_back("NoUseReshapeRemovePass", &no_use_reshape_remove_pass); + DropOutPass dropout_pass; + AssertPass assert_pass; + UnusedConstPass unused_const_pass; + StopGradientPass stop_gradient_pass; + PreventGradientPass prevent_gradient_pass; + PlaceholderWithDefaultPass placeholder_with_default_pass; + GuaranteeConstPass guarantee_const_pass; + VarIsInitializedOpPass var_is_initialized_pass; + IdentityPass identity_pass(false); + SnapshotPass snapshot_pass; + if (!options_.train_graph_flag) { + names_to_passes.emplace_back("DropOutPass", &dropout_pass); + names_to_passes.emplace_back("AssertPass", &assert_pass); + } + names_to_passes.emplace_back("UnusedConstPass", &unused_const_pass); + names_to_passes.emplace_back("StopGradientPass", &stop_gradient_pass); + names_to_passes.emplace_back("PreventGradientPass", &prevent_gradient_pass); + names_to_passes.emplace_back("PlaceholderWithDefaultPass", &placeholder_with_default_pass); + names_to_passes.emplace_back("SnapshotPass", &snapshot_pass); + names_to_passes.emplace_back("GuaranteeConstPass", &guarantee_const_pass); + names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); + names_to_passes.emplace_back("IdentityPass", &identity_pass); + GE_TIMESTAMP_START(names_to_passes); + ret = ge_passes.Run(names_to_passes); + GE_TIMESTAMP_END(names_to_passes, "GraphPrepare::NamesToPasses"); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes optimize for preprocess failed, ret:%u.", ret); + return ret; + } + + PassManager graph_pass; + try { + (void)graph_pass.AddPass(new PrunePass); + (void)graph_pass.AddPass(new NextIterationPass); + (void)graph_pass.AddPass(new ControlTriggerPass); + (void)graph_pass.AddPass(new SwitchOpPass); + } catch (std::bad_alloc &e) { + GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); + return INTERNAL_ERROR; + } + + ret = graph_pass.Run(compute_graph_); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run graph passes optimize for preprocess failed, ret:%u.", ret); + return ret; + } + + NamesToPass identity_remove_pass; + GE_TIMESTAMP_START(identity_remove_pass); + IdentityPass identity_force_pass(true); // after SwitchOpPass + identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); + ret = ge_passes.Run(identity_remove_pass); + GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); + if (ret != SUCCESS) { + GELOGE(ret, "Run identity remove pass for preprocess failed, ret:%u.", ret); + return ret; + } + // The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. + if (options_.train_graph_flag) { + for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { + // This can ensure that n is not a null pointer + if (n->GetOpDesc()->GetType() == CONSTANT) { + n->GetOpDesc()->SetType(CONSTANTOP); + } + } + } + + ret = compute_graph_->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "Graph topological sort failed, ret:%u.", ret); + return ret; + } + + GELOGI("End optimize for preprocess."); + + return SUCCESS; +} Status GraphPrepare::OptimizeForPreprocess() { GELOGI("Start optimize for preprocess."); PassManager original_graph_passes; @@ -1523,6 +2213,7 @@ Status GraphPrepare::OptimizeForPreprocess() { (void)original_graph_passes.AddPass(new VariablePrepareOpPass); (void)original_graph_passes.AddPass(new IteratorOpPass); (void)original_graph_passes.AddPass(new ShapeOperateOpRemovePass); + (void)original_graph_passes.AddPass(new ReplaceTransShapePass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -1540,6 +2231,8 @@ Status GraphPrepare::OptimizeForPreprocess() { NamesToPass names_to_passes; EnterPass enter_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); + CondPass cond_pass; + names_to_passes.emplace_back("CondPass", &cond_pass); AddNPass addn_pass; names_to_passes.emplace_back("AddNPass", &addn_pass); PrintOpPass print_pass; @@ -1570,6 +2263,8 @@ Status GraphPrepare::OptimizeForPreprocess() { names_to_passes.emplace_back("GuaranteeConstPass", &guarantee_const_pass); VarIsInitializedOpPass var_is_initialized_pass; names_to_passes.emplace_back("VarIsInitializedOpPass", &var_is_initialized_pass); + ParallelConcatStartOpPass parallel_concat_start_op_pass; + names_to_passes.emplace_back("ParallelConcatStartOpPass", ¶llel_concat_start_op_pass); IdentityPass identity_pass(false); names_to_passes.emplace_back("IdentityPass", &identity_pass); SwitchPass switch_pass; @@ -1593,8 +2288,7 @@ Status GraphPrepare::OptimizeForPreprocess() { (void)graph_pass.AddPass(new ControlTriggerPass); (void)graph_pass.AddPass(new SwitchOpPass); (void)graph_pass.AddPass(new HcclMemcpyPass); - GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass(new FlowCtrlPass);) - (void)graph_pass.AddPass(new EndGraphPass); + GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass(new FlowCtrlPass);); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR; @@ -1620,8 +2314,8 @@ Status GraphPrepare::OptimizeForPreprocess() { if (options_.train_graph_flag) { for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { // This can ensure that n is not a null pointer - if (n->GetOpDesc()->GetType() == domi::CONSTANT) { - n->GetOpDesc()->SetType(domi::CONSTANTOP); + if (n->GetOpDesc()->GetType() == CONSTANT) { + n->GetOpDesc()->SetType(CONSTANTOP); } } } @@ -1656,8 +2350,91 @@ Status GraphPrepare::ProcessNetOutput() { } return SUCCESS; } + +Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl) { + GELOGD("NewOptimizeGraphBeforeSubGraph in"); + PassManager passes; + (void)passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass); + auto ret = passes.Run(compute_graph_); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to optimize for graph"); + return ret; + } + + GEPass ge_passes_for_shape(compute_graph_); + NamesToPass names_to_passes_for_shape; + IdentifyReferencePass identify_reference_pass; + names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); + CastRemovePass cast_remove_pass; + names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); + TransposeTransDataPass transpose_transdata_pass; + names_to_passes_for_shape.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); + GE_TIMESTAMP_START(ge_passes_for_shape); + ret = ge_passes_for_shape.Run(names_to_passes_for_shape); + GE_TIMESTAMP_END(ge_passes_for_shape, "GraphManager::GePassesForShape"); + if (ret != SUCCESS) { + GELOGE(ret, "Run ge_passes_for_shape optimize for OptimizeGraphBeforeSubGraph failed, ret:%d.", ret); + return ret; + } + + string options = "default"; + if (GetContext().GetOption("ge.exec.variable_acc", options) != SUCCESS) { + GELOGI("get ge.exec.variable_acc failed. set default value."); + } + PassManager pass_manager; + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) PermutePass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariablePrepareOpPass)) + GE_IF_BOOL_EXEC(options == "default" || options == "1", GELOGI("turn on variable accelerator"); + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariableOpPass(&var_acc_ctrl)))) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpDepthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpBreadthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) VariableRefDeleteOpPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) SameTransdataBreadthFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) TransOpWithoutReshapeFusionPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass(new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))) + + GE_TIMESTAMP_START(pass_manager); + ret = pass_manager.Run(compute_graph_); + GE_TIMESTAMP_END(pass_manager, "GraphManager::BeforeSubGraph"); + if (ret != SUCCESS && ret != NOT_CHANGED) { + GELOGE(ret, "Run passes after merge sub graph failed, ret:%d.", ret); + return ret; + } + + // add variable attr for hccl broadcast,need to be removed after variable pass online + for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + if (node->GetOpDesc()->GetType() != VARIABLE) { + continue; + } + if (IsBroadCastOpData(node)) { + AdjustBroadCastOpData(node); + } + if (IsAssignOpData(node)) { + AdjustAssignOpData(node); + } + } + + NamesToPass names_to_passes; + TransOpNearbyAllreduceFusionPass trans_op_nearby_allreduce_fusion_pass; + names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); + ReshapeRemovePass reshape_remove_pass; + names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); + ConstantFoldingPass constant_folding_pass; + names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); + DimensionAdjustPass dimension_adjust_pass; + names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); + GEPass ge_passes(compute_graph_); + ret = ge_passes.Run(names_to_passes); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to optimize for graph"); + return ret; + } + return SUCCESS; +} + Status GraphPrepare::OptimizeGraphBeforeSubGraph() { PassManager passes; + (void)passes.AddPass(new (std::nothrow) VariablePrepareOpPass); (void)passes.AddPass(new (std::nothrow) CommonSubexpressionEliminationPass); auto ret = passes.Run(compute_graph_); if (ret != SUCCESS) { @@ -1665,7 +2442,9 @@ Status GraphPrepare::OptimizeGraphBeforeSubGraph() { return ret; } ConstantFoldingPass constant_folding_pass; + DimensionComputePass dimension_compute_pass; NamesToPass names_to_passes; + names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); GEPass ge_passes(compute_graph_); ret = ge_passes.Run(names_to_passes); @@ -1675,4 +2454,137 @@ Status GraphPrepare::OptimizeGraphBeforeSubGraph() { } return SUCCESS; } +Status GraphPrepare::CheckAndUpdateInput(const std::vector &user_input) { + auto ret = CheckUserInput(user_input); + if (ret != SUCCESS) { + GELOGE(ret, "Check user input failed."); + return ret; + } + + compute_graph_->SetInputSize(user_input.size()); + + ret = UpdateInput(user_input); + if (ret != SUCCESS) { + GELOGE(ret, "UpdateInput fail, ret:%u", ret); + return ret; + } + if (user_input.size() != 0) { + ret = CheckConstOp(); + if (ret != SUCCESS) { + GELOGE(ret, "CheckConstOp fail, ret:%u", ret); + return ret; + } + } else { + ret = compute_graph_->TopologicalSorting(); + if (ret != SUCCESS) { + GELOGE(ret, "graph prepare error: compute_graph_->Topological Sorting"); + return FAILED; + } + } + return SUCCESS; +} +Status GraphPrepare::UpdateInputOutputByOptions() { + if (options_.train_graph_flag) { + GELOGI("This is train mode, no need to do this schedule."); + return SUCCESS; + } + for (auto &node_ptr : compute_graph_->GetDirectNode()) { + GE_CHECK_NOTNULL(node_ptr); + if (CheckIfNeedSetNdFormat(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Set node [%s] format ND failed", node_ptr->GetName().c_str()); + return FAILED; + } + // todo do not insert trans op + if (node_ptr->GetType() == DATA) { + if (ProcessDataNodeDynShape(node_ptr) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process data node failed"); + return FAILED; + } + } + + if (node_ptr->GetType() == ge::NETOUTPUT) { + if (ProcessNetoutputNodeDynShape(node_ptr, options_.output_datatype) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process netoutput node failed"); + return FAILED; + } + } + } + return SUCCESS; +} + +bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { + for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(out_anchor); + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(in_anchor); + ge::NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_RT_FALSE_CHECK_NOTNULL(dst_node); + if (dst_node->GetType() == HCOMBROADCAST) { + return true; + } + } + } + return false; +} + +bool GraphPrepare::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, + const map> &confirm_ops, ge::NodePtr &use_node) { + GE_RT_FALSE_CHECK_NOTNULL(in_anchor); + ge::NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_RT_FALSE_CHECK_NOTNULL(dst_node); + ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); + GE_RT_FALSE_CHECK_NOTNULL(dst_op_desc); + const string &dst_type = dst_op_desc->GetType(); + int input_index = in_anchor->GetIdx(); + + GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), + dst_type.c_str(), input_index); + + if (confirm_ops.count(dst_type) > 0) { + if (confirm_ops.at(dst_type).count(input_index) > 0) { + use_node = dst_node; + return true; + } + } + return false; +} + +bool GraphPrepare::ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, + const map> &confirm_ops, ge::NodePtr &use_node) { + GE_RT_FALSE_CHECK_NOTNULL(var_node); + for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(out_anchor); + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(in_anchor); + if (ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node)) { + return true; + } + } + } + return false; +} +void GraphPrepare::AdjustBroadCastOpData(const ge::NodePtr &var_node) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { + GELOGW("set var_is_restore failed"); + } +} + +bool GraphPrepare::IsAssignOpData(const ge::NodePtr &var_node) { + GELOGD("IsAssignOpData var_node %s", var_node->GetName().c_str()); + std::map> assign_ops = {{ASSIGN, {0}}}; + + ge::NodePtr assign_node = nullptr; + if (ConfirmUseOpAndIndexByNode(var_node, assign_ops, assign_node)) { + return true; + } + + return false; +} + +void GraphPrepare::AdjustAssignOpData(const ge::NodePtr &var_node) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { + GELOGW("SetStr var_is_restore failed"); + } +} + } // namespace ge diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index 002d45ab..767ef96e 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -30,6 +30,7 @@ #include "common/util.h" #include "graph/compute_graph.h" #include "graph/manager/graph_manager_utils.h" +#include "graph/manager/util/variable_accelerate_ctrl.h" #include "graph/model.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" @@ -45,8 +46,11 @@ class GraphPrepare { GraphPrepare(const GraphPrepare &in) = delete; GraphPrepare &operator=(const GraphPrepare &in) = delete; Status Prepare(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, - uint64_t session_id = 0); + VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id = 0); + Status PrepareDynShape(ConstGraphPtr graph, const std::vector &user_input, + ge::ComputeGraphPtr &compute_graph, uint64_t session_id = 0); void SetOptions(const GraphManagerOptions &options); + Status GenerateInfershapeGraph(ConstGraphPtr graph); private: Status Init(const ge::Graph &graph, uint64_t session_id = 0); @@ -58,21 +62,40 @@ class GraphPrepare { Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); Status UpdateInput(const std::vector &user_input); + Status CheckAndUpdateInput(const std::vector &user_input); Status CheckConstOp(); Status VerifyConstOp(const NodePtr &node); Status CheckUserInput(const std::vector &user_input); Status OptimizeForPreprocess(); + Status PrepareOptimize(); Status InferShapeForPreprocess(); Status TryDoAipp(); Status OptimizeAfterInfershapeByAtcParams(); Status UpdateVariableFormats(ComputeGraphPtr &graph); + Status UpdateVariableFormatsDynShape(ComputeGraphPtr &graph); Status FormatAndShapeProcess(); Status ResourcePairProcess(const std::string &action); void ProcessCCEFormat(); Status OptimizeBeforeInfershape(); Status OptimizeGraphBeforeSubGraph(); - void SaveOriginalGraphToOmModel(); + Status NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_ctrl); + Status SaveOriginalGraphToOmModel(); Status ProcessNetOutput(); + Status UpdateInputOutputByOptions(); + bool IsBroadCastOpData(const ge::NodePtr &var_node); + + void AdjustBroadCastOpData(const ge::NodePtr &var_node); + + bool IsAssignOpData(const ge::NodePtr &var_node); + + void AdjustAssignOpData(const ge::NodePtr &var_node); + + bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, + ge::NodePtr &use_node); + + bool ConfirmUseOpAndIndexByNode(const ge::NodePtr &var_node, const map> &confirm_ops, + ge::NodePtr &use_node); + ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; }; diff --git a/src/ge/graph/preprocess/insert_op/base_insert_op.h b/src/ge/graph/preprocess/insert_op/base_insert_op.h index f482e34b..355aaae6 100644 --- a/src/ge/graph/preprocess/insert_op/base_insert_op.h +++ b/src/ge/graph/preprocess/insert_op/base_insert_op.h @@ -17,16 +17,16 @@ #ifndef GE_GRAPH_PREPROCESS_INSERT_OP_BASE_INSERT_OP_H_ #define GE_GRAPH_PREPROCESS_INSERT_OP_BASE_INSERT_OP_H_ +#include #include #include #include -#include #include "common/fmk_error_codes.h" #include "common/types.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/compute_graph.h" -#include "proto/om.pb.h" #include "proto/insert_op.pb.h" +#include "proto/om.pb.h" namespace ge { class InsertOpBase { diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index b14aa4b9..277d711a 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -39,8 +39,6 @@ #include "external/graph/operator_factory.h" #include "base_insert_op.h" -using domi::AIPP; - #define SAVE_AIPP_ATTR(KEY, SAVE_TYPE) \ do { \ (void)aipp_attrs.SetAttr(#KEY, GeAttrValue::CreateFrom(aipp_params_->KEY())); \ @@ -84,7 +82,7 @@ const float DEFAULT_VAR_RECI_CHN = 1.0; namespace ge { namespace { const char *const kMbatchSwitchnName = "mbatch-switch-name"; -const char *const kAippConfigPath = "aipp_config_route"; +const char *const kAippConfigPath = "aipp_config_path"; const char *const kCurrentAippIndex = "current_aipp_index"; const char *const kDynamicAippData = "ascend_dynamic_aipp_data"; const uint64_t kMinTransferShape = 3; @@ -100,13 +98,13 @@ Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &bat batch = 1; return SUCCESS; } - if (shape.size() == domi::DIM_DEFAULT_SIZE) { + if (shape.size() == DIM_DEFAULT_SIZE) { switch (format) { case FORMAT_NCHW: - batch = shape[domi::NCHW_DIM_N]; + batch = shape[NCHW_DIM_N]; return SUCCESS; case FORMAT_NHWC: - batch = shape[domi::NHWC_DIM_N]; + batch = shape[NHWC_DIM_N]; return SUCCESS; default: GELOGE(PARAM_INVALID, "Not support data format: %s", TypeUtils::FormatToSerialString(format).c_str()); @@ -210,21 +208,16 @@ NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr const std::string &aippConfigPath, const uint32_t &index) { std::string current_name = out_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor->GetIdx()) + "_huawei_aipp"; - auto aipp_opdesc_ptr = MakeShared(current_name, domi::AIPP); + auto aipp_opdesc_ptr = MakeShared(current_name, AIPP); if (aipp_opdesc_ptr == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc aipp desc, name %s", current_name.c_str()); return nullptr; } // Update attributes - GeAttrValue::NamedAttrs aipp_attr; + GeAttrValue::NAMED_ATTRS aipp_attr; ConvertParamToAttr(aipp_attr); - // a useless attr but defined in IR, we use `aipp_config_route` actrually - if (!AttrUtils::SetStr(aipp_opdesc_ptr, "aipp_config_path", "./aipp.cfg")) { - GELOGE(INTERNAL_ERROR, "Set config file path attr for aipp node failed"); - return nullptr; - } - if (!AttrUtils::SetNamedAttrs(aipp_opdesc_ptr, domi::ATTR_NAME_AIPP, aipp_attr)) { + if (!AttrUtils::SetNamedAttrs(aipp_opdesc_ptr, ATTR_NAME_AIPP, aipp_attr)) { GELOGE(INTERNAL_ERROR, "Set name attrs for aipp node failed"); return nullptr; } @@ -291,7 +284,7 @@ domi::AippOpParams::AippMode AippOp::GetAippMode() { return aipp_params_->aipp_m NodePtr AippOp::FindDataByIndex(const ComputeGraphPtr &graph, int rank) { int64_t data_index = 0; for (auto &node : graph->GetDirectNode()) { - if (node->GetType() != domi::DATA) { + if (node->GetType() != DATA) { continue; } // There is no `index` attribute on the `Data` node when compile in inference scene @@ -536,7 +529,7 @@ Status AippOp::GenerateOpDesc(OpDescPtr op_desc) { static int op_idx = 0; op_desc->SetName(std::string("aipp_node").append(std::to_string(op_idx++))); - op_desc->SetType(domi::AIPP); + op_desc->SetType(AIPP); // Add two InputDesc, add the second after the first one is added successfully. if ((op_desc->AddInputDesc(GeTensorDesc()) != GRAPH_SUCCESS) || @@ -549,17 +542,17 @@ Status AippOp::GenerateOpDesc(OpDescPtr op_desc) { GELOGE(FAILED, "add output desc failed."); return FAILED; } - GeAttrValue::NamedAttrs aipp_attrs; + GeAttrValue::NAMED_ATTRS aipp_attrs; ConvertParamToAttr(aipp_attrs); - GE_IF_BOOL_EXEC(!AttrUtils::SetNamedAttrs(op_desc, domi::ATTR_NAME_AIPP, aipp_attrs), + GE_IF_BOOL_EXEC(!AttrUtils::SetNamedAttrs(op_desc, ATTR_NAME_AIPP, aipp_attrs), GELOGE(FAILED, "failed to set ATTR_NAME_AIPP"); return FAILED); return SUCCESS; } -void AippOp::ConvertParamToAttr(GeAttrValue::NamedAttrs &aipp_attrs) { +void AippOp::ConvertParamToAttr(GeAttrValue::NAMED_ATTRS &aipp_attrs) { GE_CHECK_NOTNULL_JUST_RETURN(aipp_params_); SAVE_AIPP_ATTR(aipp_mode, GeAttrValue::INT); @@ -654,7 +647,7 @@ Status AippOp::CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp_ TensorUtils::SetSize(input_tensor, max_dynamic_aipp_size); // new add aipp_data ops for dynamic aipp param input - OpDescPtr op_desc_ptr_data = MakeShared(kDynamicAippData, domi::AIPPDATA); + OpDescPtr op_desc_ptr_data = MakeShared(kDynamicAippData, AIPPDATA); GE_CHECK_NOTNULL(op_desc_ptr_data); auto stat1 = op_desc_ptr_data->AddInputDesc(input_tensor); diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.h b/src/ge/graph/preprocess/insert_op/ge_aipp_op.h index 61baacfd..d4f916e4 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.h +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.h @@ -68,7 +68,7 @@ class AippOp : public InsertOpBase { AippOp &operator=(const AippOp &aipp_op); AippOp(const AippOp &aipp_op); - void ConvertParamToAttr(ge::GeAttrValue::NamedAttrs &aipp_attrs); + void ConvertParamToAttr(ge::GeAttrValue::NAMED_ATTRS &aipp_attrs); void SetCscDefaultValue(); void SetDtcDefaultValue(); NodePtr FindDataByIndex(const ComputeGraphPtr &graph, int rank); diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 218fc7f7..52799156 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -18,7 +18,6 @@ #include #include #include "common/ge/ge_util.h" -#include "common/op/attr_define.h" #include "common/op/ge_op_utils.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" @@ -32,18 +31,10 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "inc/common/dynamic_aipp.h" +#include "common/dynamic_aipp.h" #include "common/formats/utils/formats_trans_utils.h" -using domi::AIPPDATA; using domi::AippOpParams; -using domi::DATA; -using domi::DEFAULT_FORMAT; -using domi::DIM_DEFAULT_SIZE; -using domi::NCHW_DIM_C; -using domi::NCHW_DIM_H; -using domi::NCHW_DIM_N; -using domi::NCHW_DIM_W; namespace ge { namespace { @@ -130,25 +121,24 @@ Status InsertNewOpUtil::CheckGraph(const ComputeGraphPtr &graph) { domi::AippOpParams::AippMode aippMode = domi::AippOpParams::undefined; for (const auto &node : graph->GetDirectNode()) { - if (node->GetType() != domi::DATA) { + if (node->GetType() != DATA) { continue; } - + size_t next_nodes_cnt = 0; std::vector aippNodes; for (const auto &anchor : node->GetAllOutDataAnchors()) { for (const auto &inAnchor : anchor->GetPeerInDataAnchors()) { const std::string &nodeType = inAnchor->GetOwnerNode()->GetType(); - - GE_CHK_BOOL_RET_STATUS(aippNodes.size() == 0 || nodeType == domi::AIPP, PARAM_INVALID, - "Can not config part of outputs of Data node to support AIPP, config all of the " - "outputs of Data to support AIPP, or config none of them"); - - if (nodeType == domi::AIPP) { + next_nodes_cnt++; + if (nodeType == AIPP) { aippNodes.push_back(inAnchor->GetOwnerNode()); continue; } } } + GE_CHK_BOOL_RET_STATUS((aippNodes.size() == 0) || (aippNodes.size() == next_nodes_cnt), PARAM_INVALID, + "Can not config part of outputs of Data node to support AIPP, config all " + "of the outputs of Data to support AIPP, or config none of them"); std::unique_ptr aippParams(new (std::nothrow) domi::AippOpParams()); GE_CHECK_NOTNULL(aippParams); @@ -184,10 +174,10 @@ Status InsertNewOpUtil::CheckGraph(const ComputeGraphPtr &graph) { Status InsertNewOpUtil::GetAippParams(const std::unique_ptr &aippParams, const NodePtr &aipp_node) { GE_CHECK_NOTNULL(aipp_node); - ge::GeAttrValue::NamedAttrs aipp_attr; + ge::GeAttrValue::NAMED_ATTRS aipp_attr; const OpDescPtr tmpOpPtr = aipp_node->GetOpDesc(); GE_CHECK_NOTNULL(tmpOpPtr); - GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(tmpOpPtr, domi::ATTR_NAME_AIPP, aipp_attr), FAILED, + GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(tmpOpPtr, ATTR_NAME_AIPP, aipp_attr), FAILED, "Aipp node should contain param aipp!"); GE_CHK_STATUS_RET(OpUtils::ConvertAippParams(aipp_attr, aippParams.get()), "get aipp params failed"); @@ -198,13 +188,13 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { std::set updated_switchn; for (auto &node : graph->GetDirectNode()) { - if (node->GetType() == domi::DATA) { + if (node->GetType() == DATA) { std::string switchn_name; if (AttrUtils::GetStr(node->GetOpDesc(), kMbatchSwitchnName, switchn_name)) { switchn_names_to_data[switchn_name] = node; } } - if (node->GetType() == domi::AIPP) { + if (node->GetType() == AIPP) { GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); } } @@ -253,6 +243,7 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s aipp_input->SetOriginDataType(aipp_dt); DataType aipp_origni_dt = aipp_input->GetOriginDataType(); GeShape aipp_shape = aipp_input->GetShape(); + Format aipp_format = aipp_input->GetFormat(); GELOGI("Aipp [%s] input datatype is %s, origin datatype is %s, input shape is %s", aipp_op_desc->GetName().c_str(), TypeUtils::DataTypeToSerialString(aipp_dt).c_str(), TypeUtils::DataTypeToSerialString(aipp_origni_dt).c_str(), ge::formats::ShapeToString(aipp_shape.GetDims()).c_str()); @@ -263,6 +254,8 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s input->SetOriginDataType(aipp_origni_dt); input->SetShape(aipp_shape); input->SetOriginShape(aipp_shape); + input->SetFormat(aipp_format); + input->SetOriginFormat(aipp_format); ge::TensorUtils::SetSize(*input, size); const GeTensorDescPtr &output = src_op->MutableOutputDesc(peer_out_anchor->GetIdx()); @@ -271,8 +264,10 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s output->SetOriginDataType(aipp_origni_dt); output->SetShape(aipp_shape); output->SetOriginShape(aipp_shape); + output->SetFormat(aipp_format); + output->SetOriginFormat(aipp_format); ge::TensorUtils::SetSize(*output, size); - if (src_node->GetType() == domi::SWITCHN) { + if (src_node->GetType() == SWITCHN) { switchns.insert(src_node); } GELOGI("Set node %s output %d size %ld by aipp.", src_node->GetName().c_str(), peer_out_anchor->GetIdx(), size); @@ -304,6 +299,8 @@ Status InsertNewOpUtil::UpdateDataBySwitchN(const NodePtr &switchn, const NodePt input_desc->SetOriginDataType(output_desc->GetOriginDataType()); input_desc->SetShape(output_desc->GetShape()); input_desc->SetOriginShape(output_desc->GetOriginShape()); + input_desc->SetFormat(output_desc->GetFormat()); + input_desc->SetOriginFormat(output_desc->GetOriginFormat()); TensorUtils::SetSize(*input_desc, max_size); auto data_opdesc = data->GetOpDesc(); diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index 9edd1d0a..47d7701f 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -46,9 +46,7 @@ const int kMergeDataOutIndex = 0; const size_t kMaxShapesCount = 100; const size_t kMinShapesCount = 2; -inline bool IsDataLikeType(const std::string &node_type) { - return (node_type == domi::DATA) || (node_type == domi::AIPP); -} +inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) { OpDescPtr desc = MakeShared(); @@ -57,7 +55,7 @@ NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const return nullptr; } desc->SetName(name); - desc->SetType(domi::MERGE); + desc->SetType(MERGE); GeTensorDesc tensor_desc; for (size_t i = 0; i < input_num; ++i) { auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc); @@ -144,7 +142,7 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { bool IsAllDimsPositive(const std::vector &dims) { for (auto dim : dims) { - if (dim <= 0) { + if (dim < 0) { return false; } } @@ -158,7 +156,7 @@ NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) { return nullptr; } desc->SetName(name); - desc->SetType(domi::CONSTANT); + desc->SetType(CONSTANT); GeTensor tensor; tensor.SetData(std::vector({0})); if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) { @@ -178,7 +176,7 @@ NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) { bool IsOnlyOutputToAipp(const NodePtr &node) { for (const auto &out_node : node->GetOutDataNodes()) { - if (out_node->GetType() != domi::AIPP) { + if (out_node->GetType() != AIPP) { return false; } } @@ -188,7 +186,7 @@ bool IsOnlyOutputToAipp(const NodePtr &node) { Status CheckDataShape(const std::vector &nodes) { size_t unknown_shape_count = 0; for (const auto &node : nodes) { - if (node->GetType() != domi::DATA) { + if (node->GetType() != DATA) { continue; } for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) { @@ -290,7 +288,7 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { return SUCCESS; } NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { - if (node->GetType() == domi::NETOUTPUT) { + if (node->GetType() == NETOUTPUT) { return kNodeOutBatchBranch; } if (IsDataLikeType(node->GetType()) && !IsOnlyOutputToAipp(node)) { @@ -429,7 +427,7 @@ NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { return nullptr; } desc->SetName("ascend_mbatch_shape_data"); - desc->SetType(domi::DATA); + desc->SetType(DATA); GeTensorDesc tensor_desc; tensor_desc.SetFormat(FORMAT_ND); @@ -612,6 +610,8 @@ Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) { } Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); + (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); + if (IsAllDimsPositive(data_shape.GetDims())) { GELOGI("The shape of data %s are positive(%s), skip the multi batch process", data->GetName().c_str(), data_shape.ToString().c_str()); @@ -624,7 +624,7 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { return OUT_OF_MEMORY; } switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn"); - switchn_desc->SetType(domi::SWITCHN); + switchn_desc->SetType(SWITCHN); GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); if (switchn_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) { // data @@ -874,7 +874,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { std::vector> shapes; if (!domi::GetContext().dynamic_batch_size.empty()) { GELOGD("Found dynamic batch option, value %s", domi::GetContext().dynamic_batch_size.c_str()); - std::vector dims = domi::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); + std::vector dims = ge::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); for (const auto &dim : dims) { if (dim.empty()) { continue; @@ -885,13 +885,13 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { } if (!domi::GetContext().dynamic_image_size.empty()) { GELOGD("Found dynamic image size option, value %s", domi::GetContext().dynamic_image_size.c_str()); - std::vector shape_strs = domi::StringUtils::Split(domi::GetContext().dynamic_image_size, ';'); + std::vector shape_strs = ge::StringUtils::Split(domi::GetContext().dynamic_image_size, ';'); for (const auto &shape_str : shape_strs) { if (shape_str.empty()) { continue; } std::vector shape; - std::vector dims = domi::StringUtils::Split(shape_str, ','); + std::vector dims = ge::StringUtils::Split(shape_str, ','); for (const auto &dim : dims) { if (dim.empty()) { continue; diff --git a/src/ge/inc/graph_pass.h b/src/ge/inc/graph_pass.h index fb2a6238..b50269d0 100644 --- a/src/ge/inc/graph_pass.h +++ b/src/ge/inc/graph_pass.h @@ -73,13 +73,13 @@ class GraphPass : public Pass { static bool IsConstNode(const ge::NodePtr &node) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, GELOGE(FAILED, "Node GetOpDesc is nullptr"); return false); - if (node->GetOpDesc()->GetType() == domi::CONSTANTOP) { + if (node->GetOpDesc()->GetType() == CONSTANTOP) { return true; - } else if (node->GetOpDesc()->GetType() == domi::FRAMEWORKOP) { + } else if (node->GetOpDesc()->GetType() == FRAMEWORKOP) { string type; - GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type), - return false, "Get original_type for op %s fail!", node->GetName().c_str()); - GE_IF_BOOL_EXEC(type == domi::CONSTANT, GELOGI("Is const op"); return true); + GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type), return false, + "Get original_type for op %s fail!", node->GetName().c_str()); + GE_IF_BOOL_EXEC(type == CONSTANT, GELOGI("Is const op"); return true); return false; } else { return false; diff --git a/src/ge/inc/kernel.h b/src/ge/inc/kernel.h index ec0e5e40..9f7e1308 100644 --- a/src/ge/inc/kernel.h +++ b/src/ge/inc/kernel.h @@ -24,9 +24,9 @@ #include "graph/graph.h" #include "graph/op_desc.h" -using std::vector; -using std::unique_ptr; using std::shared_ptr; +using std::unique_ptr; +using std::vector; namespace ge { /// @@ -51,20 +51,6 @@ class Kernel { return NOT_CHANGED; } - /// - /// Data description transformation interface - /// @param [in] op_desc_ptr Operator related parameters - /// @param [in] input Data description(include dimensionã€formatã€data type etc.) - /// @param [inout] output save the transformation result - /// @author - /// - virtual Status Compute(const ge::OpDescPtr op_desc_ptr, const ge::GeTensorDescPtr input, ge::GeTensorDescPtr output) { - (void)op_desc_ptr; - (void)input; - (void)output; - return NOT_CHANGED; - } - virtual Status Compute(const NodePtr& node, std::vector& v_output) { (void)node; (void)v_output; diff --git a/src/ge/inc/kernel_factory.h b/src/ge/inc/kernel_factory.h index 8e5912eb..c0624e14 100644 --- a/src/ge/inc/kernel_factory.h +++ b/src/ge/inc/kernel_factory.h @@ -103,5 +103,5 @@ class KernelFactory { return ptr; \ } \ KernelFactory::Registerar g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) -}; // end namespace ge +}; // end namespace ge #endif // GE_INC_KERNEL_FACTORY_H_ diff --git a/src/ge/inc/node_pass.h b/src/ge/inc/node_pass.h deleted file mode 100644 index 4334c50d..00000000 --- a/src/ge/inc/node_pass.h +++ /dev/null @@ -1,66 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_INC_NODE_PASS_H_ -#define GE_INC_NODE_PASS_H_ - -#include -#include "common/op/ge_op_utils.h" -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/node.h" -#include "graph/op_desc.h" -#include "graph/range_vistor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "inc/pass.h" -namespace ge { -/// -/// @ingroup domi_omg -/// @brief node pass -/// @author -/// -class NodePass : public Pass { - public: - /// - /// run node pass - /// @param [in] node node to be optimized - /// @return SUCCESS optimized successfully - /// @return TO_BE_DELETED optimized successfully and the node need to be deleted - /// @return NOT_CHANGED not optimized - /// @return others optimize failed - /// @author - /// - virtual Status Run(ge::NodePtr node) = 0; - - /// Optimize to weight, Set the "is input const" flag of the output node to true - /// @param [in] node node to be optimized - /// @return SUCCESS optimized successfully - /// @return others optimize failed - /// - Status SetOutNodeWeightDef(ge::NodePtr node, std::vector &v_weight); - - /// Update node connection relationship - /// @param [in] node The node to be optimized - /// @return SUCCESS Optimized successfully - /// @return FAILED Optimization failure - /// - Status UpdateNodeInfo(ge::NodePtr node); -}; -} // namespace ge -#endif // GE_INC_NODE_PASS_H_ diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index 84ecc506..db12ef79 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -40,7 +40,6 @@ #include "runtime/kernel.h" using Json = nlohmann::json; -using domi::StringUtils; namespace ge { namespace { @@ -58,10 +57,15 @@ Status GELib::Initialize(const map &options) { GELOGE(GE_CLI_INIT_FAILED, "GeLib initialize failed, malloc shared_ptr failed."); return GE_CLI_INIT_FAILED; } + Status ret = instancePtr_->SetRTSocVersion(options); + if (ret != SUCCESS) { + GELOGE(ret, "GeLib initial failed."); + return ret; + } GetMutableGlobalOptions().insert(options.begin(), options.end()); GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); GE_TIMESTAMP_START(Init); - Status ret = instancePtr_->InnerInitialize(options); + ret = instancePtr_->InnerInitialize(options); if (ret != SUCCESS) { GELOGE(ret, "GeLib initial failed."); instancePtr_ = nullptr; @@ -143,6 +147,35 @@ Status GELib::InnerInitialize(const map &options) { return SUCCESS; } +void GELib::SetIncreBuild(const map &options) { + auto iter = options.find(OPTION_EXEC_ENABLE_INCRE_BUILD); + if (iter != options.end()) { + const std::string enable_incre_build = "true"; + const std::string disable_incre_build = "false"; + if (iter->second == enable_incre_build) { + is_incre_build_ = true; + GELOGI("Enable incre build."); + auto path_iter = options.find(OPTION_EXEC_INCRE_BUILD_CACHE_PATH); + if (path_iter != options.end()) { + std::string cache_path = path_iter->second; + if (!cache_path.empty() && cache_path[cache_path.size() - 1] != '/') { + cache_path += "/"; + } + incre_build_cache_path_ = cache_path; + } else { + incre_build_cache_path_ = ".ge_cache/"; + } + GELOGD("Using incre build cache path: %s.", incre_build_cache_path_.c_str()); + } else if (iter->second == disable_incre_build) { + is_incre_build_ = false; + GELOGI("Disable incre build."); + } else { + is_incre_build_ = false; + GELOGW("Invalid ENABLE_INCRE_BUILD option, it should be true or false."); + } + } +} + Status GELib::SystemInitialize(const map &options) { Status status = FAILED; auto iter = options.find(OPTION_GRAPH_RUN_MODE); @@ -174,7 +207,15 @@ Status GELib::SystemInitialize(const map &options) { GELOGD("Get dump step %s successfully", dump_step.c_str()); PropertiesManager::Instance().SetDumpStep(dump_step); } + auto mode_iter = options.find(OPTION_EXEC_DUMP_MODE); + if (mode_iter != options.end()) { + std::string dump_mode = mode_iter->second; + GELOGD("Get dump mode %s successfully", dump_mode.c_str()); + PropertiesManager::Instance().SetDumpMode(dump_mode); + } } + // check incre build flag + SetIncreBuild(options); if (is_train_mode_) { InitOptions(options); @@ -185,6 +226,17 @@ Status GELib::SystemInitialize(const map &options) { return status; } +Status GELib::SetRTSocVersion(const map &options) { + GELOGI("start SetRTSocVersion"); + auto it = options.find(ge::SOC_VERSION); + if (it != options.end()) { + GE_CHK_RT_RET(rtSetSocVersion(it->second.c_str())); + } else { + GELOGW("options not find SOC_VERSION"); + } + return SUCCESS; +} + void GELib::InitOptions(const map &options) { this->options_.session_id = 0; auto iter = options.find(OPTION_EXEC_SESSION_ID); @@ -210,12 +262,18 @@ void GELib::InitOptions(const map &options) { if (iter != options.end()) { std::istringstream(iter->second) >> this->options_.deployMode; } - iter = options.find(OPTION_EXEC_POD_NAME); if (iter != options.end()) { this->options_.podName = iter->second.c_str(); } - + iter = options.find(OPTION_EXEC_PROFILING_MODE); + if (iter != options.end()) { + this->options_.profiling_mode = iter->second.c_str(); + } + iter = options.find(OPTION_EXEC_PROFILING_OPTIONS); + if (iter != options.end()) { + this->options_.profiling_options = iter->second.c_str(); + } iter = options.find(OPTION_EXEC_RANK_ID); if (iter != options.end()) { this->options_.rankId = std::strtoll(iter->second.c_str(), nullptr, kDecimal); @@ -302,7 +360,9 @@ Status GELib::SystemShutdownWithOptions(const Options &options) { if (!ProfilingManager::Instance().ProfilingOpTraceOn() && ProfilingManager::Instance().ProfilingOn()) { ProfilingManager::Instance().StopProfiling(); } - + if (ProfilingManager::Instance().ProfilingOn()) { + ProfilingManager::Instance().PluginUnInit(GE_PROFILING_MODULE); + } is_system_inited = false; is_shutdown = true; @@ -384,6 +444,8 @@ Status GELib::Finalize() { } is_train_mode_ = false; + GetMutableGlobalOptions().erase(ENABLE_SINGLE_STREAM); + instancePtr_ = nullptr; init_flag_ = false; if (final_state != SUCCESS) { diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index 0945907a..60cbc0c0 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -65,6 +65,12 @@ class GELib { // add head stream to model bool HeadStream() const { return head_stream_; } + // get incre build flag + bool IsIncreBuild() const { return is_incre_build_; } + + // get incre build cache path + const std::string &GetIncreBuildCachePath() const { return incre_build_cache_path_; } + Status InitSystemWithoutOptions(); Status InitSystemWithOptions(Options &options); Status SystemShutdownWithOptions(const Options &options); @@ -74,8 +80,10 @@ class GELib { const GELib &operator=(const GELib &); Status InnerInitialize(const map &options); Status SystemInitialize(const map &options); + Status SetRTSocVersion(const map &options); void RollbackInit(); void InitOptions(const map &options); + void SetIncreBuild(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; @@ -87,8 +95,9 @@ class GELib { bool is_system_inited = false; bool is_shutdown = false; bool is_use_hcom = false; - + bool is_incre_build_ = false; bool head_stream_ = false; + std::string incre_build_cache_path_; }; } // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc new file mode 100644 index 00000000..109e6e6f --- /dev/null +++ b/src/ge/ir_build/atc_ir_common.cc @@ -0,0 +1,254 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "atc_ir_common.h" +#include "framework/common/string_util.h" +#include "framework/common/types.h" +#include "framework/common/util.h" +#include "common/util/error_manager/error_manager.h" + +using std::pair; +using std::string; +using std::vector; + +namespace ge { +namespace { +const int64_t kDynamicInputDim = -1; +const int64_t kDynamicImageSizeNum = 2; +} // namespace + +bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, + std::string &dynamic_batch_size) { + int32_t size = 0; + for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { + vector shape = iter->second; + if (shape.size() < 1) { + ErrorManager::GetInstance().ATCReportErrMessage("E10017"); + GELOGE(ge::PARAM_INVALID, "The input shape size can not be less than 0 in dynamic batchsize scenario."); + return false; + } + if (shape[0] == kDynamicInputDim) { + for (size_t i = 1; i < shape.size(); ++i) { + if (shape[i] < 1) { + ErrorManager::GetInstance().ATCReportErrMessage("E10018", {"index", "shape"}, + {std::to_string(i), std::to_string(shape[i])}); + GELOGE(ge::PARAM_INVALID, "Only batch N can be -1 in dynamic batchsize scenario, current shape[%zu] is %ld", + i, shape[i]); + return false; + } + } + size++; + } + } + + if (size == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10043"); + GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 in dynamic batchsize scenario."); + return false; + } + + for (char c : dynamic_batch_size) { + if (!isdigit(c) && (c != ',') && (c != ' ')) { + GELOGE(ge::PARAM_INVALID, "dynamic_batch_size input : %s is invalid.", dynamic_batch_size.c_str()); + return false; + } + } + if (dynamic_batch_size.back() == ',') { + dynamic_batch_size.erase(dynamic_batch_size.end() - 1); + } + return true; +} + +bool CheckDynamicImagesizeInputShapeValid(unordered_map> shape_map, + const std::string input_format, std::string &dynamic_image_size) { + int32_t size = 0; + for (unordered_map>::iterator iter = shape_map.begin(); iter != shape_map.end(); ++iter) { + vector shape = iter->second; + // only support four dim + if (shape.size() != DIM_DEFAULT_SIZE) { + if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + return false; + } + continue; + } + + int64_t height = 0; + int64_t width = 0; + if (input_format == "NCHW") { + height = shape[NCHW_DIM_H]; + width = shape[NCHW_DIM_W]; + } + + if (input_format == "NHWC") { + height = shape[NHWC_DIM_H]; + width = shape[NHWC_DIM_W]; + } + + if (height == kDynamicInputDim && width == kDynamicInputDim && + std::count(shape.begin(), shape.end(), kDynamicInputDim) == kDynamicImageSizeNum) { + size++; + } else if (std::count(shape.begin(), shape.end(), kDynamicInputDim) == 0) { + continue; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + return false; + } + } + if (size == 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10019"); + GELOGE(ge::PARAM_INVALID, "Only height or width can be -1 in dynamic imagesize scenario."); + return false; + } + + if (dynamic_image_size.back() == ';') { + dynamic_image_size.erase(dynamic_image_size.end() - 1); + } + + // Different parameter sets are split string by ';' + std::vector split_set = StringUtils::Split(dynamic_image_size, ';'); + // Different dimensions are split by ',' + std::vector split_dim; + for (auto str : split_set) { + split_dim = StringUtils::Split(str, ','); + if (split_dim.size() != static_cast(kDynamicImageSizeNum)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10020", {"DynamicImageSizeNum"}, + {std::to_string(kDynamicImageSizeNum)}); + GELOGE(ge::PARAM_INVALID, + "Invalid dynamic_image_size : dynamic_image_size's number of dimensions of each " + "group must be %ld.", + kDynamicImageSizeNum); + return false; + } + } + + return true; +} + +Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, + const std::string input_shape, const std::string input_format, + bool &is_dynamic_input) { + if (!dynamic_batch_size.empty() && !dynamic_image_size.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1"}, + {dynamic_batch_size, dynamic_image_size}); + GELOGE(ge::PARAM_INVALID, "dynamic_batch_size and dynamic_image_size can not both exist"); + return ge::PARAM_INVALID; + } + + if (dynamic_batch_size.empty() && dynamic_image_size.empty()) { + return ge::SUCCESS; + } + + unordered_map> shape_map; + vector>> user_shape_map; + is_dynamic_input = true; + if (!ParseInputShape(input_shape, shape_map, user_shape_map, is_dynamic_input)) { + GELOGE(ge::PARAM_INVALID, "Failed to parse input shape: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + + if (shape_map.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"input_shape"}); + GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic batchsize scenario."); + return ge::PARAM_INVALID; + } + + if (!dynamic_batch_size.empty()) { + if (!CheckDynamicBatchSizeInputShapeValid(shape_map, dynamic_batch_size)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic batch size input shape failed: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + } + + if (!dynamic_image_size.empty()) { + if (!CheckDynamicImagesizeInputShapeValid(shape_map, input_format, dynamic_image_size)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic image size input shape failed: %s", input_shape.c_str()); + return ge::PARAM_INVALID; + } + } + return ge::SUCCESS; +} + +bool ParseInputShape(const string &input_shape, unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input) { + vector shape_vec = StringUtils::Split(input_shape, ';'); + const int DEFAULT_SHAPE_PAIR_SIZE = 2; + for (const auto &shape : shape_vec) { + vector shape_pair_vec = StringUtils::Split(shape, ':'); + if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10010", {"shape"}, {shape}); + GELOGW("Input parameter[--input_shape]’s shape is [%s], correct sample is input_name1:n1,c1,h1,w1", + shape.c_str()); + return false; + } + if (shape_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape"}, {shape}); + GELOGW( + "Input parameter[--input_shape]’s shape is [%s], can not empty, " + "correct sample is input_name1:n1,c1,h1,w1", + shape.c_str()); + return false; + } + + vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); + vector shape_values; + for (auto &shape_value_str : shape_value_strs) { + // stoul: The method may throw an exception: invalid_argument/out_of_range + if (std::string::npos != shape_value_str.find('.')) { + ErrorManager::GetInstance().ATCReportErrMessage("E10012", {"shape"}, {shape_value_str}); + GELOGW("--input_shape's shape value[%s] exist float number the correct sample is \"input_name1:1,3,224,224\"", + shape_value_str.c_str()); + return false; + } + + long left_result = 0; + try { + left_result = stol(StringUtils::Trim(shape_value_str)); + } catch (const std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "shape"}, {"input_shape", shape}); + GELOGW("--input_shape’s shape_value_str[%s] cause out of range execption!", shape_value_str.c_str()); + return false; + } catch (const std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "shape"}, + {"input_shape", shape_value_str}); + GELOGW("--input_shape’s shape_value_str[%s] cause invalid argument!shape_value_str:%s", + shape_value_str.c_str()); + return false; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "shape"}, + {"input_shape", shape_value_str}); + GELOGW("--input_shape’s shape_value_str[%s] stol fail!", shape_value_str.c_str()); + return false; + } + int64_t result = left_result; + // - 1 is not currently supported + if (!is_dynamic_input && result <= 0) { + GELOGW("Invalid parameter for input shape: %s ,expect positive integer , but value = %ld", shape.c_str(), + result); + return false; + } + shape_values.push_back(result); + } + + shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + } + + return true; +} +} // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.h b/src/ge/ir_build/atc_ir_common.h new file mode 100644 index 00000000..5b268b48 --- /dev/null +++ b/src/ge/ir_build/atc_ir_common.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FRAMEWORK_DOMI_ATC_IR_COMMON_H_ +#define FRAMEWORK_DOMI_ATC_IR_COMMON_H_ + +#include +#include +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/omg/omg_inner_types.h" + +namespace ge { +bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, + std::string &dynamic_batch_size); + +bool CheckDynamicImagesizeInputShapeValid(unordered_map> shape_map, + const std::string input_format, std::string &dynamic_image_size); + +Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, + const std::string input_shape, const std::string input_format, + bool &is_dynamic_input); + +bool ParseInputShape(const std::string &input_shape, std::unordered_map> &shape_map, + std::vector>> &user_shape_map, bool is_dynamic_input = false); +} // namespace ge +#endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ \ No newline at end of file diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index 2c871559..cf507c42 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -17,32 +17,26 @@ #include "external/ge/ge_ir_build.h" #include -#include "generator/ge_generator.h" -#include "model/ge_model.h" -#include "graph/ge_tensor.h" -#include "init/gelib.h" -#include "ge/ge_api_types.h" -#include "graph/compute_graph.h" -#include "graph/utils/type_utils.h" -#include "external/register/register_types.h" #include "common/auth/file_saver.h" -#include "offline/atc_ir_common.h" +#include "external/register/register_types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "framework/omg/omg_inner_types.h" +#include "framework/common/string_util.h" #include "framework/common/types.h" #include "framework/common/util.h" -#include "framework/common/string_util.h" #include "framework/omg/omg_inner_types.h" +#include "framework/omg/omg_inner_types.h" +#include "ge/ge_api_types.h" +#include "generator/ge_generator.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/utils/type_utils.h" +#include "init/gelib.h" +#include "ir_build/atc_ir_common.h" +#include "model/ge_model.h" using domi::GetContext; -using domi::StringUtils; -using ge::FileSaver; -using ge::GRAPH_PARAM_INVALID; -using ge::GRAPH_SUCCESS; -using ge::ParseInputShape; using std::string; - using namespace std; namespace ge { @@ -90,10 +84,10 @@ class Impl { GetContext().out_nodes_map.clear(); GetContext().user_out_nodes.clear(); GetContext().net_format = domi::DOMI_TENSOR_RESERVED; - GetContext().type = domi::FRAMEWORK_RESERVED; - GetContext().run_mode = domi::ONLY_PRE_CHECK; + GetContext().type = domi::FMK_TYPE_RESERVED; + GetContext().run_mode = ONLY_PRE_CHECK; GetContext().train_flag = false; - GetContext().fp16_high_precision = domi::HIGH_PRECISION_DEFAULT; + GetContext().fp16_high_precision = HIGH_PRECISION_DEFAULT; GetContext().output_type.clear(); GetContext().net_name.clear(); GetContext().is_dynamic_input = false; @@ -134,17 +128,6 @@ graphStatus Impl::Init(const std::map &options) { return ret; } - auto iter = options_.find(ge::ir_option::OP_NAME_MAP); - if (iter != options_.end()) { - // divided by ":" - PropertiesManager::Instance().SetPropertyDelimiter(IR_OP_CONF_DELIMITER); - // Parsing the op_conf configuration item file - GE_RETURN_WITH_LOG_IF_FALSE(PropertiesManager::Instance().Init(iter->second), "op_name_map init failed!"); - // Return map and put it into ATC global variable - GetContext().op_conf_map.clear(); - GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap(); - } - string input_shape = options_.find("input_shape") == options_.end() ? "" : options_["input_shape"]; string input_format = options_.find("input_format") == options_.end() ? "" : options_["input_format"]; string net_format = options_.find("net_format") == options_.end() ? "" : options_["net_format"]; @@ -189,7 +172,7 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vectorGetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == domi::DATA) { + if (op->GetType() == DATA) { GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); @@ -289,6 +272,7 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m GELOGE(GRAPH_PARAM_INVALID, "input model is not illegal"); return GRAPH_PARAM_INVALID; } - return FileSaver::SaveToFile((output_file + ".om"), (void *)model.data.get(), static_cast(model.length)); + return FileSaver::SaveToFile((output_file + ".om"), reinterpret_cast(model.data.get()), + static_cast(model.length)); } } // namespace ge diff --git a/src/ge/omm/csa_interact.cc b/src/ge/omm/csa_interact.cc index 075da863..dd3f6240 100644 --- a/src/ge/omm/csa_interact.cc +++ b/src/ge/omm/csa_interact.cc @@ -25,8 +25,6 @@ #include "mmpa/mmpa_api.h" #include "nlohmann/json.hpp" -using domi::CurrentTimeInStr; - namespace ge { namespace { const char FMK_STATUS_FILE_DIR_ENV[] = "FMK_STATUS_FILE_DIR"; diff --git a/src/ge/opskernel_manager/ops_kernel_manager.cc b/src/ge/opskernel_manager/ops_kernel_manager.cc index 0785ad81..b5276483 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.cc +++ b/src/ge/opskernel_manager/ops_kernel_manager.cc @@ -21,9 +21,13 @@ #include #include +#include +#include +#include #include "../init/gelib.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api.h" +#include "proto/optimizer_priority.pb.h" namespace { const char *const kInitialize = "Initialize"; @@ -74,10 +78,11 @@ Status OpsKernelManager::Initialize(const map &options_const) { plugin_manager_.InvokeAll &>(kGetOpsKernelInfoStores, ops_kernel_store_); Status rst2 = plugin_manager_.InvokeAll &>(kGetGraphOptimizerObjs, graph_optimizers_); - if ((rst0 != SUCCESS) || (rst1 != SUCCESS) || (rst2 != SUCCESS)) { + if ((rst0 != SUCCESS) && (rst1 != SUCCESS) && (rst2 != SUCCESS)) { GELOGE(GE_OPS_GET_NO_VALID_SO); return GE_OPS_GET_NO_VALID_SO; } + ret = CheckPluginPtr(); if (ret != SUCCESS) { return ret; @@ -91,6 +96,11 @@ Status OpsKernelManager::Initialize(const map &options_const) { if (ret != SUCCESS) { return ret; } + ret = InitGraphOptimizerPriority(); + if ((ret != SUCCESS)) { + GELOGE(ret, "Init graph optimizer priority failed."); + return ret; + } init_flag_ = true; return SUCCESS; } else { @@ -366,6 +376,39 @@ bool OpsKernelManager::GetEnableAICPUFlag() const { return enable_aicpu_flag_; } bool OpsKernelManager::GetEnablePluginFlag() const { return (enable_fe_flag_ || enable_aicpu_flag_); } +Status OpsKernelManager::InitGraphOptimizerPriority() { + string priority_conf_path = "plugin/opskernel/optimizer_priority.pbtxt"; + string path = PluginManager::GetPath(); + path.append(priority_conf_path); + + optimizers::Priority optimizerPriority; + bool ret = ReadProtoFromText(path.c_str(), &optimizerPriority); + if (!ret) { + GELOGW("Read priority file failed. Follow loading sequence."); + return SUCCESS; + } + auto priorities = optimizerPriority.optimizer(); + if (priorities.empty()) { + GELOGI("No priority file config. Follow loading sequence."); + return SUCCESS; + } + // sort optimizer map by priority + map original_optimizers(graph_optimizers_); + graph_optimizers_.clear(); + std::stringstream priority_seq; + for (const auto optimizer_name : priorities) { + auto name_to_optimizer_pair = original_optimizers.find(optimizer_name); + if (name_to_optimizer_pair != original_optimizers.end()) { + graph_optimizers_.emplace(*name_to_optimizer_pair); + priority_seq << optimizer_name.c_str() << ' '; + } else { + GELOGW("Unknown optimizer %s show up in priority config file. Please check.", optimizer_name.c_str()); + } + } + GELOGI("Graph Optimizers priority initialized. The sequence will follow : %s.", priority_seq.str().c_str()); + return SUCCESS; +} + Status OpsKernelManager::FinalizeOpsKernel() { GELOGI("ge invoke ops kernal finalize."); Status ret = plugin_manager_.InvokeAll(kFinalize); diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index d83b7bc4..df7e06b2 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -97,6 +97,10 @@ class OpsKernelManager { Status ParsePluginOptions(const map &options, const string &plugin_name, bool &enable_flag); + Status LoadGEGraphOptimizer(map &graphOptimizer); + + Status InitGraphOptimizerPriority(); + PluginManager plugin_manager_; // opsKernelInfoStore map ops_kernel_store_{}; diff --git a/src/ge/opskernel_manager/optimizer_priority.pbtxt b/src/ge/opskernel_manager/optimizer_priority.pbtxt new file mode 100644 index 00000000..06bcf520 --- /dev/null +++ b/src/ge/opskernel_manager/optimizer_priority.pbtxt @@ -0,0 +1 @@ +optimizer:["AIcoreEngine","VectorEngine","aicpu_optimizer","hccl_graph_optimizer"] \ No newline at end of file diff --git a/src/ge/plugin/engine/CMakeLists.txt b/src/ge/plugin/engine/CMakeLists.txt index 45c3d302..a3f14ee2 100644 --- a/src/ge/plugin/engine/CMakeLists.txt +++ b/src/ge/plugin/engine/CMakeLists.txt @@ -14,7 +14,7 @@ # ============================================================================ # libengine.so -file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} +file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "*.cc" ) diff --git a/src/ge/session/inner_session.cc b/src/ge/session/inner_session.cc index 4798de90..74a43d96 100644 --- a/src/ge/session/inner_session.cc +++ b/src/ge/session/inner_session.cc @@ -18,12 +18,12 @@ #include #include #include -#include "graph/load/new_model_manager/model_manager.h" +#include "common/util.h" +#include "framework/common/debug/ge_log.h" +#include "graph/ge_context.h" #include "graph/ge_global_options.h" #include "graph/ge_local_context.h" -#include "graph/ge_context.h" -#include "framework/common/debug/ge_log.h" -#include "common/util.h" +#include "graph/load/new_model_manager/model_manager.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/tensor_adapter.h" #include "runtime/mem.h" @@ -180,11 +180,11 @@ Status InnerSession::RegisterCallBackFunc( return SUCCESS; } -Status InnerSession::RunGraphAsync(uint32_t graph_id, const std::vector &inputs, - std::vector &outputs, std::function callback) { +Status InnerSession::RunGraphAsync(uint32_t graph_id, const std::vector &inputs, + RunAsyncCallback callback) { UpdateThreadContext(graph_id); GELOGI("[InnerSession:%lu] run graph on session, graph_id=%u.", session_id_, graph_id); - Status ret = graph_manager_.RunGraphAsync(graph_id, inputs, outputs, session_id_, callback); + Status ret = graph_manager_.RunGraphAsync(graph_id, inputs, session_id_, callback); if (ret != SUCCESS) { GELOGE(ret, "[InnerSession:%lu] run graph failed, graph_id=%u.", session_id_, graph_id); return ret; diff --git a/src/ge/session/inner_session.h b/src/ge/session/inner_session.h index b35d01e6..3b009a44 100644 --- a/src/ge/session/inner_session.h +++ b/src/ge/session/inner_session.h @@ -41,8 +41,7 @@ class InnerSession { Status RemoveGraph(uint32_t graph_id); - Status RunGraphAsync(uint32_t graph_id, const std::vector &inputs, std::vector &outputs, - std::function callback); + Status RunGraphAsync(uint32_t graph_id, const std::vector &inputs, RunAsyncCallback callback); Status Finalize(); diff --git a/src/ge/session/session_manager.cc b/src/ge/session/session_manager.cc index aa34441a..bfdd9f2d 100644 --- a/src/ge/session/session_manager.cc +++ b/src/ge/session/session_manager.cc @@ -17,13 +17,12 @@ #include "session/session_manager.h" #include #include -#include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" -#include "graph/manager/util/rt_context_util.h" -#include "graph/load/new_model_manager/model_manager.h" +#include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" +#include "graph/load/new_model_manager/model_manager.h" +#include "graph/manager/util/rt_context_util.h" -using domi::ATTR_NAME_SESSION_GRAPH_ID; using std::map; using std::string; using std::vector; @@ -157,12 +156,16 @@ Status SessionManager::AddGraph(SessionId session_id, uint32_t graph_id, const G innerSession = it->second; } auto compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); std::string session_graph_id = std::to_string(session_id) + "_" + std::to_string(graph_id); if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { GELOGW("Set graph session_graph_id attr failed."); } else { GELOGD("Set graph session_graph_id attr to [%s]", session_graph_id.c_str()); } + for (auto graph : compute_graph->GetAllSubgraphs()) { + AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + } } return innerSession->AddGraph(graph_id, graph, options); } @@ -243,8 +246,8 @@ Status SessionManager::RegisterCallBackFunc( return innerSession->RegisterCallBackFunc(key, callback); } -Status SessionManager::RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, - std::vector &outputs, std::function callback) { +Status SessionManager::RunGraphAsync(SessionId session_id, uint32_t graph_id, + const std::vector &inputs, RunAsyncCallback callback) { if (!init_flag_) { GELOGE(GE_SESSION_MANAGER_NOT_INIT); return GE_SESSION_MANAGER_NOT_INIT; @@ -259,7 +262,7 @@ Status SessionManager::RunGraphAsync(SessionId session_id, uint32_t graph_id, co innerSession = it->second; } } - return innerSession->RunGraphAsync(graph_id, inputs, outputs, callback); + return innerSession->RunGraphAsync(graph_id, inputs, callback); } bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) { if (!init_flag_) { diff --git a/src/ge/session/session_manager.h b/src/ge/session/session_manager.h index 10ff3edf..111795ed 100644 --- a/src/ge/session/session_manager.h +++ b/src/ge/session/session_manager.h @@ -110,11 +110,10 @@ class SessionManager { /// @param [in] session_id session id /// @param [in] graph_id graph id /// @param [in] inputs input data - /// @param [out] outputs output data /// @return Status result of function /// - Status RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, - std::vector &outputs, std::function callback); + Status RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector &inputs, + RunAsyncCallback callback); /// /// @ingroup ge_graph diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index 475e463f..04b09389 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -19,6 +19,7 @@ #include "common/fmk_types.h" #include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" #include "graph/load/new_model_manager/model_utils.h" #include "runtime/mem.h" @@ -75,7 +76,7 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: return SUCCESS; } -Status SingleOp::UpdateArgs(const std::vector &inputs, const std::vector &outputs) { +Status SingleOp::GetArgs(const std::vector &inputs, const std::vector &outputs) { size_t arg_index = 0; if (use_physical_addr_) { for (auto &input : inputs) { @@ -108,7 +109,14 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve args_[arg_index++] = reinterpret_cast(output.data); } } + return SUCCESS; +} +Status SingleOp::UpdateArgs(const std::vector &inputs, const std::vector &outputs) { + Status ret = GetArgs(inputs, outputs); + if (ret != SUCCESS) { + return ret; + } size_t num_args = arg_table_.size(); for (size_t i = 0; i < num_args; ++i) { std::vector &ptr_to_arg_in_tasks = arg_table_[i]; @@ -121,7 +129,19 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve *arg_addr = args_[i]; } } - + for (auto &task : tasks_) { + if (task->GetOpTaskType() == OP_TASK_AICPU) { + GELOGD("Update aicpu task args"); + AiCpuTask *task_aicpu = dynamic_cast(task); + GE_CHECK_NOTNULL(task_aicpu); + auto rt_ret = rtMemcpyAsync(task_aicpu->GetIOAddr(), sizeof(uint64_t) * args_.size(), &args_[0], + sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); + return RT_FAILED; + } + } + } return SUCCESS; } diff --git a/src/ge/single_op/single_op.h b/src/ge/single_op/single_op.h index ba025c0b..08782b3b 100644 --- a/src/ge/single_op/single_op.h +++ b/src/ge/single_op/single_op.h @@ -39,6 +39,7 @@ class SingleOp { private: Status ValidateArgs(const std::vector &inputs, const std::vector &outputs); Status UpdateArgs(const std::vector &inputs, const std::vector &outputs); + Status GetArgs(const std::vector &inputs, const std::vector &outputs); friend class SingleOpModel; rtStream_t stream_ = nullptr; diff --git a/src/ge/single_op/single_op_manager.cc b/src/ge/single_op/single_op_manager.cc index 8014cc2a..79f3f044 100644 --- a/src/ge/single_op/single_op_manager.cc +++ b/src/ge/single_op/single_op_manager.cc @@ -23,19 +23,17 @@ #include "framework/common/debug/ge_log.h" namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -SingleOpManager::~SingleOpManager() { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() { for (auto &it : stream_resources_) { delete it.second; it.second = nullptr; } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -Status SingleOpManager::GetOpFromModel(const std::string &model_name, - const ModelData &model_data, - void *stream, - SingleOp **single_op) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFromModel(const std::string &model_name, + const ModelData &model_data, + void *stream, + SingleOp **single_op) { if (single_op == nullptr) { GELOGE(PARAM_INVALID, "single op is null"); return PARAM_INVALID; @@ -57,14 +55,13 @@ Status SingleOpManager::GetOpFromModel(const std::string &model_name, resource_id = reinterpret_cast(stream); } - GELOGI("GetOpFromModel in. model name = %s, resource id = 0x%lx", - model_name.c_str(), + GELOGI("GetOpFromModel in. model name = %s, resource id = 0x%lx", model_name.c_str(), static_cast(resource_id)); StreamResource *res = GetResource(resource_id); if (res == nullptr) { - GELOGE(MEMALLOC_FAILED, "GetResource failed"); - return MEMALLOC_FAILED; + GELOGE(MEMALLOC_FAILED, "GetResource failed"); + return MEMALLOC_FAILED; } SingleOp *op = res->GetOperator(model_data.model_data); @@ -81,7 +78,7 @@ Status SingleOpManager::GetOpFromModel(const std::string &model_name, return ret; } - auto *new_op = new(std::nothrow)SingleOp(); + auto *new_op = new (std::nothrow) SingleOp(); if (new_op == nullptr) { GELOGE(MEMALLOC_FAILED, "new SingleOp failed"); return MEMALLOC_FAILED; @@ -90,10 +87,8 @@ Status SingleOpManager::GetOpFromModel(const std::string &model_name, GELOGI("To build operator: %s", model_name.c_str()); ret = model.BuildOp(*res, *new_op); if (ret != SUCCESS) { - GELOGE(ret, "Build op failed. op = %s, resource id = 0x%lx, ret = %u", - model_name.c_str(), - static_cast(resource_id), - ret); + GELOGE(ret, "Build op failed. op = %s, resource id = 0x%lx, ret = %u", model_name.c_str(), + static_cast(resource_id), ret); delete new_op; new_op = nullptr; return ret; @@ -106,8 +101,7 @@ Status SingleOpManager::GetOpFromModel(const std::string &model_name, return SUCCESS; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -Status SingleOpManager::ReleaseResource(void *stream) { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::ReleaseResource(void *stream) { auto resource_id = reinterpret_cast(stream); GELOGI("ReleaseResource in. resource id = 0x%lx", static_cast(resource_id)); std::lock_guard lock(mutex_); @@ -126,7 +120,7 @@ StreamResource *SingleOpManager::GetResource(uintptr_t resource_id) { auto it = stream_resources_.find(resource_id); StreamResource *res = nullptr; if (it == stream_resources_.end()) { - res = new (std::nothrow)StreamResource(); + res = new (std::nothrow) StreamResource(); if (res != nullptr) { stream_resources_.emplace(resource_id, res); } diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index 22e46008..16375741 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -27,22 +27,13 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/tensor_utils.h" #include "runtime/rt.h" +#include "task/aicpu_task_builder.h" #include "task/tbe_task_builder.h" +using domi::TaskDef; using std::unique_ptr; using std::vector; -using domi::AIPP_DATA_TYPE; -using domi::ALLOC_MEMORY_MAX_SIZE; -using domi::DATA_TYPE; -using domi::MODEL_ATTR_TASK_GEN_BASE_ADDR; -using domi::MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; -using domi::ModelFileHeader; -using domi::ModelHelper; -using domi::NETOUTPUT; -using domi::OmFileLoadHelper; -using domi::TaskDef; - namespace ge { namespace { const size_t kDataOutputNum = 1; @@ -80,14 +71,19 @@ void SingleOpModel::ParseOpModelParams(ModelHelper &model_helper, SingleOpModelP GE_CHECK_NOTNULL_JUST_RETURN(model); ret = ge::AttrUtils::GetInt(model, ATTR_MODEL_MEMORY_SIZE, value); param.memory_size = ret ? static_cast(value) : 0; + ret = ge::AttrUtils::GetInt(model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, value); + param.zero_copy_mem_size = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(model, ATTR_MODEL_WEIGHT_SIZE, value); param.weight_size = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(model, MODEL_ATTR_TASK_GEN_BASE_ADDR, value); param.base_addr = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(model, MODEL_ATTR_TASK_GEN_WEIGHT_ADDR, value); param.weight_addr = ret ? static_cast(value) : 0; + ret = ge::AttrUtils::GetInt(model, ATTR_MODEL_CORE_TYPE, value); + param.core_type = ret ? value : 0; - GELOGI("ParseOpModelParams(), memory_size:%lu, weight_size:%lu.", param.memory_size, param.weight_size); + GELOGI("ParseOpModelParams(), total_memory_size:%lu, zero_copy_size:%lu, weight_size:%lu. core_type = %lu", + param.memory_size, param.zero_copy_mem_size, param.weight_size, param.core_type); } Status SingleOpModel::InitModelMem(StreamResource &res) { @@ -99,14 +95,17 @@ Status SingleOpModel::InitModelMem(StreamResource &res) { } if (model_params_.memory_size > 0) { - model_params_.mem_base = res.MallocMemory(model_params_.memory_size); + const string purpose("malloc feature map memory on model execute."); + GELOGI("total memory: %lu, zero_copy_mem: %lu", model_params_.memory_size, model_params_.zero_copy_mem_size); + model_params_.mem_base = res.MallocMemory(purpose, model_params_.memory_size - model_params_.zero_copy_mem_size); if (model_params_.mem_base == nullptr) { return RT_FAILED; } } if (model_params_.weight_size > 0) { - model_params_.weight_base = res.MallocWeight(model_params_.weight_size); + const string purpose("malloc weights memory on model execute."); + model_params_.weight_base = res.MallocWeight(purpose, model_params_.weight_size); if (model_params_.weight_base == nullptr) { // no need to free memory, for that was handled by StreamResources return RT_FAILED; @@ -235,6 +234,7 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); if (task_type == RT_MODEL_TASK_KERNEL) { + GELOGD("Building TBE task"); OpTask *task = nullptr; auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); if (ret != SUCCESS) { @@ -243,8 +243,13 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { single_op.tasks_.emplace_back(task); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - GELOGD("BuildKernelExTask is not supported. modelName = %s", model_name_.c_str()); - return UNSUPPORTED; + GELOGD("Building AICPU task"); + OpTask *task = nullptr; + auto ret = BuildKernelExTask(task_def.kernel_ex(), single_op, &task); + if (ret != SUCCESS) { + return ret; + } + single_op.tasks_.emplace_back(task); } else { // skip GELOGD("Skip task type: %d", static_cast(task_type)); @@ -309,6 +314,29 @@ Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, SingleO return SUCCESS; } +Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, SingleOp &single_op, OpTask **task) { + auto iter = op_list_.find(kernel_def.op_index()); + if (iter == op_list_.end()) { + GELOGE(INTERNAL_ERROR, "op desc not found. op index = %u", kernel_def.op_index()); + return INTERNAL_ERROR; + } + + std::unique_ptr aicpu_task(new (std::nothrow) AiCpuTask()); + if (aicpu_task == nullptr) { + GELOGE(MEMALLOC_FAILED, "create aicpu op task failed"); + return MEMALLOC_FAILED; + } + auto builder = AiCpuTaskBuilder(iter->second, kernel_def); + auto ret = builder.BuildTask(*aicpu_task, model_params_); + if (ret != SUCCESS) { + GELOGE(ret, "build aicpu op task failed"); + return ret; + } + + *task = aicpu_task.release(); + return SUCCESS; +} + Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ret = InitModelMem(resource); if (ret != SUCCESS) { diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index 528004b8..4d8aae30 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -34,11 +34,13 @@ struct SingleOpModelParam { uint64_t memory_size = 0; uint64_t weight_addr = 0; uint64_t weight_size = 0; + uint64_t zero_copy_mem_size = 0; uint8_t *mem_base = nullptr; uint8_t *weight_base = nullptr; std::map addr_mapping_; + int64_t core_type = 0; }; class SingleOpModel { @@ -61,15 +63,16 @@ class SingleOpModel { Status BuildTaskList(SingleOp &single_op); Status BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task); + Status BuildKernelExTask(const domi::KernelExDef &kernel_def, SingleOp &single_op, OpTask **task); - static void ParseOpModelParams(domi::ModelHelper &model_helper, SingleOpModelParam ¶m); + static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(TbeOpTask *task, SingleOp &op); std::string model_name_; const void *ori_model_data_; uint32_t ori_model_size_; - domi::ModelHelper model_helper_; + ModelHelper model_helper_; map op_list_; SingleOpModelParam model_params_; diff --git a/src/ge/single_op/stream_resource.cc b/src/ge/single_op/stream_resource.cc index 53dfb183..e48afb96 100644 --- a/src/ge/single_op/stream_resource.cc +++ b/src/ge/single_op/stream_resource.cc @@ -55,7 +55,8 @@ SingleOp *StreamResource::GetOperator(const void *key) { return it->second; } -uint8_t *StreamResource::DoMallocMemory(size_t size, size_t &max_allocated, std::vector &allocated) { +uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, + std::vector &allocated) { if (size <= max_allocated && !allocated.empty()) { GELOGD("reuse last memory"); return allocated.back(); @@ -67,7 +68,7 @@ uint8_t *StreamResource::DoMallocMemory(size_t size, size_t &max_allocated, std: GELOGE(RT_FAILED, "rtMalloc failed, size = %zu, ret = %d", size, ret); return nullptr; } - GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "malloc function.", size) + GE_PRINT_DYNAMIC_MEMORY(rtMalloc, purpose.c_str(), size) ret = rtMemset(buffer, size, 0U, size); if (ret != RT_ERROR_NONE) { @@ -83,15 +84,15 @@ uint8_t *StreamResource::DoMallocMemory(size_t size, size_t &max_allocated, std: return buffer; } -uint8_t *StreamResource::MallocMemory(size_t size) { +uint8_t *StreamResource::MallocMemory(const std::string &purpose, size_t size) { GELOGD("To Malloc memory, size = %zu", size); - uint8_t *buffer = DoMallocMemory(size, max_memory_size_, memory_list_); + uint8_t *buffer = DoMallocMemory(purpose, size, max_memory_size_, memory_list_); return buffer; } -uint8_t *StreamResource::MallocWeight(size_t size) { +uint8_t *StreamResource::MallocWeight(const std::string &purpose, size_t size) { GELOGD("To Malloc weight, size = %zu", size); - uint8_t *buffer = DoMallocMemory(size, max_weight_size_, weight_list_); + uint8_t *buffer = DoMallocMemory(purpose, size, max_weight_size_, weight_list_); return buffer; } } // namespace ge diff --git a/src/ge/single_op/stream_resource.h b/src/ge/single_op/stream_resource.h index 043a718c..fc114c08 100644 --- a/src/ge/single_op/stream_resource.h +++ b/src/ge/single_op/stream_resource.h @@ -41,11 +41,12 @@ class StreamResource { SingleOp *GetOperator(const void *key); - uint8_t *MallocMemory(size_t size); - uint8_t *MallocWeight(size_t size); + uint8_t *MallocMemory(const std::string &purpose, size_t size); + uint8_t *MallocWeight(const std::string &purpose, size_t size); private: - static uint8_t *DoMallocMemory(size_t size, size_t &max_allocated, std::vector &allocated); + static uint8_t *DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, + std::vector &allocated); size_t max_memory_size_ = 0; size_t max_weight_size_ = 0; diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc new file mode 100644 index 00000000..3f571d30 --- /dev/null +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -0,0 +1,135 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "single_op/task/aicpu_task_builder.h" +#include +#include "single_op/task/build_task_utils.h" +#include "runtime/mem.h" +#include "framework/common/debug/ge_log.h" +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/load/new_model_manager/model_manager.h" + +namespace ge { +AiCpuTaskBuilder::AiCpuTaskBuilder(const OpDescPtr &op_desc, const domi::KernelExDef &kernel_def) + : op_desc_(op_desc), kernel_def_(kernel_def) {} + +Status AiCpuTaskBuilder::SetInputOutputAddr(void **io_addr, const std::vector &addresses) { + size_t arg_size = kernel_def_.args_size(); + auto rt_ret = rtMalloc(io_addr, arg_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMallocHost failed, size = %zu, ret = %d", arg_size, rt_ret); + return RT_FAILED; + } + + const void *src_addr = reinterpret_cast(addresses.data()); + uint64_t src_len = sizeof(void *) * addresses.size(); + rt_ret = rtMemcpy(*io_addr, arg_size, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST); + if (rt_ret != RT_ERROR_NONE) { + (void)rtFree(*io_addr); + GELOGE(RT_FAILED, "rtMemcpy addresses failed, ret = %d", rt_ret); + return RT_FAILED; + } + + return SUCCESS; +} + +Status AiCpuTaskBuilder::SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &fwk_op_kernel) { + auto sec_ret = + memcpy_s(&fwk_op_kernel, sizeof(STR_FWK_OP_KERNEL), kernel_def_.args().data(), sizeof(STR_FWK_OP_KERNEL)); + if (sec_ret != EOK) { + GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); + return FAILED; + } + + auto io_addr_val = static_cast(reinterpret_cast(io_addr)); + fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = io_addr_val; + auto ws_addr_val = static_cast(reinterpret_cast(ws_addr)); + fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = ws_addr_val; + return SUCCESS; +} + +Status AiCpuTaskBuilder::SetKernelArgs(void **args, STR_FWK_OP_KERNEL &fwk_op_kernel) { + void *fwk_op_args = nullptr; + auto rt_ret = rtMalloc(&fwk_op_args, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "malloc arg memory failed, ret = %d", rt_ret); + return RT_FAILED; + } + + rt_ret = + rtMemcpy(fwk_op_args, sizeof(STR_FWK_OP_KERNEL), &fwk_op_kernel, sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_HOST); + if (rt_ret != RT_ERROR_NONE) { + (void)rtFree(fwk_op_args); + GELOGE(RT_FAILED, "copy args failed, ret = %d", rt_ret); + return RT_FAILED; + } + *args = fwk_op_args; + return SUCCESS; +} + +Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam ¶m) { + if (kernel_def_.args_size() != sizeof(STR_FWK_OP_KERNEL)) { + GELOGE(PARAM_INVALID, "sizeof STR_FWK_OP_KERNEL is: %lu, but args_size is: %d", sizeof(STR_FWK_OP_KERNEL), + kernel_def_.args_size()); + return PARAM_INVALID; + } + auto addresses = BuildTaskUtils::GetAddresses(op_desc_, param); + auto ws_addr_vec = addresses.at(BuildTaskUtils::kAddressIndexWorkspace); + if (ws_addr_vec.empty()) { + GELOGE(PARAM_INVALID, "workspace Data Address is empty."); + return PARAM_INVALID; + } + auto rt_ret = rtMemcpy(ws_addr_vec[0], kernel_def_.task_info_size(), kernel_def_.task_info().data(), + kernel_def_.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(FAILED, "rtMemcpy error: 0x%X", rt_ret); + return FAILED; + } + + void *io_addr = nullptr; + auto ret = SetInputOutputAddr(&io_addr, BuildTaskUtils::JoinAddresses(addresses)); + if (ret != SUCCESS) { + return ret; + } + + STR_FWK_OP_KERNEL fwk_op_kernel; + ret = SetFmkOpKernel(io_addr, ws_addr_vec[0], fwk_op_kernel); + if (ret != SUCCESS) { + rtFree(io_addr); + return ret; + } + // Create session + auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; + GE_CHECK_NOTNULL(ModelManager::GetInstance()); + GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, + GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); + return FAILED;) + ret = SetKernelArgs(&task.args_, fwk_op_kernel); + if (ret != SUCCESS) { + rtFree(io_addr); + return ret; + } + + task.arg_size_ = sizeof(STR_FWK_OP_KERNEL); + task.op_type_ = op_desc_->GetName(); + task.io_addr_ = io_addr; + task.task_info_ = kernel_def_.task_info(); + task.workspace_addr_ = ws_addr_vec[0]; + + return SUCCESS; +} + +} // namespace ge diff --git a/src/ge/single_op/task/aicpu_task_builder.h b/src/ge/single_op/task/aicpu_task_builder.h new file mode 100644 index 00000000..0253ebd0 --- /dev/null +++ b/src/ge/single_op/task/aicpu_task_builder.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_SINGLE_OP_TASK_AICPU_TASK_BUILDER_H_ +#define GE_SINGLE_OP_TASK_AICPU_TASK_BUILDER_H_ + +#include +#include "graph/op_desc.h" +#include "single_op/single_op.h" +#include "single_op/single_op_model.h" +#include "cce/aicpu_engine_struct.h" + +namespace ge { +class AiCpuTaskBuilder { + public: + AiCpuTaskBuilder(const OpDescPtr &op_desc, const domi::KernelExDef &kernel_def); + ~AiCpuTaskBuilder() = default; + + Status BuildTask(AiCpuTask &task, const SingleOpModelParam ¶m); + + private: + static Status SetKernelArgs(void **args, STR_FWK_OP_KERNEL &kernel); + Status SetInputOutputAddr(void **io_addr, const std::vector &addresses); + Status SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &kernel); + + const OpDescPtr &op_desc_; + const domi::KernelExDef &kernel_def_; +}; +} // namespace ge + +#endif // GE_SINGLE_OP_TASK_AICPU_TASK_BUILDER_H_ \ No newline at end of file diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index f983e080..d515336b 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -31,9 +31,7 @@ void TbeOpTask::SetKernelArgs(void *args, size_t arg_size, uint32_t block_dim) { block_dim_ = block_dim; } -void TbeOpTask::SetSmDesc(void *sm_desc) { - sm_desc_ = sm_desc; -} +void TbeOpTask::SetSmDesc(void *sm_desc) { sm_desc_ = sm_desc; } TbeOpTask::~TbeOpTask() { if (args_ != nullptr) { @@ -45,22 +43,14 @@ TbeOpTask::~TbeOpTask() { } } -const void* TbeOpTask::GetArgs() const { - return args_; -} +const void *TbeOpTask::GetArgs() const { return args_; } -size_t TbeOpTask::GetArgSize() const { - return arg_size_; -} +size_t TbeOpTask::GetArgSize() const { return arg_size_; } -const std::string& TbeOpTask::GetStubName() const { - return stub_name_; -} +const std::string &TbeOpTask::GetStubName() const { return stub_name_; } Status TbeOpTask::LaunchKernel(rtStream_t stream) { - GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", - this->stub_name_.c_str(), - block_dim_); + GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", this->stub_name_.c_str(), block_dim_); auto *sm_desc = reinterpret_cast(sm_desc_); auto ret = rtKernelLaunch(stub_func_, block_dim_, args_, static_cast(arg_size_), sm_desc, stream); if (ret != RT_ERROR_NONE) { @@ -71,4 +61,35 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->stub_name_.c_str()); return SUCCESS; } + +AiCpuTask::~AiCpuTask() { + if (args_ != nullptr) { + rtFree(args_); + } + + if (io_addr_ != nullptr) { + (void)rtFree(io_addr_); + } +} + +void *AiCpuTask::GetIOAddr() { return io_addr_; } + +Status AiCpuTask::LaunchKernel(rtStream_t stream) { + auto ret = rtMemcpyAsync(workspace_addr_, task_info_.size(), task_info_.data(), task_info_.size(), + RT_MEMCPY_HOST_TO_DEVICE, stream); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMemcpyAsync workspace data failed. ret = %d, task = %s", ret, this->op_type_.c_str()); + return RT_FAILED; + } + + GELOGD("To invoke rtKernelLaunchEx. task = %s", this->op_type_.c_str()); + ret = rtKernelLaunchEx(args_, arg_size_, 0, stream); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->op_type_.c_str()); + return RT_FAILED; + } + + GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->op_type_.c_str()); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index 5cda8ba4..95e42772 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -25,24 +25,32 @@ #include "graph/op_kernel_bin.h" namespace ge { +enum OpTaskType { + OP_TASK_TBE = 0, + OP_TASK_AICPU, + OP_TASK_INVALID, +}; + class OpTask { public: OpTask() = default; virtual ~OpTask() = default; virtual Status LaunchKernel(rtStream_t stream) = 0; + virtual OpTaskType GetOpTaskType() = 0; }; class TbeOpTask : public OpTask { public: ~TbeOpTask() override; Status LaunchKernel(rtStream_t stream) override; + OpTaskType GetOpTaskType() override { return OP_TASK_TBE; } void SetSmDesc(void *sm_desc); void SetStubFunc(const std::string &name, const void *stub_func); void SetKernelArgs(void *args, size_t arg_size, uint32_t block_dim); - const void* GetArgs() const; + const void *GetArgs() const; size_t GetArgSize() const; - const std::string& GetStubName() const; + const std::string &GetStubName() const; private: const void *stub_func_ = nullptr; @@ -52,6 +60,25 @@ class TbeOpTask : public OpTask { void *sm_desc_ = nullptr; std::string stub_name_; }; + +class AiCpuTask : public OpTask { + public: + AiCpuTask() = default; + ~AiCpuTask() override; + + Status LaunchKernel(rtStream_t stream) override; + OpTaskType GetOpTaskType() override { return OP_TASK_AICPU; } + void *GetIOAddr(); + + private: + friend class AiCpuTaskBuilder; + void *workspace_addr_ = nullptr; + std::string task_info_; + void *args_ = nullptr; + size_t arg_size_ = 0; + std::string op_type_; + void *io_addr_ = nullptr; +}; } // namespace ge #endif // GE_SINGLE_OP_TASK_OP_TASK_H_ diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index 1a47402e..c0f6877f 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -22,15 +22,13 @@ #include "common/helper/model_helper.h" #include "framework/common/debug/ge_log.h" -#include "framework/common/op/attr_define.h" #include "graph/load/new_model_manager/model_utils.h" +#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/task_info/task_info.h" #include "graph/manager/graph_var_manager.h" #include "runtime/rt.h" #include "single_op/task/build_task_utils.h" -using domi::TVM_ATTR_NAME_METADATA; - namespace ge { namespace { std::mutex g_reg_mutex; @@ -91,16 +89,17 @@ TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const OpDescPtr &o const domi::KernelDef &kernel_def) : op_desc_(op_desc), kernel_def_(kernel_def), stub_name_(model_name + "/" + op_desc->GetName() + "_tvmbin") {} -Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle) const { +Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, + const SingleOpModelParam ¶m) const { rtDevBinary_t binary; binary.version = 0; binary.data = kernel_bin.GetBinData(); binary.length = kernel_bin.GetBinDataSize(); - binary.magic = RT_DEV_BINARY_MAGIC_ELF; + binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; auto ret = rtDevBinaryRegister(&binary, bin_handle); if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtDevBinaryRegister failed, bin key = %s, rt ret = %d", stub_name_.c_str(), - static_cast(ret)); + GELOGE(RT_FAILED, "rtDevBinaryRegister failed, bin key = %s, core_type = %ld, rt ret = %d", stub_name_.c_str(), + param.core_type, static_cast(ret)); return RT_FAILED; } @@ -134,13 +133,13 @@ Status TbeTaskBuilder::DoRegisterFunction(void *bin_handle, const char *stub_nam return SUCCESS; } -Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, - void **bin_handle) { +Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, void **bin_handle, + const SingleOpModelParam ¶m) { std::string kernel_name; GetKernelName(op_desc_, kernel_name); void *handle = nullptr; - auto ret = DoRegisterBinary(tbe_kernel, &handle); + auto ret = DoRegisterBinary(tbe_kernel, &handle, param); if (ret != SUCCESS) { return ret; } @@ -162,7 +161,7 @@ Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const return SUCCESS; } -Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task) { +Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m) { KernelBinRegistry ®istry = KernelBinRegistry::GetInstance(); // check if already registered const char *stub_func = registry.GetStubFunc(stub_name_); @@ -192,7 +191,7 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task) { } void *bin_handle = nullptr; - auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle); + auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle, param); if (ret == SUCCESS) { holder->SetBinHandle(bin_handle); if (!registry.AddKernel(stub_name_, holder)) { @@ -287,7 +286,7 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ return ret; } - ret = RegisterKernel(task); + ret = RegisterKernel(task, param); if (ret != SUCCESS) { return ret; } diff --git a/src/ge/single_op/task/tbe_task_builder.h b/src/ge/single_op/task/tbe_task_builder.h index 25441289..5e0965bf 100644 --- a/src/ge/single_op/task/tbe_task_builder.h +++ b/src/ge/single_op/task/tbe_task_builder.h @@ -74,9 +74,10 @@ class TbeTaskBuilder { Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; - Status RegisterKernel(TbeOpTask &task); - Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle); - Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle) const; + Status RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m); + Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle, + const SingleOpModelParam ¶m); + Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; Status DoRegisterMeta(void *bin_handle); static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); diff --git a/src/proto/ge_ir.proto b/src/proto/ge_ir.proto index 96962346..87886c84 100644 --- a/src/proto/ge_ir.proto +++ b/src/proto/ge_ir.proto @@ -176,6 +176,7 @@ message OpDef repeated bool is_input_const = 32; repeated TensorDescriptor input_desc = 33; repeated TensorDescriptor output_desc = 34; + repeated string subgraph_name = 35; } // Graph definition @@ -187,7 +188,7 @@ message GraphDef repeated string output = 5; // Graph output repeated OpDef op = 6; // List of operators - + map attr = 11; // Extended field } @@ -197,7 +198,7 @@ message ModelDef string name = 1; // name uint32 version = 2; // IR Proto verion string custom_version = 3; // User model version number, passed in by user - + repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef map attr = 11; // Extended field diff --git a/src/proto/op_mapping_info.proto b/src/proto/op_mapping_info.proto index ea4c4a8d..35383c5b 100644 --- a/src/proto/op_mapping_info.proto +++ b/src/proto/op_mapping_info.proto @@ -30,6 +30,15 @@ message Output { int32 original_output_index = 6; int32 original_output_data_type = 7; int32 original_output_format = 8; + uint64 size = 9; +}; + +message Input { + int32 data_type =1; + int32 format = 2; + Shape shape = 3; + uint64 address = 4; + uint64 size = 5; }; message Op { @@ -43,6 +52,7 @@ message Task { Op op = 3; repeated Output output = 4; bool end_graph = 5; + repeated Input input = 6; }; message OpMappingInfo { diff --git a/src/proto/task.proto b/src/proto/task.proto index 3eb8de5c..8ef5c2e2 100644 --- a/src/proto/task.proto +++ b/src/proto/task.proto @@ -31,7 +31,7 @@ message ModelTaskDef { repeated bytes op = 15; // input/output opdef in bytes - uint64 base_addr = 16; // base addr + uint64 base_addr = 16; // base addr uint64 weight_addr = 17; // weight addr uint32 batch_num = 18; } @@ -58,6 +58,10 @@ message TaskDef { bytes private_def = 34; uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future StreamSwitchNDef stream_switch_n = 36; + + LabelSetDef label_set = 37; + LabelGotoExDef label_goto_ex = 38; + LabelSwitchByIndexDef label_switch_by_index = 39; } message KernelDef { @@ -119,6 +123,7 @@ message MemcpyAsyncDef { uint64 src = 3; uint64 count = 4; uint32 kind = 5; + uint32 op_index = 6; } message StreamSwitchDef { @@ -142,3 +147,20 @@ message StreamSwitchNDef { uint32 element_size = 5; uint32 data_type = 6; } + +message LabelSetDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelGotoExDef { + uint32 op_index = 1; + uint32 label_id = 2; + uint32 model_id = 3; +} + +message LabelSwitchByIndexDef { + uint32 op_index = 1; + uint32 label_max = 2; +} diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index 3075f795..09827358 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -50,6 +50,14 @@ enum FWKOperateType { FWK_ADPT_SINGLE_OP_RUN }; +// Extend Info type for task +enum FWKTaskExtInfoType { + FWK_ADPT_EXT_SHAPE_TYPE = 0, + FWK_ADPT_EXT_INPUT_SHAPE, + FWK_ADPT_EXT_OUTPUT_SHAPE, + FWK_ADPT_EXT_INVALID +}; + // API Parameter Structure struct StrFWKKernel { FWKOperateType opType; @@ -66,10 +74,20 @@ struct StrFWKKernel { uint64_t inputOutputBuf; // InputOutput shap protobuf addr, need convert to void* uint64_t workspaceBaseAddr; // Workspace base addr, need convert to void* uint64_t inputOutputAddr; // InputOutput addr, need convert to void* + + uint64_t extInfoNum; // extend info number + uint64_t extInfoAddr; // extend info addr list, ExtInfo structure, num equal to extInfoNum } __attribute__((packed)); typedef StrFWKKernel FWKOperateParam; +// Extend info structure for extInfoAddr +struct ExtInfo{ + int32_t infoType; + uint32_t infoLen; + uint64_t infoAddr; +} __attribute__((packed)); + struct ResultSummary { uint64_t shape_data_ptr; // shape data addr, need convert to void* uint64_t shape_data_size; // num of dims diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index d85d7bc4..89c21f1c 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -100,6 +100,12 @@ struct model_feature { float *gradient_time; /**< The BP compution time of each gradient */ }; +enum GradSplitForceMode { + FORCE_NONE, /**< no force */ + FORCE_SIZE, /**< force split gradient by size */ + FORCE_RESERVED /**< reserved */ +}; + /** * @brief stream handle. */ diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index 8ac2b4bc..a448d411 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -247,7 +247,7 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT * @return hcclResult_t */ hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, - u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx); + u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE); /** * @brief Set the gradient split strategy with in the group, according to gradient index. diff --git a/third_party/fwkacllib/inc/ops/aipp.h b/third_party/fwkacllib/inc/ops/aipp.h index da2a36ca..6053bb76 100644 --- a/third_party/fwkacllib/inc/ops/aipp.h +++ b/third_party/fwkacllib/inc/ops/aipp.h @@ -17,7 +17,7 @@ #ifndef GE_OP_AIPP_H #define GE_OP_AIPP_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** diff --git a/third_party/fwkacllib/inc/ops/all_ops.h b/third_party/fwkacllib/inc/ops/all_ops.h index 36c991ff..d6bd1353 100644 --- a/third_party/fwkacllib/inc/ops/all_ops.h +++ b/third_party/fwkacllib/inc/ops/all_ops.h @@ -35,7 +35,6 @@ #include "decode_wheels_target.h" #include "elewise_calculation_ops.h" #include "fastrcnn_predictions.h" -#include "fsrdetectionoutput_ops.h" #include "functional_ops.h" #include "get_data_ops.h" #include "hcom_ops.h" @@ -58,14 +57,12 @@ #include "outfeed_ops.h" #include "pad_ops.h" #include "parsing_ops.h" -#include "power_ops.h" #include "quantize_ops.h" #include "ragged_conversion_ops.h" #include "random_ops.h" #include "reduce_ops.h" #include "resource_variable_ops.h" #include "rnn.h" -#include "roipooling_ops.h" #include "rpn_ops.h" #include "rpn_proposals.h" #include "save_ops.h" @@ -80,4 +77,5 @@ #include "string_ops.h" #include "swap_co_ops.h" #include "transformation_ops.h" +#include "condtake_ops.h" #endif // BUILT_IN_OP_PROTO_INC_ALL_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/array_ops.h b/third_party/fwkacllib/inc/ops/array_ops.h index 0d1126aa..e1ea5537 100644 --- a/third_party/fwkacllib/inc/ops/array_ops.h +++ b/third_party/fwkacllib/inc/ops/array_ops.h @@ -397,6 +397,24 @@ REG_OP(ListDiff) .ATTR(out_idx, Type, DT_INT32) .OP_END_FACTORY_REG(ListDiff) +/** +*@brief Create an empty tensor, using the shape and dtype specified in attributes. + +*@par Attributes: +*@li dtype: Specify the data type of the empty tensor. +*@li shape: Specify the shape of the empty tensor. + +*@par Outputs: +*y: The empty constant tensor. + +*/ +REG_OP(_ParallelConcatStart) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Type, DT_INT32) + .ATTR(shape, ListInt, {}) + .OP_END_FACTORY_REG(_ParallelConcatStart) + /** *@brief Creates a constant tensor from a tensor-like object. This operator is used for inference. \n Operator Const has the same definition as operator Constant. @@ -595,6 +613,9 @@ REG_OP(ExpandDims) *@par Outputs: *y: A tensor. + +*@par Attention: +*This operator cannot be directly called by the acllopExecute API. */ REG_OP(Reshape) .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, @@ -701,6 +722,47 @@ REG_OP(PlaceHolder) .ATTR(anchorIndex, Int, 0) // check if these node are from save anchor .OP_END_FACTORY_REG(PlaceHolder) +/** +*@brief Inserts a placeholder with default value for a tensor. + +*@par Inputs: +*x: A tensor. + +*@par Attributes: +*@li dtype: data type of tensor. +*@li shape: tensor shape. + +*@par Outputs: +*y: The created placeholder tensor. + +*/ +REG_OP(PlaceholderWithDefault) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .REQUIRED_ATTR(shape, ListInt) + .OP_END_FACTORY_REG(PlaceholderWithDefault) + +/** +*@brief Reads and returns the value of the input variable tensor. + +*@par Inputs: +*x: A tensor. + +*@par Attributes: +*dtype: An optional int32 or int64. The output data type. Defaults to int32. + +*@par Outputs: +*y: A tensor. + +*/ +REG_OP(ReadVariableOp) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, + DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + .ATTR(dtype, Int, DT_INT32) + .OP_END_FACTORY_REG(ReadVariableOp) + REG_OP(End) .INPUT(x, TensorType::ALL()) .OUTPUT(y, TensorType::ALL()) @@ -719,7 +781,7 @@ REG_OP(Summary) *x: A tensor. *@par Attributes: -*out_type: An optional int32 or int64. The output data type. Defaults to int32. +*dtype: An optional int32 or int64. The output data type. Defaults to int32. *@par Outputs: *y: A tensor. The shape of the input tensor. @@ -738,7 +800,7 @@ REG_OP(Shape) *x: A list of input tensors. *@par Attributes: -*out_type: An optional int32 or int64. The output data type. Defaults to "int32". +*dtype: An optional int32 or int64. The output data type. Defaults to "int32". *@par Outputs: *y: A list of tensors with the same length as the input list of tensors. @@ -829,14 +891,17 @@ REG_OP(Where) *The Split node is removed from the graph after the split operation is completed. *@par Inputs: -*x: A Tensor. Must be one of the following types: float16, float32, int8, int32. +*x: A Tensor. Must be one of the following types: \n +fp16, fp32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. *@par Outputs: *y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. */ REG_OP(Copy) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT32})) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64})) .OP_END_FACTORY_REG(Copy); /** @@ -859,6 +924,26 @@ REG_OP(Fingerprint) .INPUT(method, TensorType({DT_STRING})) .OUTPUT(y, TensorType({DT_UINT8})) .OP_END_FACTORY_REG(Fingerprint) + +/** +*@brief Change the shape of output according to the attr outShape +* + +*@par Inputs: +*x: A Tensor. + +*@par Outputs: +*y: A Tensor. Has the same type as "x".It's required and the value should equal to output_num. + +*@par Attributes: +*outShape: The shape of output will be inferred according to the attribute +*/ +REG_OP(TransShape) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(outShape,ListInt ,{}) + .OP_END_FACTORY_REG(TransShape); + } // namespace ge #endif // GE_OP_ARRAY_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/condtake_ops.h b/third_party/fwkacllib/inc/ops/condtake_ops.h new file mode 100644 index 00000000..37d3b92a --- /dev/null +++ b/third_party/fwkacllib/inc/ops/condtake_ops.h @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_CONDTAKE_OPS_H_ +#define GE_OP_CONDTAKE_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { +/** +*@brief Take elements from data if specific condition is satisfied on mask. + +*@par Inputs: +*@li data: input tensor from which to take elements, High-dimension input would \n +first be flattened. +*@li mask: condition param; must be the same shape with data. + +*@par Attributes: +*@li mode:convert by convert in Mode. +*@li val:convert by +*@li eps:convert by (default: 1e-06) + +*@par Outputs: +*@li out_data: the elements taken +*@li out_index: the indices corresponding to those elements +*@li valid_num: elements of out_data and out_index from zeros to valid_num is valid. +*/ + +REG_OP(CondTake) + .INPUT(data, TensorType({DT_FLOAT})) + .INPUT(mask, TensorType({DT_FLOAT})) + .OUTPUT(out_data, TensorType({DT_FLOAT})) + .OUTPUT(out_index, TensorType({DT_INT32})) + .OUTPUT(valid_num, TensorType({DT_INT32})) + .REQUIRED_ATTR(mode, String) + .REQUIRED_ATTR(val, Float) + .ATTR(eps, Float, 1e-06) + .OP_END_FACTORY_REG(CondTake) +} // namespace ge + +#endif // GE_OP_ARRAY_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/data_flow_ops.h b/third_party/fwkacllib/inc/ops/data_flow_ops.h index 08cbd1ff..dac7fb0b 100644 --- a/third_party/fwkacllib/inc/ops/data_flow_ops.h +++ b/third_party/fwkacllib/inc/ops/data_flow_ops.h @@ -19,6 +19,7 @@ #include #include "graph/operator_reg.h" +#include "graph/operator.h" namespace ge { @@ -259,7 +260,7 @@ match this name to the matching Unstage Op. REG_OP(Stage) .DYNAMIC_INPUT(values, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, \ DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -312,7 +313,7 @@ REG_OP(StagePeek) .INPUT(index, TensorType({DT_INT32})) .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -363,7 +364,7 @@ REG_OP(StackPop) .INPUT(handle, TensorType({DT_RESOURCE})) .OUTPUT(element, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .REQUIRED_ATTR(elem_type, Type) .OP_END_FACTORY_REG(StackPop) @@ -388,10 +389,10 @@ REG_OP(StackPush) .INPUT(handle, TensorType({DT_RESOURCE})) .INPUT(element, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .ATTR(swap_memory, Bool, false) .OP_END_FACTORY_REG(StackPush) @@ -600,7 +601,7 @@ REG_OP(MapIncompleteSize) REG_OP(Unstage) .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE})) + DT_DOUBLE, DT_UINT32, DT_UINT64})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -1869,7 +1870,7 @@ REG_OP(SparseAccumulatorApplyGradient) .INPUT(local_step, TensorType({DT_INT64})) .INPUT(indices, TensorType({DT_INT64})) .INPUT(values, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ - DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT,DT_FLOAT16, DT_UINT32, \ + DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_UINT32, \ DT_UINT64, DT_COMPLEX64, DT_COMPLEX128, DT_QINT16, DT_QUINT16, \ DT_QINT8, DT_QUINT8, DT_QINT32})) .INPUT(shape, TensorType({DT_INT64})) @@ -1905,6 +1906,145 @@ REG_OP(SparseAccumulatorTakeGradient) .OUTPUT(shape, TensorType({DT_INT64})) .REQUIRED_ATTR(dtype, Type) .OP_END_FACTORY_REG(SparseAccumulatorTakeGradient) -} // namespace ge + +/** +*@brief A conditional accumulator for aggregating gradients. + +*@par Attributes: +* @li dtype: The type of the value being accumulated. +* @li shape: The shape of the values, can be [], in which case shape is unknown. +* @li container: If non-empty, this accumulator is placed in the given container. \n +Otherwise, a default container is used. +* @li shared_name: If non-empty, this accumulator will be shared under the given \n +name across multiple sessions. +* @li reduction_type: reduction operator type, default "MEAN". + +*@par Outputs: +*handle: A Tensor of type DT_RESOURCE. The handle to the accumulator. + +*@attention Constraints: +*ResourceConditionalAccumulator runs on the Ascend AI CPU, which delivers poor performance. + +*/ + +REG_OP(ResourceConditionalAccumulator) + .OUTPUT(handle, TensorType({DT_RESOURCE})) + .REQUIRED_ATTR(dtype, Type) + .REQUIRED_ATTR(shape, ListInt) + .ATTR(container, String, "") + .ATTR(shared_name, String, "") + .ATTR(reduction_type, String, "MEAN") + .OP_END_FACTORY_REG(ResourceConditionalAccumulator) + +/** +*@brief Applies a gradient to a given accumulator. \n +Does not add if "local_step" is lesser than the accumulator's "global_step". + +*@par Inputs: +* @li handle: The handle to an accumulator. +* @li local_step: The "local_step" value at which the gradient was computed. +* @li gradient: A tensor of the gradient to be accumulated. \n +Must be one of the following types: \n +DT_FLOAT16, DT_FLOAT, DT_DOUBLE + +*@attention Constraints: +*ResourceAccumulatorApplyGradient runs on the Ascend AI CPU, which delivers poor performance. + +*/ + +REG_OP(ResourceAccumulatorApplyGradient) + .INPUT(handle, TensorType({DT_RESOURCE})) + .INPUT(local_step, TensorType({DT_INT64})) + .INPUT(gradient, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(ResourceAccumulatorApplyGradient) + +/** +*@brief Returns the number of gradients aggregated in the given accumulators. + +*@par Inputs: +*handle: The handle to an accumulator. + +*@par Outputs: +*num_accumulated: The number of gradients aggregated in the given accumulator. + +*@attention Constraints: +*ResourceAccumulatorNumAccumulated runs on the Ascend AI CPU, which delivers poor performance. + +*/ + +REG_OP(ResourceAccumulatorNumAccumulated) + .INPUT(handle, TensorType({DT_RESOURCE})) + .OUTPUT(num_accumulated, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(ResourceAccumulatorNumAccumulated) + +/** +*@brief Updates the accumulator with a new value for "global_step". + +*@par Inputs: +* @li handle: The handle to an accumulator. +* @li new_global_step: The new "global_step" value to set. + +*@attention Constraints: +*ResourceAccumulatorSetGlobalStep runs on the Ascend AI CPU, which delivers poor performance. + +*/ + +REG_OP(ResourceAccumulatorSetGlobalStep) + .INPUT(handle, TensorType({DT_RESOURCE})) + .INPUT(new_global_step, TensorType({DT_INT64})) + .OP_END_FACTORY_REG(ResourceAccumulatorSetGlobalStep) + +/** +*@brief Extracts the average gradient in the given ConditionalAccumulator. + +*@par Inputs: +* @li handle: The handle to an accumulator. +* @li num_required: Number of gradients required before an aggregate is returned. + +*@par Attributes: +*dtype: The data type of accumulated gradients. \n +Needs to correspond to the type of the accumulator. + +*@par Outputs: +*average: The average of the accumulated gradients. \n +Must be one of the following types: \n +DT_FLOAT16, DT_FLOAT, DT_DOUBLE. + +*@attention Constraints: +*ResourceAccumulatorTakeGradient runs on the Ascend AI CPU, which delivers poor performance. + +*/ + +REG_OP(ResourceAccumulatorTakeGradient) + .INPUT(handle, TensorType({DT_RESOURCE})) + .INPUT(num_required, TensorType({DT_INT32})) + .OUTPUT(average, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(dtype, Type) + .OP_END_FACTORY_REG(ResourceAccumulatorTakeGradient) + +/** +*@brief Enqueue a Tensor on the computation outfeed. + +*@par Inputs: +*Inputs include: \n +*x: A Tensor. Must be one of the following types: float16, float32, \n +float64, int8, int16, uint16, uint8, int32, int64, uint32, uint64, \n +bool, double, string. + +*@par Attributes: +*channel_name: name of operator channel, default "". + +*@attention Constraints:\n +*-The implementation for OutfeedEnqueueOp on Ascend uses AICPU, with bad performance.\n + +*/ +REG_OP(OutfeedEnqueueOp) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, + DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, + DT_UINT64, DT_BOOL, DT_DOUBLE, DT_STRING})) + .ATTR(channel_name, String, "") + .OP_END_FACTORY_REG(OutfeedEnqueueOp) + +} // namespace ge #endif // GE_OP_DATA_FLOW_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index d5272805..11475819 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -16,7 +16,7 @@ #ifndef GE_OP_ELEWISE_CALCULATION_OPS_H #define GE_OP_ELEWISE_CALCULATION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -112,8 +112,7 @@ REG_OP(MinimumGrad) int64, uint64, int16, uint16, double, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@par Attributes: -*@li dst_type: An required attribute of type int32, specifying the dst data type. -*@li truncate: An optional attribute of type bool, specifying the src data type. Defaults to "false". +*dst_type: An required attribute of type int32, specifying the dst data type. *@par Outputs: *y:A `Tensor`. Has the same type as `x`. @@ -126,7 +125,6 @@ REG_OP(Cast) DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32})) .REQUIRED_ATTR(dst_type, Int) - .ATTR(truncate, Bool, false) .OP_END_FACTORY_REG(Cast) /** @@ -886,7 +884,10 @@ REG_OP(BesselI1e) * y: A Tensor of type UnaryDataType. * @attention Constraints: -* @li base > 0 or if base is set to default (-1), base is set to e; +* @li "base" is supposed to be greater than 0. Retaining the default \n +* value "-1" sets "base" to "e". +* @li If the input value of operator Log is within the range (0, 0.01] or \n +* [0.95, 1.05], the output accuracy is subject to change. */ REG_OP(Log) .INPUT(x, TensorType::UnaryDataType()) @@ -911,7 +912,6 @@ REG_OP(Log) * uint8, int8, uint16, int16, int32, int64, complex64, complex128. * @attention Constraints: -* @li "x1" and "x2" have incompatible shapes or types. */ REG_OP(Mul) .INPUT(x1, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_UINT8, DT_INT8, @@ -1272,12 +1272,8 @@ REG_OP(Greater) * The output has the same shape and type as the input. */ REG_OP(ZerosLike) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8, - DT_UINT8, DT_INT16, DI_UINT16, DT_INT32, - DT_INT64, DT_COMPLEX128, DT_BOOL})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8, - DT_UINT8, DT_INT16, DI_UINT16, DT_INT32, - DT_INT64, DT_COMPLEX128, DT_BOOL})) + .INPUT(x, TensorType::BasicType()) + .OUTPUT(y, TensorType::BasicType()) .OP_END_FACTORY_REG(ZerosLike) /** @@ -2056,6 +2052,7 @@ REG_OP(ArgMinWithValue) * "0": product, "1": sum, "2": max. *@li coeff: A required attribute. Must met all of following rules: * size of "coeff" must be equal to len("x") or is null. +* the absolute value of “coeff†must less than or equal to 1. */ REG_OP(Eltwise) .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) @@ -2526,9 +2523,102 @@ REG_OP(ArgMaxWithK) .OUTPUT(values, TensorType({DT_FLOAT, DT_FLOAT16})) .ATTR(axis, Int, 10000) .ATTR(out_max_val, Bool, false) - .ATTR(top_k, Int, 1) + .ATTR(topk, Int, 1) .OP_END_FACTORY_REG(ArgMaxWithK) +/** +*@brief Multiply tensor with scale. + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". + +*/ +REG_OP(Muls) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value, Float) + .OP_END_FACTORY_REG(Muls) + +/** +*@brief Fill tensor with scale. + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". + +*/ +REG_OP(Fills) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value,Float) + .OP_END_FACTORY_REG(Fills) + +/** +*@brief Add tensor with scale. + +*@par Inputs: +*Five inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32,int16, float16, float32. +* @li x2: A scale. Must be float. + +*@par Outputs: +*@li y: A Tensor. Has the same type and shape as "x1". + +*/ + REG_OP(Adds) + .INPUT(x, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT,DT_INT16,DT_INT32,DT_FLOAT16})) + .REQUIRED_ATTR(value,Float) + .OP_END_FACTORY_REG(Adds) + + REG_OP(MulNoNan) + .INPUT(x1, TensorType::NumberType()) /* "First operand." */ + .INPUT(x2, TensorType::NumberType()) /* "Second operand." */ + .OUTPUT(y, TensorType::NumberType()) /* "Result, has same element type as two inputs" */ + .OP_END_FACTORY_REG(MulNoNan) + +REG_OP(Axpy) + .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .REQUIRED_ATTR(alpha, Float) + .OP_END_FACTORY_REG(Axpy) + +/** +*@brief Creates a criterion that measures the loss given input tensors x1 x2 and a Tensor label y with values 1 or -1. + +*@par Inputs: +*@li x1: A ND Tensor with one of the following types: int8, uint8, int32, float16, float32. +*@li x2: A ND Tensor with one of the following types: int8, uint8, int32, float16, float32. +*@li target: A ND Tensor with one of the following types: int8, int32, float16, float32. + +*@par Attributes: +*@li margin: A optional float32. Defaults to "0.0". +*@li reduction: A optional string. Defaults to "mean". + +*@par Outputs: +*@li y: A ND Tensor with Must be float32. +*/ +REG_OP(CosineEmbeddingLoss) + .INPUT(x1, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(x2, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(target, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(margin, Float, 0) + .ATTR(reduction, String, "mean") + .OUTPUT(y, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(CosineEmbeddingLoss) + } // namespace ge + + #endif // GE_OP_ELEWISE_CALCULATION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h b/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h deleted file mode 100644 index 2b3e206d..00000000 --- a/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_OP_FSRDETECTIONOUTPUT_OPS_H_ -#define GE_OP_FSRDETECTIONOUTPUT_OPS_H_ -#include "graph/operator_reg.h" - -namespace ge { -/** -*@brief Returns detection result. - -*@par Inputs: -* Four inputs, including: -*@li rois: An NCHW tensor of type floa16 or float32, output from operator proposal_d at the preceding layer, used as the input of operator FSRDetectionOutput. -*@li prior_box: An NCHWC0 tensor of type floa16 or float32, specifying the prediction offset, used to update the coordinates [x1, y1, x2, y2] of each ROI. -*@li score: An NCHWC0 tensor of type floa16 or float32, specifying the probability of each class. Class 0 is the background class. -*@li actual_rois_num: An NCHW tensor of type int32, specifying the number of valid boxes per batch. -*@par Attributes: -*@li batch_rois: An optional int32, specifying the number of images to be predicted. Defaults to "1024". The value range is [1, 1024]. -*@li im_info: An optional list of two ints. Defaults to (375, 1024). The value range is [1, 1024]. -*@li num_classes: An optional int32, specifying the number of classes to be predicted. Defaults to "80". The value must be greater than 0. -*@li max_rois_num: An optional int32, specifying the maximum number of ROIs per batch. Defaults to "1024". The value must be a multiple of 16. -*@li score_thresh: An optional float32, specifying the threshold for box filtering. Defaults to 0.45. The value range is [0.0, 1.0]. -*@li nms_thresh: An optional float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. Defaults to 0.7. The value range is (0.0, 1.0). -*@li bbox_reg_weights: An optional list of four ints. Defaults to (1, 1, 1, 1). Must not have value "0". -*@li post_nms_topn: An optional int, specifying the number of output boxes. Defaults to "304". The value must be less than or equal to 1024 and must be a multiple of 16. -*@li kernel_name: An optional string, specifying the operator name. Defaults to "fsr_detection_output". -*@par Outputs: -*box: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*actual_bbox_num: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li totalnum < max_rois_num * batch_rois. -*@li "score" must be with shape (total_num, (num_classes+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. -*@li "prior_box" must be with shape (total_num, (num_classes*4+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. -*/ -REG_OP(FSRDetectionOutput) - .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(prior_box, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(actual_rois_num, TensorType({DT_INT32})) - .OUTPUT(actual_bbox_num, TensorType({DT_INT32})) - .OUTPUT(box, TensorType({DT_FLOAT, DT_FLOAT16})) - .ATTR(batch_rois, Int, 1024) - .ATTR(im_info, ListInt, {375,1024}) - .ATTR(num_classes, Int, 80) - .ATTR(max_rois_num, Int, 1024) - .ATTR(score_thresh, Float, 0.45) - .ATTR(nms_thresh, Float, 0.7) - .ATTR(bbox_reg_weights, ListInt, {1,1,1,1}) - .ATTR(post_nms_topn, Int, 304) - .OP_END_FACTORY_REG(FSRDetectionOutput) -} -#endif diff --git a/third_party/fwkacllib/inc/ops/functional_ops.h b/third_party/fwkacllib/inc/ops/functional_ops.h index 1529d45c..ea15dba8 100644 --- a/third_party/fwkacllib/inc/ops/functional_ops.h +++ b/third_party/fwkacllib/inc/ops/functional_ops.h @@ -19,9 +19,107 @@ #include "graph/operator_reg.h" #include "graph/operator.h" +#include "graph/ge_attr_value.h" namespace ge { +REG_OP(SymbolicGradient) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(f) + .OP_END_FACTORY_REG(SymbolicGradient) +REG_OP(RemoteCall) + .INPUT(target, DT_STRING) + .DYNAMIC_INPUT(args, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(f) + .OP_END_FACTORY_REG(RemoteCall) + +REG_OP(_If) + .INPUT(cond, TensorType::ALL()) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(then_branch) + .GRAPH(else_branch) + .OP_END_FACTORY_REG(_If) + +REG_OP(StatelessIf) + .INPUT(cond, TensorType::ALL()) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(then_branch) + .GRAPH(else_branch) + .OP_END_FACTORY_REG(StatelessIf) + +REG_OP(If) + .INPUT(cond, TensorType::ALL()) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(then_branch) + .GRAPH(else_branch) + .OP_END_FACTORY_REG(If) + +REG_OP(Case) + .INPUT(branch_index, DT_INT32) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .DYNAMIC_GRAPH(branches) + .OP_END_FACTORY_REG(Case) + +REG_OP(_While) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(cond) + .GRAPH(body) + .OP_END_FACTORY_REG(_While) + +REG_OP(While) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(cond) + .GRAPH(body) + .ATTR(parallel_iterations, Int, 10) + .OP_END_FACTORY_REG(While) + +REG_OP(StatelessWhile) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(cond) + .GRAPH(body) + .ATTR(parallel_iterations, Int, 10) + .OP_END_FACTORY_REG(StatelessWhile) + +REG_OP(For) + .INPUT(start, DT_INT32) + .INPUT(limit, DT_INT32) + .INPUT(delta, DT_INT32) + .DYNAMIC_INPUT(input, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(body) + .OP_END_FACTORY_REG(For) + +REG_OP(PartitionedCall) + .DYNAMIC_INPUT(args, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(f) + .ATTR(config, String, "") + .ATTR(config_proto, String, "") + .ATTR(executor_type, String, "") + .OP_END_FACTORY_REG(PartitionedCall) + +REG_OP(StatefulPartitionedCall) + .DYNAMIC_INPUT(args, TensorType::ALL()) + .DYNAMIC_OUTPUT(output, TensorType::ALL()) + .GRAPH(f) + .ATTR(config, String, "") + .ATTR(config_proto, String, "") + .ATTR(executor_type, String, "") + .OP_END_FACTORY_REG(StatefulPartitionedCall) + +REG_OP(FakeParam) + .OUTPUT(output, TensorType::ALL()) + .ATTR(shape, ListInt, {}) + .OP_END_FACTORY_REG(FakeParam) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/hcom_ops.h b/third_party/fwkacllib/inc/ops/hcom_ops.h index 598f3d11..5a69ed80 100644 --- a/third_party/fwkacllib/inc/ops/hcom_ops.h +++ b/third_party/fwkacllib/inc/ops/hcom_ops.h @@ -57,7 +57,9 @@ REG_OP(HcomAllGather) * @li group: A required string identifying the group name of ranks * participating in the op. * @li fusion: An optional integer identifying the fusion flag of the op. \n - * 0: no fusion; 1 (default): fusion. + * 0: no fusion; 1 (default): fusion; 2: fusion the ops by fusion id. + * @li fusion_id: An optional integer identifying the fusion id of the op. + * The HcomAllReduce ops with the same fusion id will be fused. * @par Outputs: * y: A Tensor. Has the same type as "x". * @attention Constraints: \n @@ -70,6 +72,7 @@ REG_OP(HcomAllReduce) .REQUIRED_ATTR(reduction, String) .REQUIRED_ATTR(group, String) .ATTR(fusion, Int, 1) + .ATTR(fusion_id, Int, -1) .ATTR(alpha, Float, 1.0) .ATTR(beta, Float, 0.0) .OP_END_FACTORY_REG(HcomAllReduce) diff --git a/third_party/fwkacllib/inc/ops/image_ops.h b/third_party/fwkacllib/inc/ops/image_ops.h index 2ac7a70e..aaad03c6 100644 --- a/third_party/fwkacllib/inc/ops/image_ops.h +++ b/third_party/fwkacllib/inc/ops/image_ops.h @@ -525,8 +525,7 @@ REG_OP(ResizeBilinearV2) .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .INPUT(size, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, - DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT})) .ATTR(align_corners, Bool, false) .ATTR(half_pixel_centers, Bool, false) .OP_END_FACTORY_REG(ResizeBilinearV2) @@ -925,7 +924,7 @@ images[3] <= 2048. */ REG_OP(ResizeBilinearV2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) .ATTR(align_corners, Bool, false) .ATTR(half_pixel_centers, Bool, false) .REQUIRED_ATTR(size, ListInt) diff --git a/third_party/fwkacllib/inc/ops/linalg_ops.h b/third_party/fwkacllib/inc/ops/linalg_ops.h index b8a12950..985af4eb 100644 --- a/third_party/fwkacllib/inc/ops/linalg_ops.h +++ b/third_party/fwkacllib/inc/ops/linalg_ops.h @@ -18,7 +18,7 @@ #define GE_OP_LINALG_OPS_H_ #include "graph/operator_reg.h" -#include "../graph/operator.h" +#include "graph/operator.h" namespace ge { diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index aa318c94..b75991e2 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -22,6 +22,29 @@ namespace ge { +/** +*@brief Computes the output as (shift + scale * x) ^ power. + +*@par Inputs: +* x: A Tensor of type float16 or float32. + +*@par Attributes: +*@li power: Optional. Defaults to 1.0. +*@li scale: Optional. Defaults to 1.0. +*@li shift: Optional. Defaults to 0.0. + +*@par Outputs: +* y: A Tensor. Has the same type and shape as "x". +*/ + +REG_OP(Power) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(power, Float, 1.0) + .ATTR(scale, Float, 1.0) + .ATTR(shift, Float, 0.0) + .OP_END_FACTORY_REG(Power); + /** *@brief Compute the lower regularized incomplete Gamma function P(a, x). @@ -420,11 +443,127 @@ REG_OP(NextAfter) * * */ REG_OP(IsFinite) - .INPUT(x, TensorType({DT_INT8, DT_INT16, DT_INT32, DT_INT64, - DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, - DT_FLOAT, DT_FLOAT16, DT_DOUBLE, DT_BOOL})) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_BOOL})) .OP_END_FACTORY_REG(IsFinite) + +/** + * *@brief Computes the complex absolute value of a tensor. + * + * *@par Inputs: + * *x:A Tensor. + * + * *@par Outputs: + * *y:A tensor of type `float` or `double` that is the absolute value of each element in `x`. + * + * */ +REG_OP(ComplexAbs) + .INPUT(x, TensorType({DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE})) + .ATTR(Tout, Type, DT_FLOAT) + .OP_END_FACTORY_REG(ComplexAbs) + +/** + * *@brief Returns which elements of x are NaN. + * + * *@par Inputs: + * *x:A Tensor. + * + * *@par Outputs: + * *y:A Tensor. Has the same shape as x. + * + * */ +REG_OP(IsNan) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_BOOL})) + .OP_END_FACTORY_REG(IsNan) + +/** + * *@brief Returns the real part of a complex number. + * + * *@par Inputs: + * *input:A Tensor. + * + * *@par Outputs: + * *output:A Tensor. Has the same shape as input. + * + * */ +REG_OP(Real) + .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE})) + .ATTR(Tout, Type, DT_FLOAT) + .OP_END_FACTORY_REG(Real) + +/** + * *@brief Returns the complex conjugate of a complex number. + * + * *@par Inputs: + * *input:A Tensor. + * + * *@par Outputs: + * *output:A Tensor. Has the same shape as input. + * + * */ +REG_OP(Conj) + .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) + .OUTPUT(output, TensorType({DT_COMPLEX64, DT_COMPLEX128})) + .OP_END_FACTORY_REG(Conj) + +/** + * *@brief The negative log likelihood loss. + * + * *@par Inputs: + * *The input x and weight must have the same type. Inputs include: \n + * *@li x:A Tensor. Must be the type: float32. + * *@li target:A Tensor. Must be the type: int32. + * *@li weight:A Tensor. Must be the type: float32. + * + * *@par Attributes: + * *@li reduction: An optional attribute. Defaults to "mean". + * + * *@par Outputs: + * *Two outputs, including: + * *@li y: A Tensor. Must be the following type: float32. + * *@li total_weight: A Tensor. Must be the type: float32. + * + * */ +REG_OP(NLLLoss) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(target, TensorType({DT_INT32})) + .INPUT(weight, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .OUTPUT(total_weight, TensorType({DT_FLOAT})) + .ATTR(reduction, String, "mean") + .OP_END_FACTORY_REG(NLLLoss) + +/** + * *@brief The negative log likelihood loss grad. + + * *@par Inputs: + * *Inputs include: + * *@li x:A Tensor. Must be the type: float32. + * *@li y_grad:A Tensor. Must be the type: float32. + * *@li target:A Tensor. Must be the type: int32. + * *@li weight:A Tensor. Must be the type: float32. + * *@li total_weight:A Tensor. Must be the type: float32. + * + * *@par Attributes: + * *@li reduction: An optional attribute. Defaults to "mean". + * + * *@par Outputs: + * *One outputs, including: + * *@li x_grad: A Tensor. Must be the following type: float32. + * + * */ +REG_OP(NLLLossGrad) + .INPUT(x, TensorType({DT_FLOAT})) + .INPUT(y_grad, TensorType({DT_FLOAT})) + .INPUT(target, TensorType({DT_INT32})) + .INPUT(weight, TensorType({DT_FLOAT})) + .INPUT(total_weight, TensorType({DT_FLOAT})) + .OUTPUT(x_grad, TensorType({DT_FLOAT})) + .ATTR(reduction, String, "mean") + .OP_END_FACTORY_REG(NLLLossGrad) } // namespace ge #endif // GE_OP_MATH_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index 597a8982..4f0f4557 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_MATRIX_CALCULATION_OPS_H #define GE_OP_MATRIX_CALCULATION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { @@ -302,6 +302,32 @@ REG_OP(ScatterNdUpdate) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ScatterNdUpdate) +/** +*@brief Applies sparse addition to individual values or slices in a Variable. + +*@par Inputs: +* Three inputs, including: +*@li x: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 +*@li indices: An ND Tensor. \n + +*Must be one of the following types: int32 +*@li updates: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*/ +REG_OP(TensorScatterUpdate) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(indices, TensorType::IndexNumberType()) + .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OP_END_FACTORY_REG(TensorScatterUpdate) + /** *@brief Adds sparse "updates" to a variable reference. @@ -393,6 +419,32 @@ REG_OP(ScatterNdAdd) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ScatterNdAdd) +/** +*@brief Applies sparse addition to individual values or slices in a Variable. + +*@par Inputs: +* Three inputs, including: +*@li x: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 +*@li indices: An ND Tensor. \n + +*Must be one of the following types: int32 +*@li updates: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*/ +REG_OP(TensorScatterAdd) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(indices, TensorType::IndexNumberType()) + .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OP_END_FACTORY_REG(TensorScatterAdd) + /** *@brief Applies sparse subtraction to individual values or slices in a Variable. @@ -423,6 +475,32 @@ REG_OP(ScatterNdSub) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ScatterNdSub) +/** +*@brief Applies sparse addition to individual values or slices in a Variable. + +*@par Inputs: +* Three inputs, including: +*@li x: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 +*@li indices: An ND Tensor. \n + +*Must be one of the following types: int32 +*@li updates: An ND Tensor. \n + +*Must be one of the following types: float16, float32, int32, int8, uint8 + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*/ +REG_OP(TensorScatterSub) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .INPUT(indices, TensorType::IndexNumberType()) + .INPUT(updates, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT32,DT_INT8,DT_UINT8})) + .OP_END_FACTORY_REG(TensorScatterSub) + /** *@brief Subtracts sparse updates to a variable reference. @@ -492,34 +570,35 @@ REG_OP(DiagPart) *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases. *@par Inputs: -* Two inputs, including: +* Four inputs, including: *@li x: A Tensor of type float16, int8. *@li w: A weight matrix of type float16, int8. -*@li b: A Tensor of type float16, int32. +*@li b: A Tensor of type float16, int32, float32. *@li offset_w: A Tensor of type int8. *@par Attributes: *@li num_output: Reserved. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". -*@li bias_term: A bool, specifying whether to learn and apply a set of additive biases to the filter outputs, either "true" or "false". Defaults to "true". -*@li axis: only support axis is 1. Defaults to "1". -*@li offset_a: A type of Int, Defaults to "1". +*@li axis: Reserved. +*@li offset_x: Reserved. *@par Outputs: -*y: The result tensor of type float16, int8. +*y: The result tensor of type float16, int8, float32. + +*@par Quantization supported or not +* Yes */ -REG_OP(InnerProduct) +REG_OP(FullyConnection) .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) .INPUT(w, TensorType({DT_FLOAT16, DT_INT8})) - .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32})) .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32,DT_FLOAT32})) .REQUIRED_ATTR(num_output, Int) .ATTR(transpose, Bool, false) - .ATTR(bias_term, Bool, true) .ATTR(axis, Int, 1) - .ATTR(offset_a, Int, 0) - .OP_END_FACTORY_REG(InnerProduct) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(FullyConnection) /** *@brief Computes the confusion matrix from predictions and labels. @@ -673,6 +752,96 @@ REG_OP(ScatterUpdate) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ScatterUpdate) +/** +*@brief Returns a tensor with the `k[0]`-th to `k[1]`-th diagonals of the batched `input`. + +*@par Inputs: +* Three inputs, including: +*@li input: Rank `r` tensor where `r >= 2`. \n + +*@li k: \n +*Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n +*diagonal, and negative value means subdiagonals. `k` can be a single integer \n +*(for a single diagonal) or a pair of integers specifying the low and high ends \n +*of a matrix band. `k[0]` must not be larger than `k[1]`. \n + +*@li padding_value: The value to fill the area outside the specified diagonal band with. \n + +*@par Outputs: +*diagonal: The extracted diagonal(s). + +*/ +REG_OP(MatrixDiagPartV2) + .INPUT(input, TensorType::BasicType()) + .INPUT(k, TensorType({DT_INT32})) + .INPUT(padding_value, TensorType::BasicType()) + .OUTPUT(diagonal, TensorType::BasicType()) + .OP_END_FACTORY_REG(MatrixDiagPartV2) + +/** +*@brief Returns a batched matrix tensor with new batched diagonal values. + +*@par Inputs: +* Three inputs, including: +*@li input: "Rank `r+1`, where `r >= 1`. \n + +*@li diagonal: Rank `r` when `k` is an integer or `k[0] == k[1]`. Otherwise, it has rank `r+1`. \n + +*@li k: +*Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n +*diagonal, and negative value means subdiagonals. `k` can be a single integer \n +*(for a single diagonal) or a pair of integers specifying the low and high ends \n +*of a matrix band. `k[0]` must not be larger than `k[1]`. \n + +*@par Outputs: +*output: Rank `r+1`, with `output.shape = input.shape`. + +*/ +REG_OP(MatrixSetDiagV2) + .INPUT(input, TensorType::BasicType()) + .INPUT(diagonal, TensorType::BasicType()) + .INPUT(k, TensorType({DT_INT32})) + .OUTPUT(output, TensorType::BasicType()) + .OP_END_FACTORY_REG(MatrixSetDiagV2) + +/** +*@brief Returns a batched diagonal tensor with given batched diagonal values. + +*@par Inputs: +* Five inputs, including: +*@li diagonal: Rank `r`, where `r >= 1` \n + +*@li k: +*Diagonal offset(s). Positive value means superdiagonal, 0 refers to the main \n +*diagonal, and negative value means subdiagonals. `k` can be a single integer \n +*(for a single diagonal) or a pair of integers specifying the low and high ends \n +*of a matrix band. `k[0]` must not be larger than `k[1]`. \n + +*@li num_rows: +*The number of rows of the output matrix. If it is not provided, the op assumes \n +*the output matrix is a square matrix and infers the matrix size from k and the \n +*innermost dimension of `diagonal`. \n + +*@li num_cols: An NCHW, NHWC, or ND Tensor. +*The number of columns of the output matrix. If it is not provided, the op \n +*assumes the output matrix is a square matrix and infers the matrix size from \n +*k and the innermost dimension of `diagonal`. \n + +*@li padding_value: The number to fill the area outside the specified diagonal band with. \n + +*@par Outputs: +*output: Has rank `r+1` when `k` is an integer or `k[0] == k[1]`, rank `r` otherwise. + +*/ +REG_OP(MatrixDiagV2) + .INPUT(diagonal, TensorType::BasicType()) + .INPUT(k, TensorType({DT_INT32})) + .INPUT(num_rows, TensorType({DT_INT32})) + .INPUT(num_cols, TensorType({DT_INT32})) + .INPUT(padding_value, TensorType::BasicType()) + .OUTPUT(output, TensorType::BasicType()) + .OP_END_FACTORY_REG(MatrixDiagV2) + } // namespace ge #endif // GE_OP_MATRIX_CALCULATION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h index 4b5c5f23..0a1337c0 100644 --- a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_NN_BATCH_NORM_OPS_H #define GE_OP_NN_BATCH_NORM_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index 1be85a0e..b7b55fb0 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_NN_CALCULATION_OPS_H #define GE_OP_NN_CALCULATION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -62,7 +62,7 @@ namespace ge { * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n * Limited by Tiling and L1 / L0 buffer memory: 512 * ceil(Wo, 16) + (480 * -* stride_h + 32 * filter_h) * ceil(Wi, 16) â‰?l1_size and Hf*Wf â‰?l0b_size/512.\n +* stride_h + 32 * filter_h) * ceil(Wi, 16) �?l1_size and Hf*Wf �?l0b_size/512.\n */ REG_OP(DepthwiseConv2DBackpropFilter) .INPUT(input, TensorType({float16})) @@ -115,7 +115,7 @@ REG_OP(DepthwiseConv2DBackpropFilter) * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n * Limited by Tiling and L1 / L0 buffer memory: 512 * ceil(Wo, 16) + (480 * -* stride_h + 32 * filter_h) * ceil(Wi, 16) â‰?l1_size and Hf*Wf â‰?l0b_size/512.\n +* stride_h + 32 * filter_h) * ceil(Wi, 16) �?l1_size and Hf*Wf �?l0b_size/512.\n */ REG_OP(DepthwiseConv2DBackpropFilterD) .INPUT(input, TensorType({float16})) @@ -170,7 +170,7 @@ REG_OP(DepthwiseConv2DBackpropFilterD) * Output backprop is 4D with shape [N, C, Ho, Wo] or [N, Ho, Wo, C], but the * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n -* Limited by Tiling: max_h_in_l1 â‰?C0, where max_h_in_l1 = (l1_size - Hf * +* Limited by Tiling: max_h_in_l1 �?C0, where max_h_in_l1 = (l1_size - Hf * * Wf * C0 * C0 * 2) / (2 * Wo *C0).\n */ REG_OP(DepthwiseConv2DBackpropInput) @@ -223,7 +223,7 @@ REG_OP(DepthwiseConv2DBackpropInput) * Output backprop is 4D with shape [N, C, Ho, Wo] or [N, Ho, Wo, C], but the * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n -* Limited by Tiling: max_h_in_l1 â‰?C0, where max_h_in_l1 = (l1_size - Hf * +* Limited by Tiling: max_h_in_l1 �?C0, where max_h_in_l1 = (l1_size - Hf * * Wf * C0 * C0 * 2) / (2 * Wo *C0).\n */ REG_OP(DepthwiseConv2DBackpropInputD) @@ -276,13 +276,16 @@ REG_OP(DepthwiseConv2DBackpropInputD) * Limited by the size of L1 buffer memory: \n * (l1_size - filter_h*filter_w*BLOCK_SIZE*BLOCK_SIZE*data_size) // (Wi * * BLOCK_SIZE * data_size) >= (BLOCK_SIZE * strides_h + filter_h - strides_h).\n + +* @par Quantization supported or not +* Yes */ REG_OP(DepthwiseConv2D) - .INPUT(x, TensorType({DT_FLOAT16})) - .INPUT(filter, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(bias, TensorType({DT_INT8})) - .OPTIONAL_INPUT(offset_w, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_FLOAT16, DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) .REQUIRED_ATTR(strides, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1}) .REQUIRED_ATTR(pads, ListInt) @@ -439,13 +442,17 @@ REG_OP(Conv2DBackpropInputD) * One optional input: * @li bias: An optional tensor of type int8 *@par Attributes: - * Three attributes: + * Five attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window * for H/W dimension. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] * padding on the feature map * @li dilations: A tuple or list of 4 integers. The dilation factor for each * dimension of input. Must be [1, 1, 1, 1]. + * @li groups: Number of blocked connections from input channels to \n + output channels. + * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC".\n + Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as "filter". 4D tensor with shape * [batch, height, width, channels] or [batch, channels, height, width]. @@ -458,6 +465,8 @@ REG_OP(Deconvolution) .ATTR(strides, ListInt, {1, 1, 1, 1}) .ATTR(pads, ListInt, {0, 0, 0, 0}) .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") .OP_END_FACTORY_REG(Deconvolution) /** *@brief Computes the gradients of convolution with respect to the filter @@ -606,6 +615,8 @@ REG_OP(Conv2DBackpropFilterD) * As shown above, "HxW(input)" indicates the image size after padding and * "HxW(filter)" indicates the filter size after dilation. +*@par Quantization supported or not +* Yes */ REG_OP(Conv2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) @@ -621,6 +632,21 @@ REG_OP(Conv2D) .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv2D) +REG_OP(Conv2DCompress) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) + .INPUT(filter_compress, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT8})) + .INPUT(compress_index, TensorType({DT_INT8})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NHWC") + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv2DCompress) + /** *@brief Computes a 3D convolution given 5D "x" and "filter" tensors. *@par Inputs: @@ -631,7 +657,6 @@ REG_OP(Conv2D) *@par Attributes: *@li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". *@li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. -*@li padding_mode: An optional string from: "zeros", "circular". Defaults to "zeros". *@li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. *@li dilations: A list of 5 ints. Specifies the dilation factor for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". @@ -649,7 +674,6 @@ REG_OP(Conv3D) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .ATTR(strides, ListInt, {1, 1, 1, 1, 1}) .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) - .ATTR(padding_mode, String, "zeros") .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3D) @@ -658,9 +682,9 @@ REG_OP(Conv3D) *@brief Computes the gradients of convolution 3d with respect to the input. *@par Inputs: * Three inputs: - * @li input_sizes: A Tensor of type int32, int64. An integer vector representing the shape of input, + * @li input_size: A Tensor of type int32, int64. An integer vector representing the shape of input, * where input is a 5-D tensor [batch, depth, height, width, channels] or [batch, channels, depth, height, width]. - * @li filters: A Tensor. Must be one of the following types: float16, float32, float64. + * @li filter: A Tensor. Must be one of the following types: float16, float32, float64. * @li grads: A Tensor. Must have the same type as filter. 5-D with shape [batch, depth, out_height, out_width, out_channels] * or [batch, out_channels, depth, out_height, out_width]. Gradients with respect to the output of the convolution. *@par Attributes: @@ -671,10 +695,10 @@ REG_OP(Conv3D) * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as filter,and has same format as input_size -*/ +*/ REG_OP(Conv3DBackpropInput) - .INPUT(input_sizes, TensorType({DT_INT32, DT_INT64})) - .INPUT(filters, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .INPUT(grads, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .REQUIRED_ATTR(strides, ListInt) @@ -687,54 +711,44 @@ REG_OP(Conv3DBackpropInput) *@brief Computes the gradients of convolution 3d with respect to the input. *@par Inputs: * Two inputs: - * @li filters: A Tensor. Types is float16. + * @li filter: A Tensor. Types is float16. * @li grads: A Tensor. Must have the same type as filter. *@par Attributes: * Five attributes: - * @li input_sizes A Tensor of type int32. An integer vector representing the shape of input, + * @li input_size A Tensor of type int32. An integer vector representing the shape of input, * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. * @li pads: A tuple/list of 4 integers * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as filter -*/ +*/ REG_OP(Conv3DBackpropInputD) - .INPUT(filters, TensorType({DT_FLOAT16})) + .INPUT(filter, TensorType({DT_FLOAT16})) .INPUT(grads, TensorType({DT_FLOAT16})) .OUTPUT(y, TensorType({DT_FLOAT16})) - .REQUIRED_ATTR(input_sizes, ListInt) + .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3DBackpropInputD) -REG_OP(LSTMQuant) - .INPUT(x, TensorType({DT_FLOAT16,DT_INT8})) +REG_OP(LSTM) + .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(cont, TensorType({DT_FLOAT32,DT_FLOAT16})) - .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16,DT_INT8})) - .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16,DT_FLOAT32,DT_INT8})) - .OPTIONAL_INPUT(c_0, TensorType({DT_FLOAT16,DT_FLOAT32})) - .INPUT(w_x, TensorType({DT_FLOAT16,DT_INT8})) + .INPUT(w_x, TensorType({DT_FLOAT16})) .INPUT(bias, TensorType({DT_FLOAT16,DT_FLOAT32,DT_INT16,DT_INT32})) - .OPTIONAL_INPUT(w_x_static, TensorType({DT_FLOAT16,DT_INT8})) - .INPUT(w_h, TensorType({DT_FLOAT16,DT_INT8})) - .OPTIONAL_INPUT(w_xh_deqscale, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(w_x_static_deqscale, TensorType({DT_FLOAT16})) - .OUTPUT(h, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT8})) - .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT,DT_INT8})) + .INPUT(w_h, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16,DT_FLOAT32})) + .OPTIONAL_INPUT(c_0, TensorType({DT_FLOAT16,DT_FLOAT32})) + .OPTIONAL_INPUT(w_x_static, TensorType({DT_FLOAT16})) + .OUTPUT(h, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(c_t, TensorType({DT_FLOAT16, DT_FLOAT})) .ATTR(num_output, Int, 0) .ATTR(expose_hidden, Bool, false) - .ATTR(xh_scale, Float,0) - .ATTR(sqrt_mode_xh, Bool, false) - .ATTR(sqrt_mode_x_static, Bool, false) - .ATTR(xh_offset, Int,0) - .ATTR(x_static_scale, Float,0.0) - .ATTR(x_static_offset, Int,0) - .ATTR(w_xh_offset,ListInt,{0}) - .ATTR(w_x_static_offset,ListInt,{0}) - .OP_END_FACTORY_REG(LSTMQuant) + .OP_END_FACTORY_REG(LSTM) } // namespace ge #endif // GE_OP_NN_CALCULATION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index 1d8f0ae5..ce06a9b2 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -305,11 +305,415 @@ REG_OP(ROIAlign) REG_OP(PSROIPooling) .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) - .ATTR(output_dim, Int, 0) - .ATTR(group_size, Int, 0) - .ATTR(spatial_scale, Float, 0.0625) + .REQUIRED_ATTR(output_dim, Int) + .REQUIRED_ATTR(group_size, Int) + .REQUIRED_ATTR(spatial_scale, Float) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) .OP_END_FACTORY_REG(PSROIPooling) + +/** +*@brief Returns detection result. + +*@par Inputs: +* Four inputs, including: +*@li rois: An NCHW tensor of type floa16 or float32, output from operator proposal_d at the preceding layer, used as the input of operator FSRDetectionOutput. +*@li bbox_delta: An NCHWC0 tensor of type floa16 or float32, specifying the prediction offset, used to update the coordinates [x1, y1, x2, y2] of each ROI. +*@li score: An NCHWC0 tensor of type floa16 or float32, specifying the probability of each class. Class 0 is the background class. +*@li im_info: An ND tensor of type float16 or float32, specifying the Image information. +*@li actual_rois_num: An optional NCHW tensor of type int32, specifying the number of valid boxes per batch. +*@par Attributes: +*@li batch_rois: An optional int32, specifying the number of images to be predicted. +*@li num_classes: An required int32, specifying the number of classes to be predicted. The value must be greater than 0. +*@li score_threshold: An required float32, specifying the threshold for box filtering. The value range is [0.0, 1.0]. +*@li iou_threshold: An required float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. The value range is (0.0, 1.0). +*@par Outputs: +*box: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*actual_bbox_num: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li totalnum < max_rois_num * batch_rois. +*@li "score" must be with shape (total_num, (num_classes+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. +*@li "bbox_delta" must be with shape (total_num, (num_classes*4+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. +*/ +REG_OP(FSRDetectionOutput) + .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(bbox_delta, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(im_info, TensorType({DT_FLOAT, DT_FLOAT16})) + .OPTIONAL_INPUT(actual_rois_num, TensorType({DT_INT32})) + .OUTPUT(actual_bbox_num, TensorType({DT_INT32})) + .OUTPUT(box, TensorType({DT_FLOAT, DT_FLOAT16})) + .ATTR(batch_rois, Int, 1) + .REQUIRED_ATTR(num_classes, Int) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(iou_threshold, Float) + .OP_END_FACTORY_REG(FSRDetectionOutput) + +/** +*@brief Normalizes data. It is called Region on YOLO v2 and Yolo on YOLO v3. + +*@par Inputs: +*x: An NCHW tensor of type float16 or float32. The data is with shape (N, boxes*(coords+obj+classes), H, W),where, "obj" indicates the confidence of an object, and only one confidence is supported. Boxes are arranged as xx...xyy...yww...whh...hbb...bc0c0..c0c1c1...c1......cncn...cn. + +*@par Attributes: +*@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. +*@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). +*@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. +*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". +*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". +*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". +*@li background: A bool. + +*@par Outputs: +*@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. +*@li obj_prob: A float16 or float32 with shape [N, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the confidence. +*@li classes_prob: A float16 or float32 with shape [N, classes, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the prediction classes. + +*@attention Constraints: +*@li This operator applies to YOLO v2 and v3 networks. +*@li The succeeding layer of the Yolo operator must be operator Yolov3DetectionOutput. +*/ +REG_OP(Yolo) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(yolo_version, String, "V3") + .ATTR(softmax, Bool, false) + .ATTR(background, Bool, false) + .ATTR(softmaxtree, Bool, false) + .OP_END_FACTORY_REG(Yolo) + +/** +*@brief Performs YOLO V2 detection. + +*@par Inputs: +* Four inputs, including: +*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov3DetectionOutput. \n +Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 20]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. + +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v2 network. +*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. + +*@see Yolo() +*/ +REG_OP(YoloV2DetectionOutput) + .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases, ListFloat) + .ATTR(boxes, Int, 5) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 20) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 512) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV2DetectionOutput) + +/** +*@brief Performs YOLO V2 detection. + +*@par Inputs: +*Six inputs, including: +*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov2DetectionOutput. \n +Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +*@li windex: A windex tensor with shape [height, weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed. \n + +*@li hindex: A hindex tensor with shape [height, weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]]. \n + +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 20]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. + +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. +* +*@attention Constraints:\n +*@li This operator applies only to the YOLO v2 network. +*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. + +*@see Yolo() +*/ +REG_OP(YoloV2DetectionOutputD) + .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases, ListFloat) + .ATTR(boxes, Int, 5) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 20) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 512) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV2DetectionOutputD) + +/** +*@brief Performs YOLO V3 detection. + +*@par Inputs: +*Ten inputs, including: +*@li Operator Yolov3DetectionOutput takes the outputs of operator Yolo as its inputs. A Yolo operator has three outputs: "coords", "obj", and "class". \n +There are three Yolo operators at Yolov3DetectionOutput's preceding layer on Yolo v3. For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. + +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. + +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n + +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v3 network. +*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. + +*@see Yolo() +*/ +REG_OP(YoloV3DetectionOutput) + .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases_low, ListFloat) + .REQUIRED_ATTR(biases_mid, ListFloat) + .REQUIRED_ATTR(biases_high, ListFloat) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 512) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV3DetectionOutput) + +/** +*@brief Performs YOLO V3 detection. + +*@par Inputs: +*16 Input, including: +*@li The outputs of operator Yolo at the preceding layer (that is, three Yolo operators on YOLO v3) are used as the inputs of operator Yolov3DetectionOutput. \n +A Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +*@li windex: A windex tensor with shape [height,weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed for the three Yolo outputs, respectively. + +*@li hindex: A hindex tensor with shape [height,weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]] is formed for the three Yolo outputs, respectively. + +* +*@par Attributes: +*@li biases: A required float32. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v3 network. +*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. +*@see Yolo() +*/ +REG_OP(YoloV3DetectionOutputD) + .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex3, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases_low, ListFloat) + .REQUIRED_ATTR(biases_mid, ListFloat) + .REQUIRED_ATTR(biases_high, ListFloat) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 512) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV3DetectionOutputD) + +/** +*@brief Spatial Pyramid Pooling, multi-level pooling. +* Pooling out(n, sigma(c*2^i*2^i)) tensor, i in range[0,pyramid_height). + +*@par Inputs: +*x: An NCHW tensor, support float16 or float32 type. + +*@par Attributes: +* @li pyramid_height: An required int32. +* Multi-level pooling out from 2^0 to 2^(pyramid_height-1). +* @li pool_method: An optional int32, pooling method: 0-MAX, 1-AVE. +* Defaults to "0". + +*@par Outputs: +*y: A NCHW tensor, support float16 or float32 type. + +*@attention Constraints: +* @li pyramid_height: pyramid_heigjt should be in range [0,7). +* @li feature_size:input feture map h and w should be [1, 510]. + +*/ +REG_OP(SPP) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) + .REQUIRED_ATTR(pyramid_height, Int) + .ATTR(pool_method, Int, 0) + .OP_END_FACTORY_REG(SPP) + +/** +*@brief Performs Region of Interest (ROI) Pooling. + +*@par Inputs: +* Three inputs, including: +*@li x: An NC1HWC0 tensor of type float16 or float32, describing the feature +* map. +*@li rois: A tensor of type float16 or float32, with shape +* [batch, 5, roi_max_num], describing the RIOs. +*@li roi_actual_num: A optional tensor of type int32, with shape [batch, 8], specifying +* the number of ROIs per batch. + +*@par Attributes: +*@li pooled_h: A required int32, specifying the pooled H. Must be greater +* than 0. +*@li pooled_w: A required int32, specifying the pooled W. Must be greater +* than 0. +*@li spatial_scale_h: An required scaling factor for mapping the input +* coordinates of height to the ROI coordinates. +*@li spatial_scale_w: An required scaling factor for mapping the input +* coordinates of width to the ROI coordinates. + +*@par Outputs: +*y: An NC1HWC0 tensor of type float16 or float32, describing the result +* feature map. + +*@attention Constraints:\n +*@li For the feature map input: \n +(1) If pooled_h = pooled_w = 2, the feature map size must not exceed 50. \n +(2) If pooled_h = pooled_w = 3, the feature map size must not exceed 60. \n +(3) If pooled_h = pooled_w = 4, the feature map size must not exceed 70. \n +(4) If pooled_h = pooled_w = 5, the feature map size must not exceed 70. \n +(5) If pooled_h = pooled_w = 6, the feature map size must not exceed 80. \n +(6) If pooled_h = pooled_w = 7, the feature map size must not exceed 80. \n +(7) If pooled_h = pooled_w = 8, the feature map size must not exceed 80. \n +(8) If pooled_h = pooled_w = 9, the feature map size must not exceed 70. \n +(9) If pooled_h = pooled_w = 10, the feature map size must not exceed 70. \n +(10) If pooled_h = pooled_w = 11, the feature map size must not exceed 70. \n +(11) If pooled_h = pooled_w = 12, the feature map size must not exceed 70. \n +(12) If pooled_h = pooled_w = 13, the feature map size must not exceed 70. \n +(13) If pooled_h = pooled_w = 14, the feature map size must not exceed 70. \n +(14) If pooled_h = pooled_w = 15, the feature map size must not exceed 70. \n +(15) If pooled_h = pooled_w = 16, the feature map size must not exceed 70. \n +(16) If pooled_h = pooled_w = 17, the feature map size must not exceed 50. \n +(17) If pooled_h = pooled_w = 18, the feature map size must not exceed 40. \n +(18) If pooled_h = pooled_w = 19, the feature map size must not exceed 40. \n +(19) If pooled_h = pooled_w = 20, the feature map size must not exceed 40. \n +*/ +REG_OP(ROIPooling) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) + .OPTIONAL_INPUT(roi_actual_num, TensorType({DT_INT32})) + .REQUIRED_ATTR(pooled_h, Int) + .REQUIRED_ATTR(pooled_w, Int) + .REQUIRED_ATTR(spatial_scale_h, Float) + .REQUIRED_ATTR(spatial_scale_w, Float) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) + .OP_END_FACTORY_REG(ROIPooling) + } // namespace ge #endif // GE_OP_NN_DETECT_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_norm_ops.h index 9b82f565..45e0eb08 100644 --- a/third_party/fwkacllib/inc/ops/nn_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_norm_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_NN_NORM_OPS_H #define GE_OP_NN_NORM_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -686,6 +686,100 @@ REG_OP(LRNGrad) .ATTR(beta, Float, 0.5) .OP_END_FACTORY_REG(LRNGrad) + /** + *@brief Calculates the RNNT Loss (log probability) for each batch entry. \n + Also calculates the gradient. + + *@par Inputs: + *@li acts: 4-D, shape: `(batch x seqLength x labelLength x outputDim)`, the logits. + *@li labels: 2-D Tensor containing all the targets of the batch with zero padded. + *@li input_lengths: Tensor of size (batch) containing size of each output sequence. + *@li label_lengths: Tensor of (batch) containing label length of each example. + + *@par Outputs: + *@li costs: 1-D Tensor, the cost of each example in the batch. + *@li grads: A Tensor. Has the same type as acts. + + *@par Attributes: + *@li blank_label: An optional attribute. Defaults to 0. + + */ +REG_OP(RNNTLoss) + .INPUT(acts, TensorType({DT_FLOAT})) + .INPUT(labels, TensorType({DT_INT32})) + .INPUT(input_lengths, TensorType({DT_INT32})) + .INPUT(label_lengths, TensorType({DT_INT32})) + .ATTR(blank_label, Int, 0) + .OUTPUT(costs, TensorType({DT_FLOAT})) + .OUTPUT(grads, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(RNNTLoss) + +/** +*@brief Performs group normalization. + +*@par Inputs:\n +* Five inputs, including: (NHWC, NCHW supported) +*@li x: A 4D Tensor of type float16 or float32, with format NHWC or \n +NCHW for 4D. +*@li scale: A Tensor of type float32. Must be 1D if input "x" is with format \n +NHWC or NCHW. Specifies the scaling factor. +*@li offset: A Tensor of type float32. Must be 1D if input "x" is with \n +format NHWC or NCHW. Specifies the offset. +*@li mean: A Tensor of type float32. Must be 1D if input "x" is with format \n +NHWC or NCHW. Reserved. Mu +st be "None" if the operation is used for training. +*@li variance: A Tensor of type float32. Must be 1D if input "x" is with \n +format NHWC or NCHW. Specifies the variance used for inference. Reserved. + +*@par Attributes: +*@li epsilon: An optional float32, specifying the small value added to \n +variance to avoid dividing by zero. Defaults to "0.0001". +*@li data_format: An optional string, specifying the format of "x". \n +Defaults to "NHWC". +*@li is_training: An optional bool, specifying if the operation is used for \n +training or inference. Defaults to "True". + +*@par Outputs:\n +* Five outputs, including: (NHWC, NCHW supported) +*@li y: A 4D Tensor of type float16 or float32 for the normalized "x", \n +with format NHWC or NCHW for 4D. +*@li batch_mean: A Tensor of type float32. Must be 1D if input "x" is with \n +format NHWC or NCHW. Specifies the mean of "x". +*@li batch_variance: A Tensor of type float32. Must be 1D if input "x" is \n +with format NHWC or NCHW. Specifies the variance of "x". +*@li reserve_space_1: An optional Tensor of type float32. Must be 1D if \n +input "x" is with format NHWC or NCHW. Specifies the mean o +f "x" for gradient computation. Pass "None" to skip this output. +*@li reserve_space_2: An optional Tensor of type float32. Must be 1D if \n +input "x" is with format NHWC or NCHW. Specifies the varian +ce of "x" for gradient computation. Pass "None" to skip this output. + +*@attention Constraints: +*@li If the operation is used for inference and outputs "reserve_space_1" \n +and "reserve_space_2" are available, then "reserve_space_1" has the same \n +value as "mean" and "reserve_spa +ce_2" has the same value as "variance". +*@li For Ascend 310, the result accuracy fails due to the square root \n +instruction. + +*/ +REG_OP(GroupNorm) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(scale, TensorType({DT_FLOAT,})) + .INPUT(offset, TensorType({DT_FLOAT,})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OUTPUT(reserve_space_1, TensorType({DT_FLOAT})) + .OUTPUT(reserve_space_2, TensorType({DT_FLOAT})) + .ATTR(epsilon, Float, 0.0001) + .ATTR(data_format, String, "NHWC") + .ATTR(is_training, Bool, true) + .ATTR(num_groups, Int, 2) + .OP_END_FACTORY_REG(GroupNorm) + } // namespace ge #endif //GE_OP_NN_NORM_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_ops.h b/third_party/fwkacllib/inc/ops/nn_ops.h index 3fd6d74b..7637da07 100644 --- a/third_party/fwkacllib/inc/ops/nn_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_ops.h @@ -17,203 +17,6 @@ #ifndef GE_OP_NN_OPS_H_ #define GE_OP_NN_OPS_H_ -#include "graph/operator_reg.h" -#include "graph/operator.h" - -namespace ge { - -/** -*@brief Computes gradient of the FractionalMaxPool function. - -*@par Inputs: -*Inputs include: \n -* @li orig_input: A Tensor. Must be one of the following types: float32, float64, int32, int64. -* @li orig_output: A Tensor. Must have the same type as orig_input. -* @li out_backprop: A Tensor. Must have the same type as orig_input. \n - 4-D with shape [batch, height, width, channels]. -* @li row_pooling_sequence: A Tensor of type int64. -* @li col_pooling_sequence: A Tensor of type int64. - -*@par Attributes: -*overlapping: An optional bool. Defaults to False. - -*@par Outputs: -*y: A Tensor. Has the same type as orig_input. - -*@attention Constraints:\n -*-The implementation for FractionalMaxPoolGrad on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(FractionalMaxPoolGrad) - .INPUT(orig_input, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .INPUT(orig_output, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .INPUT(out_backprop, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .INPUT(row_pooling_sequence, TensorType({ DT_INT64 })) - .INPUT(col_pooling_sequence, TensorType({ DT_INT64 })) - .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64 })) - .ATTR(overlapping, Bool, false) - .OP_END_FACTORY_REG(FractionalMaxPoolGrad) - -/** -*@brief Performs fractional average pooling on the input. - -*@par Inputs: -*Inputs include: \n -*x: A Tensor. Must be one of the following types: float32, float64, int32, int64. \n - 4-D with shape [batch, height, width, channels]. - -*@par Attributes: -*@li pooling_ratio: A list of floats that has length >= 4. -*@li pseudo_random: An optional bool. Defaults to False. -*@li overlapping: An optional bool. Defaults to False. When set to True, it means when pooling. -*@li deterministic: An optional bool. Defaults to False. -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0. - -*@par Outputs: -*@li y: A Tensor. Has the same type as x. -*@li row_pooling_sequence: A Tensor of type int64. -*@li col_pooling_sequence: A Tensor of type int64. - -*@attention Constraints:\n -*-The implementation for FractionalAvgPool on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(FractionalAvgPool) - .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .OUTPUT(row_pooling_sequence, TensorType({DT_INT64})) - .OUTPUT(col_pooling_sequence, TensorType({DT_INT64})) - .ATTR(pooling_ratio, ListFloat, {}) - .ATTR(pseudo_random, Bool, false) - .ATTR(overlapping, Bool, false) - .ATTR(deterministic, Bool, false) - .ATTR(seed, Int, 0) - .ATTR(seed2, Int, 0) - .OP_END_FACTORY_REG(FractionalAvgPool) - -/** -*@brief Performs fractional max pooling on the input. - -*@par Inputs: -*Inputs include: \n -*x: A Tensor. Must be one of the following types: float32, float64, int32, int64. \n - 4-D with shape [batch, height, width, channels]. - -*@par Attributes: -*@li pooling_ratio: A list of floats that has length >= 4. Pooling ratio for each dimension of value. -*@li pseudo_random: An optional bool. Defaults to False. -*@li overlapping: An optional bool. Defaults to False. -*@li deterministic: An optional bool. Defaults to False. -*@li seed: An optional int. Defaults to 0. -*@li seed2: An optional int. Defaults to 0. - -*@par Outputs: -*@li y: A Tensor. Has the same type as x. -*@li row_pooling_sequence: A Tensor of type int64. -*@li col_pooling_sequence: A Tensor of type int64. - -*@attention Constraints:\n -*-The implementation for FractionalMaxPool on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(FractionalMaxPool) - .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .OUTPUT(row_pooling_sequence, TensorType({DT_INT64})) - .OUTPUT(col_pooling_sequence, TensorType({DT_INT64})) - .ATTR(pooling_ratio, ListFloat, {}) - .ATTR(pseudo_random, Bool, false) - .ATTR(overlapping, Bool, false) - .ATTR(deterministic, Bool, false) - .ATTR(seed, Int, 0) - .ATTR(seed2, Int, 0) - .OP_END_FACTORY_REG(FractionalMaxPool) - -/** -*@brief Finds values of the n-th order statistic for the last dimension. - -*@par Inputs: -*Inputs include: \n -* @li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, \n - int16, int8, int64, bfloat16, uint16, half, uint32, uint64. -* @li n: A Tensor of type int32. 0-D. - -*@par Attributes: -*reverse: An optional bool. Defaults to False. - -*@par Outputs: -*y: A Tensor. Has the same type as x. - -*@attention Constraints:\n -*-The implementation for NthElement on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(NthElement) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, - DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) - .INPUT(n, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, - DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) - .ATTR(reverse, Bool, false) - .OP_END_FACTORY_REG(NthElement) - -/** -*@brief Computes gradient of the FractionalAvgPool function. - -*@par Inputs: -*Inputs include: \n -* @li orig_input_tensor_shape: A Tensor of type int64. -* @li out_backprop: A Tensor. Must be one of the following types: float32, float64, \n - int32, int64. 4-D with shape [batch, height, width, channels]. -* @li row_pooling_sequence: A Tensor of type int64. -* @li col_pooling_sequence: A Tensor of type int64. - -*@par Attributes: -*overlapping: An optional bool. Defaults to False. - -*@par Outputs: -*y: A Tensor. Has the same type as out_backprop. - -*@attention Constraints:\n -*-The implementation for FractionalAvgPoolGrad on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(FractionalAvgPoolGrad) - .INPUT(orig_input_tensor_shape, TensorType({DT_INT64})) - .INPUT(out_backprop, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .INPUT(row_pooling_sequence, TensorType({DT_INT64})) - .INPUT(col_pooling_sequence, TensorType({DT_INT64})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .ATTR(overlapping, Bool, false) - .OP_END_FACTORY_REG(FractionalAvgPoolGrad) - -/** -*@brief Returns the permuted vector/tensor in the destination data format given the. - -*@par Inputs: -*Inputs include: \n -*x: A Tensor. Must be one of the following types: int32, int64. Vector of size 4 \n - or Tensor of shape (4, 2) in source data format. - -*@par Attributes: -*@li src_format: An optional string. Defaults to "NHWC". source data format. -*@li dst_format: An optional string. Defaults to "NCHW". destination data format. - -*@par Outputs: -*y: A Tensor. Has the same type as x. - -*@attention Constraints:\n -*-The implementation for DataFormatVecPermute on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(DataFormatVecPermute) - .INPUT(x, TensorType({ DT_INT32, DT_INT64 })) - .OUTPUT(y, TensorType({ DT_INT32, DT_INT64 })) - .ATTR(src_format, String, "NHWC") - .ATTR(dst_format, String, "NCHW") - .OP_END_FACTORY_REG(DataFormatVecPermute) - -} // namespace ge +#include "nn_pooling_ops.h" #endif // GE_OP_NN_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 10f3f369..d3635c3f 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -17,14 +17,15 @@ #ifndef GE_OP_NN_POOLING_OPS_H #define GE_OP_NN_POOLING_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" +#include "graph/operator.h" namespace ge { /** *@brief Performs pooling on the input. *@par Inputs: -*@li x: An NCHW tensor of type float16. +*@li x: An NCHW tensor of type float16, float32. *@par Attributes: *@li mode: An optional int32, specifying the pooling algorithm, either "1" (max pooling) or "0" (avg pooling). Defaults to "0". *@li global_pooling: An optional bool. Defaults to "false". @@ -46,14 +47,15 @@ namespace ge { *dilation[3]: An optional int32, specifying the right dilation. Defaults to "1". \n *@li ceil_mode: An optional int32, either "0" (ceil mode) or "1" (floor mode). Defaults to "0". *@par Outputs: -*y: An NCHW tensor of type float16. +*y: An NCHW tensor of type float16, float32. *@attention Constraints:\n *@li window[0] * window[1] < 256; *@li 1<=input_h<=4096,1<=input_w<=4096 +*@li If input tensor N is a prime number, it should be less than 65535. */ REG_OP(Pooling) - .INPUT(x, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_INT32})) .ATTR(mode, Int, 0) // 0:max pooling or 1:avg pooling .ATTR(global_pooling, Bool, false) .ATTR(window, ListInt, {1,1}) // kernel size @@ -475,7 +477,7 @@ REG_OP(MaxPoolGradWithArgmaxCCE) *@li x: A tensor of type float16 or float32. *@par Attributes: *@li scale:scale factor of x -*@li stride_h:broadcast the axis of h +*@li stride_h:broadcast the axis of h *@li stride_w:broadcast the axis of w *@par Outputs: *y: A tensor of type float16 or float32. @@ -489,32 +491,195 @@ REG_OP(Upsample) .OP_END_FACTORY_REG(Upsample) /** -*@brief Spatial Pyramid Pooling, multi-level pooling. -* Pooling out(n, sigma(c*2^i*2^i)) tensor, i in range[0,pyramid_height). +*@brief Computes gradient of the FractionalMaxPool function. *@par Inputs: -*x: An NCHW tensor, support float16 or float32 type. +*Inputs include: \n +* @li orig_input: A Tensor. Must be one of the following types: float32, float64, int32, int64. +* @li orig_output: A Tensor. Must have the same type as orig_input. +* @li out_backprop: A Tensor. Must have the same type as orig_input. \n + 4-D with shape [batch, height, width, channels]. +* @li row_pooling_sequence: A Tensor of type int64. +* @li col_pooling_sequence: A Tensor of type int64. *@par Attributes: -* @li pyramid_height: An required int32. -* Multi-level pooling out from 2^0 to 2^(pyramid_height-1). -* @li pool_method: An optional int32, pooling method: 0-MAX, 1-AVE. -* Defaults to "0". +*overlapping: An optional bool. Defaults to False. *@par Outputs: -*y: A NCHW tensor, support float16 or float32 type. +*y: A Tensor. Has the same type as orig_input. -*@attention Constraints: -* @li pyramid_height: pyramid_heigjt should be in range [0,7). -* @li feature_size:input feture map h and w should be [1, 510]. +*@attention Constraints:\n +*-The implementation for FractionalMaxPoolGrad on Ascend uses AICPU, with bad performance.\n +*/ +REG_OP(FractionalMaxPoolGrad) + .INPUT(orig_input, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(orig_output, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(out_backprop, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(row_pooling_sequence, TensorType({ DT_INT64 })) + .INPUT(col_pooling_sequence, TensorType({ DT_INT64 })) + .OUTPUT(y, TensorType({ DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64 })) + .ATTR(overlapping, Bool, false) + .OP_END_FACTORY_REG(FractionalMaxPoolGrad) + +/** +*@brief Performs fractional average pooling on the input. + +*@par Inputs: +*Inputs include: \n +*x: A Tensor. Must be one of the following types: float32, float64, int32, int64. \n + 4-D with shape [batch, height, width, channels]. + +*@par Attributes: +*@li pooling_ratio: A list of floats that has length >= 4. +*@li pseudo_random: An optional bool. Defaults to False. +*@li overlapping: An optional bool. Defaults to False. When set to True, it means when pooling. +*@li deterministic: An optional bool. Defaults to False. +*@li seed: An optional int. Defaults to 0. +*@li seed2: An optional int. Defaults to 0. + +*@par Outputs: +*@li y: A Tensor. Has the same type as x. +*@li row_pooling_sequence: A Tensor of type int64. +*@li col_pooling_sequence: A Tensor of type int64. + +*@attention Constraints:\n +*-The implementation for FractionalAvgPool on Ascend uses AICPU, with bad performance.\n +*/ +REG_OP(FractionalAvgPool) + .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(row_pooling_sequence, TensorType({DT_INT64})) + .OUTPUT(col_pooling_sequence, TensorType({DT_INT64})) + .ATTR(pooling_ratio, ListFloat, {}) + .ATTR(pseudo_random, Bool, false) + .ATTR(overlapping, Bool, false) + .ATTR(deterministic, Bool, false) + .ATTR(seed, Int, 0) + .ATTR(seed2, Int, 0) + .OP_END_FACTORY_REG(FractionalAvgPool) +/** +*@brief Performs fractional max pooling on the input. + +*@par Inputs: +*Inputs include: \n +*x: A Tensor. Must be one of the following types: float32, float64, int32, int64. \n + 4-D with shape [batch, height, width, channels]. + +*@par Attributes: +*@li pooling_ratio: A list of floats that has length >= 4. Pooling ratio for each dimension of value. +*@li pseudo_random: An optional bool. Defaults to False. +*@li overlapping: An optional bool. Defaults to False. +*@li deterministic: An optional bool. Defaults to False. +*@li seed: An optional int. Defaults to 0. +*@li seed2: An optional int. Defaults to 0. + +*@par Outputs: +*@li y: A Tensor. Has the same type as x. +*@li row_pooling_sequence: A Tensor of type int64. +*@li col_pooling_sequence: A Tensor of type int64. + +*@attention Constraints:\n +*-The implementation for FractionalMaxPool on Ascend uses AICPU, with bad performance.\n */ -REG_OP(SPP) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) - .REQUIRED_ATTR(pyramid_height, Int) - .ATTR(pool_method, Int, 0) - .OP_END_FACTORY_REG(SPP) +REG_OP(FractionalMaxPool) + .INPUT(x, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .OUTPUT(row_pooling_sequence, TensorType({DT_INT64})) + .OUTPUT(col_pooling_sequence, TensorType({DT_INT64})) + .ATTR(pooling_ratio, ListFloat, {}) + .ATTR(pseudo_random, Bool, false) + .ATTR(overlapping, Bool, false) + .ATTR(deterministic, Bool, false) + .ATTR(seed, Int, 0) + .ATTR(seed2, Int, 0) + .OP_END_FACTORY_REG(FractionalMaxPool) + +/** +*@brief Finds values of the n-th order statistic for the last dimension. + +*@par Inputs: +*Inputs include: \n +* @li x: A Tensor. Must be one of the following types: float32, float64, int32, uint8, \n + int16, int8, int64, bfloat16, uint16, half, uint32, uint64. +* @li n: A Tensor of type int32. 0-D. + +*@par Attributes: +*reverse: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@attention Constraints:\n +*-The implementation for NthElement on Ascend uses AICPU, with bad performance.\n + +*/ +REG_OP(NthElement) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .INPUT(n, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_DOUBLE})) + .ATTR(reverse, Bool, false) + .OP_END_FACTORY_REG(NthElement) + +/** +*@brief Computes gradient of the FractionalAvgPool function. + +*@par Inputs: +*Inputs include: \n +* @li orig_input_tensor_shape: A Tensor of type int64. +* @li out_backprop: A Tensor. Must be one of the following types: float32, float64, \n + int32, int64. 4-D with shape [batch, height, width, channels]. +* @li row_pooling_sequence: A Tensor of type int64. +* @li col_pooling_sequence: A Tensor of type int64. + +*@par Attributes: +*overlapping: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as out_backprop. + +*@attention Constraints:\n +*-The implementation for FractionalAvgPoolGrad on Ascend uses AICPU, with bad performance.\n + +*/ +REG_OP(FractionalAvgPoolGrad) + .INPUT(orig_input_tensor_shape, TensorType({DT_INT64})) + .INPUT(out_backprop, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(row_pooling_sequence, TensorType({DT_INT64})) + .INPUT(col_pooling_sequence, TensorType({DT_INT64})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .ATTR(overlapping, Bool, false) + .OP_END_FACTORY_REG(FractionalAvgPoolGrad) + +/** +*@brief Returns the permuted vector/tensor in the destination data format given the. + +*@par Inputs: +*Inputs include: \n +*x: A Tensor. Must be one of the following types: int32, int64. Vector of size 4 \n + or Tensor of shape (4, 2) in source data format. + +*@par Attributes: +*@li src_format: An optional string. Defaults to "NHWC". source data format. +*@li dst_format: An optional string. Defaults to "NCHW". destination data format. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@attention Constraints:\n +*-The implementation for DataFormatVecPermute on Ascend uses AICPU, with bad performance.\n + +*/ +REG_OP(DataFormatVecPermute) + .INPUT(x, TensorType({ DT_INT32, DT_INT64 })) + .OUTPUT(y, TensorType({ DT_INT32, DT_INT64 })) + .ATTR(src_format, String, "NHWC") + .ATTR(dst_format, String, "NCHW") + .OP_END_FACTORY_REG(DataFormatVecPermute) + + } // namespace ge #endif // GE_OP_NN_POOLING_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index d800d075..f09c1a8c 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_TRAINING_OPS_H #define GE_OP_TRAINING_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** *@brief Updates "var" according to the AdaMax algorithm.\n @@ -67,6 +67,57 @@ REG_OP(ApplyAdaMax) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdaMax) +/** +*@brief Updates "var" according to the AdaMax algorithm.\n +* t-1 mean previous period. +* m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n +* v_t <- max(beta2 * v{t-1}, abs(grad))\n +* var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon) +* +*@attention Constraints:\n +* the input tensors must have the same shape. +* +*@par Inputs: +*@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType(). +* Should be from a Variable(). +*@li m: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li v: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li beta1_power: A scalar. Has the same type as "var". +*@li lr: learning_rate. A scalar. Has the same type as "var". +*@li beta1: A scalar. Has the same type as "var". +*@li beta2: A scalar. Has the same type as "var". +*@li epsilon: A scalar. Has the same type as "var". +*@li grad: A tensor for the gradient. Has the same type as "var". +* +*@par Attributes:\n +* use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", "ms", and "mom" tensors is protected +* by a lock; otherwise the behavior is undefined, but may exhibit less +* contention. +* +*@par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* +* +*/ +REG_OP(ApplyAdaMaxD) + .INPUT(var, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(v, TensorType::NumberType()) + .INPUT(beta1_power, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(beta1, TensorType::NumberType()) + .INPUT(beta2, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(m, TensorType::NumberType()) + .OUTPUT(v, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAdaMaxD) + /** *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. @@ -113,7 +164,8 @@ REG_OP(SparseApplyAdagrad) *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False". *@par Outputs: -*var: A Tensor. Has the same type and format as input "var". +*@li var: A Tensor. Has the same type and format as input "var". +*@li accum: A Tensor. Has the same type and format as input "var". */ REG_OP(SparseApplyAdagradD) @@ -131,7 +183,7 @@ REG_OP(SparseApplyAdagradD) *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. *@par Inputs: -* Five inputs, including: +*Six inputs, including: *@li var: An NCHW, NHWC, or ND Tensor of type float32. *@li accum: An NCHW, NHWC, or ND Tensor of type float32. *@li lr: An NCHW, NHWC, or ND Tensor of type float32. @@ -141,7 +193,7 @@ REG_OP(SparseApplyAdagradD) *@par Attributes: *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock. -*@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False". +*@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different. *@par Outputs: *var: A Tensor. Has the same type and format as input "var". @@ -163,7 +215,7 @@ REG_OP(SparseApplyAdagradV2) *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. *@par Inputs: -* Four inputs, including: +*Four inputs, including: *@li var: An NCHW, NHWC, or ND Tensor of type float32. *@li accum: An NCHW, NHWC, or ND Tensor of type float32. *@li grad: An NCHW, NHWC, or ND Tensor of type float32. @@ -173,11 +225,11 @@ REG_OP(SparseApplyAdagradV2) *@li lr: Required, used for computation. *@li epsilon: Required, used for computation. *@li use_locking: An optional bool. Defaults to "False". If "True", the operation will be protected by a lock. -*@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False". +*@li update_slots: An optional bool. Defaults to "True". If "False", the computation logic will be different. *@par Outputs: -*var: A Tensor. Has the same type and format as input "var". -*accum: A Tensor. Has the same type and format as input "accum". +*@li var: A Tensor. Has the same type and format as input "var". +*@li accum: A Tensor. Has the same type and format as input "accum". */ REG_OP(SparseApplyAdagradV2D) @@ -247,6 +299,273 @@ REG_OP(ApplyMomentumCCE) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyMomentumCCE) +/** +*@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you +* want to use Nesterov momentum.\n +* computing process: \n +* accum = accum * momentum + grad\n +* var -= lr * accum +* +*@attention Constraints:\n +* the input tensors must have the same shape. +* +*@par Inputs: +*@li var: A mutable tensor. Should be from a Variable(). +*@li accum: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li lr: A scalar. Has the same type as "var". +*@li grad: A tensor for the gradient. Has the same type as "var". +* +*@par Attributes: +*@li use_nesterov: An optional bool. Defaults to "False". +* If "True", the tensor passed to compute grad will be +* var - lr * momentum * accum, so in the end, the var you get is actually +* var - lr * momentum * accum. +* +*@li use_locking: An optional bool. Defaults to "False".\n +* If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* accum: A mutable tensor. Has the same type as input "accum". +* +*/ + +REG_OP(ApplyMomentumD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .ATTR(use_nesterov, Bool, false) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyMomentumD) + +/** +*@brief Updates '*var' according to the momentum scheme. +* accum = accum * momentum - grad * lr \n +* if use_nesterov is True: \n +* var += accum * momentum - grad * lr \n +* else: \n +* var += accum +* +*@par Inputs: +*@li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +*@li accum: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li lr: A tensor for the learning rate. Has the same type as "var". Should be +* from a Variable(). +*@li grad: A tensor for the gradient. Has the same type as "var". Should be +* from a Variable(). +*@li momentum: A scalar. Has the same type as "var". +* +*@par Attributes: +*@li use_nesterov: An optional bool. Defaults to "False". +* If "True", var will be updated by using Nesterov momentum. +*@li use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* +*@attention Constraints: +* The input tensors must have the same shape. +* +* +*/ +REG_OP(ApplyKerasMomentum) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .ATTR(use_nesterov, Bool, false) + .OP_END_FACTORY_REG(ApplyKerasMomentum) + + +/** +*@brief Updates '*var' according to the momentum scheme. +* accum = accum * momentum - grad * lr \n +* if use_nesterov is True: \n +* var += accum * momentum - grad * lr \n +* else: \n +* var += accum +* +*@par Inputs: +*@li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +*@li accum: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li lr: A tensor for the learning rate. Has the same type as "var". Should be +* from a Variable(). +*@li grad: A tensor for the gradient. Has the same type as "var". Should be +* from a Variable(). +*@li momentum: A scalar. Has the same type as "var". Should be from a +* Variable(). +* +*@par Attributes: +*@li use_nesterov: An optional bool. Defaults to "False". +* If "True", var will be updated by using nesterov momentum +*@li use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +*@li var: A mutable tensor. Has the same type as input "var". +*@li accum: A mutable tensor. Has the same type as input "var" +* +*@attention Constraints: +* The input tensors must have the same shape. +* +* +*/ +REG_OP(ApplyKerasMomentumD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .ATTR(use_nesterov, Bool, false) + .OP_END_FACTORY_REG(ApplyKerasMomentumD) + + +/** +*@brief Updates '*var' according to the Adam algorithm.. +* lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t) +* m_t := beta_1 * m_{t-1} + (1 - beta_1) * g +* v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g +* vhat_t := max{vhat_{t-1}, v_t} +* variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon) +* +*@par Inputs: +*@li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +*@li m: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li v: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li vhat: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li lr: A tensor for the learning rate. Has the same type as "var". Should be +* from a Variable(). +*@li grad: A tensor for the gradient. Has the same type as "var". Should be +* from a Variable(). +* +*@par Attributes: +*@li beta1: A scalar. Has the same type as "var". +*@li beta2: A scalar. Has the same type as "var". +*@li epsilon: A scalar. Has the same type as "var". +*@li use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +*@li var: A mutable tensor. Has the same type as input "var". +*@li m: A mutable tensor. Has the same type as input "var" +*@li v: A mutable tensor. Has the same type as input "var" +*@li vhat: A mutable tensor. Has the same type as input "var" +* +*@attention Constraints: +* The input tensors must have the same shape. +* +* +*/ +REG_OP(ApplyAdamWithAmsgradD) + .INPUT(var, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(v, TensorType::NumberType()) + .INPUT(vhat, TensorType::NumberType()) + .INPUT(beta1_power, TensorType::NumberType()) + .INPUT(beta2_power, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(m, TensorType::NumberType()) + .OUTPUT(v, TensorType::NumberType()) + .OUTPUT(vhat, TensorType::NumberType()) + .REQUIRED_ATTR(beta1, Float) + .REQUIRED_ATTR(beta2, Float) + .REQUIRED_ATTR(epsilon, Float) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAdamWithAmsgradD) + + +/** +*@brief Updates '*var' according to the Adam algorithm.. +* lr_t := {learning_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t) +* m_t := beta_1 * m_{t-1} + (1 - beta_1) * g +* v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g +* vhat_t := max{vhat_{t-1}, v_t} +* variable := variable - lr_t * m_t / (sqrt{vhat_t} + epsilon) +* +*@par Inputs: +*@li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +*@li m: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li v: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li vhat: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li beta1_power: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li beta2_power: A mutable tensor. Has the same type as "var". Should be from a +* Variable(). +*@li lr: A tensor for the learning rate. Has the same type as "var". Should be +* from a Variable(). +*@li grad: A tensor for the gradient. Has the same type as "var". Should be +* from a Variable(). +* +*@par Attributes: +*@li beta1: A scalar. Has the same type as "var". +*@li beta2: A scalar. Has the same type as "var". +*@li epsilon: A scalar. Has the same type as "var". +*@li use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* +*@par Outputs: +*@li var: A mutable tensor. Has the same type as input "var". +*@li m: A mutable tensor. Has the same type as input "var" +*@li v: A mutable tensor. Has the same type as input "var" +*@li vhat: A mutable tensor. Has the same type as input "var" +* +*@attention Constraints: +* The input tensors must have the same shape. +* +* +*/ +REG_OP(ApplyAdamWithAmsgrad) + .INPUT(var, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(v, TensorType::NumberType()) + .INPUT(vhat, TensorType::NumberType()) + .INPUT(beta1_power, TensorType::NumberType()) + .INPUT(beta2_power, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(beta1, TensorType::NumberType()) + .INPUT(beta2, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAdamWithAmsgrad) + + /** *@brief Updates "var" according to the AddSign update.\n * t-1 mean previous period. @@ -289,6 +608,51 @@ REG_OP(ApplyPowerSign) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyPowerSign) +/** +*@brief Updates "var" according to the AddSign update.\n +* t-1 mean previous period. +* m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n +* update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n +* var <- var - lr * update +* +*@attention Constraints:\n +* the input tensors must have the same shape. +* +*@par Inputs: +*@li var: A mutable tensor. Should be from a Variable(). +*@li m: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li lr: A scalar. Has the same type as "var". +*@li logbase: A scalar. Has the same type as "var". +*@li sign_decay: A scalar. Has the same type as "var". +*@li beta: A scalar. Has the same type as "var". +*@li grad: A tensor for the gradient. Has the same type as "var". +* +*@par Attributes: +* use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", "ms", and "mom" tensors is protected +* by a lock; otherwise the behavior is undefined, but may exhibit less +* contention. +* +*@par Outputs: +*@li var: A mutable tensor. Has the same type as input "var". +*@li m: A mutable tensor. Has the same type as input "var". +* +* +*/ +REG_OP(ApplyPowerSignD) + .INPUT(var, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(logbase, TensorType::NumberType()) + .INPUT(sign_decay, TensorType::NumberType()) + .INPUT(beta, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(m, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyPowerSignD) + /** *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n * prox_v = var - alpha * delta\n @@ -361,6 +725,46 @@ REG_OP(ApplyAddSign) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAddSign) +/** +*@brief Updates "var" according to the AddSign update. + +*@par Inputs: +*Seven inputs, including: +* @li var: A mutable Tensor of type TensorType::NumberType(). +* Should be a Variable Tensor. +* @li m: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. +* @li alpha: A Tensor of the same type as "var". Must be a scalar. +* @li sign_decay: A Tensor of the same type as "var". Must be a scalar. +* @li beta: A Tensor of the same type as "var". Must be a scalar. +* @li grad: A Tensor of the same type as "var", for the gradient. + + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" and "m" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. + +*@par Outputs: +*@li var: A mutable Tensor. Has the same type as "var". +*@li m: A mutable Tensor. Has the same type as "m". + +*/ +REG_OP(ApplyAddSignD) + .INPUT(var, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(alpha, TensorType::NumberType()) + .INPUT(sign_decay, TensorType::NumberType()) + .INPUT(beta, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(m, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAddSignD) + /** *@brief Updates "var" according to the centered RMSProp algorithm.\n * The centered RMSProp algorithm uses an estimate of the centered second moment @@ -417,6 +821,70 @@ REG_OP(ApplyCenteredRMSProp) .OUTPUT(var, TensorType::NumberType()) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyCenteredRMSProp) + +/** +*@brief Updates "var" according to the centered RMSProp algorithm.\n +* The centered RMSProp algorithm uses an estimate of the centered second moment +* (i.e., the variance) for normalization, as opposed to regular RMSProp, which +* uses the (uncentered) second moment. This often helps with training, but is +* slightly more expensive in terms of computation and memory. +* +* t-1 mean previous period. +* mg <- rho * mg{t-1} + (1-rho) * grad\n +* ms <- rho * ms{t-1} + (1-rho) * grad * grad\n +* mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n +* var <- var - mom\n +* +*@attention Constraints:\n +*@li in dense implementation of this algorithm, mg, ms, and mom will +* update even if the grad is zero, but in this sparse implementation, mg, ms, +* and mom will not update in iterations during which the grad is zero. +*@li the input tensors must have the same shape. +* +*@par Inputs: +*@li var: A mutable tensor. Should be from a Variable(). +*@li mg: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li ms: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li mom: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li lr: A scalar. Has the same type as "var". +*@li rho: A scalar. Has the same type as "var". +*@li momentum: A tensor. Has the same type as "var". +*@li epsilon: A scalar. Has the same type as "var". +*@li grad: A tensor for the gradient. Has the same type as "var". +* +*@par Attributes: +* use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", "ms", and "mom" tensors is protected +* by a lock; otherwise the behavior is undefined, but may exhibit less +* contention. +* +*@par Outputs: +*@li var: A mutable Tensor. Has the same type as "var". +*@li mg: A mutable Tensor. Has the same type as "mg". +*@li ms: A mutable Tensor. Has the same type as "ms". +*@li mom: A mutable Tensor. Has the same type as "mom". + +* +*/ +REG_OP(ApplyCenteredRMSPropD) + .INPUT(var, TensorType::NumberType()) + .INPUT(mg, TensorType::NumberType()) + .INPUT(ms, TensorType::NumberType()) + .INPUT(mom, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(rho, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(mg, TensorType::NumberType()) + .OUTPUT(ms, TensorType::NumberType()) + .OUTPUT(mom, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyCenteredRMSPropD) /** *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n @@ -442,11 +910,46 @@ REG_OP(ApplyCenteredRMSProp) */ REG_OP(ApplyGradientDescent) .INPUT(var, TensorType::NumberType()) - .INPUT(alpha, TensorType::NumberType()) - .INPUT(delta, TensorType::NumberType()) + .INPUT(alpha, TensorType::NumberType()) + .INPUT(delta, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyGradientDescent) + +/** +*@brief Updates "var" according to the adagrad scheme.\n +* accum += grad * grad\n +* var -= lr * grad * (1 / sqrt(accum)) +* +*@attention Constraints:\n +* the input tensors must have the same shape. +* +*@par Inputs: +*@li var: A mutable tensor. Should be from a Variable(). +*@li accum: A mutable tensor. Has the same type as "var". +* Should be from a Variable(). +*@li lr: A scalar. Has the same type as "var". +*@li grad: A tensor for the gradient. Has the same type as "var". +* +*@par Attributes: +* use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", "ms", and "mom" tensors is protected +* by a lock; otherwise the behavior is undefined, but may exhibit less +* contention. +* +*@par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* +*/ +REG_OP(ApplyAdagrad) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) .OUTPUT(var, TensorType::NumberType()) + .ATTR(update_slots, Bool, true) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyGradientDescent) + .OP_END_FACTORY_REG(ApplyAdagrad) /** *@brief Updates "var" according to the adagrad scheme.\n @@ -470,47 +973,50 @@ REG_OP(ApplyGradientDescent) * contention. * *@par Outputs: -* var: A mutable tensor. Has the same type as input "var". +*@li var: A mutable tensor. Has the same type as input "var". +*@li accum: A mutable tensor. Has the same type as input "var". +* * */ -REG_OP(ApplyAdagrad) +REG_OP(ApplyAdagradD) .INPUT(var, TensorType::NumberType()) .INPUT(accum, TensorType::NumberType()) .INPUT(lr, TensorType::NumberType()) .INPUT(grad, TensorType::NumberType()) .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) .ATTR(update_slots, Bool, true) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAdagrad) + .OP_END_FACTORY_REG(ApplyAdagradD) /** -* @brief Updates "var" according to the adagradv2 scheme.\n +* @brief Updates "var" according to the adagradv2 scheme. * accum += grad * grad \n * var -= lr * grad * (1 / sqrt(accum) + epsilon) * -* @attention Constraints: -* the input tensors must have the same shape. -* * @par Inputs: * @li var: A mutable tensor. Must be one of the data types defined in -* TensorType::NumberType(). Should be from a Variable(). +* TensorType::NumberType(). Should be from a Variable(). * @li accum: A mutable tensor. Has the same type as "var". Should be from a -* Variable(). +* Variable(). * @li lr: A tensor for the learning rate. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li grad: A tensor for the gradient. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li epsilon: A scalar. Has the same type as "var". * * @par Attributes: * @li update_slots: An optional bool. Defaults to "True". -* If "True", accum will be updated +* If "True", "accum" will be updated * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" tensor is protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. * * @par Outputs: -* var: A mutable tensor. Has the same type as input "var". +* var: A mutable tensor. Has the same type as input "var". +* +* @attention Constraints: +* The input tensors must have the same shape. * */ REG_OP(ApplyAdagradV2) @@ -526,33 +1032,33 @@ REG_OP(ApplyAdagradV2) /** -* @brief Updates "var" according to the adagradv2 scheme.\n -* accum += grad * grad \n -* var -= lr * grad * (1 / sqrt(accum) + epsilon) -* -* @attention Constraints: -* the input tensors must have the same shape. +* @brief Updates "var" according to the adagradv2 scheme. +* accum += grad * grad \n +* var -= lr * grad * (1 / sqrt(accum) + epsilon) * * @par Inputs: * @li var: A mutable tensor. Must be one of the data types defined in -* TensorType::NumberType(). Should be from a Variable(). +* TensorType::NumberType(). Should be from a Variable(). * @li accum: A mutable tensor. Has the same type as "var". Should be from a -* Variable(). +* Variable(). * @li lr: A tensor for the learning rate. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li grad: A tensor for the gradient. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * * @par Attributes: * @li epsilon: A scalar. Has the same type as "var". * @li update_slots: An optional bool. Defaults to "True". -* If "True", accum will be updated +* If "True", "accum" will be updated * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" tensor is protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. * * @par Outputs: -* var: A mutable tensor. Has the same type as input "var". +* var: A mutable tensor. Has the same type as input "var". +* +* @attention Constraints: +* The input tensors must have the same shape. * */ REG_OP(ApplyAdagradV2D) @@ -610,6 +1116,54 @@ REG_OP(ApplyAdagradDA) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdagradDA) +/** +*@brief Updates "var" according to the proximal adagrad scheme. + +*@par Inputs: +*Eight inputs, including: +* @li var: A mutable Tensor. Must be one of the following types: +* TensorType::NumberType(). Should be a Variable Tensor. +* @li gradient_accumulator: A mutable Tensor. Must have the same +* type as "var". Should be a Variable Tensor. +* @li gradient_squared_accumulator: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li grad: A Tensor of the same type as "var", for the gradient. +* @li lr: A Tensor of the same type as "var". +* Scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var". +* L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var". +* L2 regulariation. Must be a scalar. +* @li global_step: A Tensor of type int32 or int64. +* Training step number. Must be a scalar. + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". +* If "True", updating of the var and accum tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. + +*@par Outputs: +*var: A mutable Tensor. Has the same type as "var". +*gradient_accumulator: A mutable Tensor. Has the same type as "var". +*gradient_squared_accumulator: A mutable Tensor. Has the same type as "var". + +*/ +REG_OP(ApplyAdagradDAD) + .INPUT(var, TensorType::NumberType()) + .INPUT(gradient_accumulator, TensorType::NumberType()) + .INPUT(gradient_squared_accumulator, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(global_step, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(gradient_accumulator, TensorType::NumberType()) + .OUTPUT(gradient_squared_accumulator, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAdagradDAD) + /** *@brief Returns the dimension index in the destination data format given the one in * the source data format. @@ -798,7 +1352,9 @@ REG_OP(ApplyRMSPropD) *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention. *@par Outputs: -*var: A mutable Tensor. Has the same type as "var". +* @li var: A mutable tensor. Must have the same type as input "var". +* @li ms: A mutable tensor. Must have the same type as input "ms". +* @li mom: A mutable tensor. Must have the same type as input "mom". */ REG_OP(ApplyProximalAdagrad) .INPUT(var, TensorType::NumberType()) @@ -811,6 +1367,39 @@ REG_OP(ApplyProximalAdagrad) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyProximalAdagrad) +/** +*@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate. + +*@par Inputs: +*Six inputs, including: +* @li var: A mutable Tensor of type TensorType::NumberType(). +* Should be from a Variable(). +* @li accum: A mutable Tensor of the same type as "var". Should be from a Variable(). +* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. +* @li grad: A Tensor of the same type as "var", for the gradient. + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention. + +*@par Outputs: +* @li var: A mutable Tensor. Has the same type as "var". +* @li accum: A mutable Tensor. Has the same type as "var". + +*/ +REG_OP(ApplyProximalAdagradD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyProximalAdagradD) + /** *@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n * Compared with op ApplyProximalAdagrad, an additional index tensor is input, @@ -853,6 +1442,51 @@ REG_OP(SparseApplyProximalAdagrad) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(SparseApplyProximalAdagrad) +/** +*@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n +* Compared with op ApplyProximalAdagrad, an additional index tensor is input, +* Only the indices into the first dimensions of "var" and "accum" are updated. + +*@par Inputs: +* Seven inputs, including:\n +* @li var: A mutable Tensor.\n +* TensorType::NumberType(). Should be a Variable Tensor. +* @li accum: A mutable Tensor of the same type as "var".\n +* Should be a Variable Tensor. +* @li lr: A Tensor of the same type as "var".\n +* Scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var".\n +* L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var".\n +* L2 regulariation. Must be a scalar. +* @li grad: A Tensor. Has the same type as "var". \n +* The gradient. +* @li indices: A vector of indices into the first dimension of "var" and "accum".\n +* TensorType::IndexNumberType(). + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False".\n +* If "True", updating of the var and accum tensors will be protected by a lock; \n +* If "False", the behavior is undefined, but may exhibit less contention. + +*@par Outputs: +*@li var: A mutable Tensor. Has the same type as "var". +*@li accum: A mutable Tensor. Has the same type as "var". + +*/ +REG_OP(SparseApplyProximalAdagradD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(SparseApplyProximalAdagradD) + /** *@brief Updates "var" according to the Ftrl-proximal scheme. @@ -892,6 +1526,92 @@ REG_OP(ApplyFtrl) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyFtrl) +/** +*@brief Updates "var" according to the Ftrl-proximal scheme. + +*@par Inputs: +*Eight inputs, including: +* @li var: A mutable Tensor. Must be of type TensorType::NumberType(). +* Should be a Variable Tensor. +* @li accum: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li linear: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li grad: A Tensor of the same type as "var", for the gradient. +* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. +* @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. + +*@par Outputs: +*@li var: A mutable Tensor. Has the same type as "var". +*@li accum: A mutable Tensor. Has the same type as "accum". +*@li linear: A mutable Tensor. Has the same type as "linear". + +*/ +REG_OP(ApplyFtrlD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(linear, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(lr_power, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .OUTPUT(linear, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyFtrlD) + +/** +*@brief Update "var" according to the Ftrl-proximal scheme. + +*@par Inputs: +*Nine inputs, including: +* @li var: A mutable Tensor. Must be of type TensorType::NumberType(). +* Should be a Variable Tensor. +* @li accum: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li linear: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li grad: A Tensor of the same type as "var", for the gradient. +* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. +* @li l2_shrinkage: A Tensor of the same type as "var". +* @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. + +*@par Outputs: +*var: A mutable Tensor. Has the same type as "var". + +*/ +REG_OP(ApplyFtrlV2) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(linear, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(l2_shrinkage, TensorType::NumberType()) + .INPUT(lr_power, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyFtrlV2) + /** *@brief Update "var" according to the Ftrl-proximal scheme. @@ -917,22 +1637,78 @@ REG_OP(ApplyFtrl) * but may exhibit less contention. *@par Outputs: -*var: A mutable Tensor. Has the same type as "var". - +*var: A mutable Tensor. Has the same type as "var". +*accum: A mutable Tensor. Has the same type as "accum". +*linear: A mutable Tensor. Has the same type as "linear". + +*/ +REG_OP(ApplyFtrlV2D) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(linear, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(l1, TensorType::NumberType()) + .INPUT(l2, TensorType::NumberType()) + .INPUT(l2_shrinkage, TensorType::NumberType()) + .INPUT(lr_power, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .OUTPUT(linear, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyFtrlV2D) + +/** +*@brief Updates "var" according to the Adam algorithm.\n +* lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n +* m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n +* v_t <- max(beta2 * v{t-1}, abs(g))\n +* variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon) +* +*@attention Constraints:\n +* *The input tensors must have the same shape.* +* +*@par Inputs: +*@li var: A mutable Tensor of the type TensorType::NumberType(). +* Should be from a Variable(). +*@li m: A mutable Tensor of the same type as "var". +* Should be from a Variable(). +*@li v: A mutable Tensor of the same type as "var". +* Should be from a Variable(). +*@li beta1_power: A scalar of the same type as "var". +*@li beta2_power: A scalar of the same type as "var". +*@li lr: learning_rate. A scalar of the same type as "var". +*@li beta1: A scalar of the same type as "var". +*@li beta2: A scalar of the same type as "var". +*@li epsilon: A scalar of the same type as "var". +*@li grad: A Tensor of the same type as "var", for the gradient. +* +*@par Attributes:\n +*@li use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", m", and "v" tensors will be protected +* by a lock; otherwise the behavior is undefined, but may exhibit less +* contention. +*@li use_nesterov: An optional bool. Defaults to "False". + If "True", uses the nesterov update. +* +*@par Outputs: +* var: A mutable Tensor. Has the same type as intput "var". */ -REG_OP(ApplyFtrlV2) +REG_OP(ApplyAdam) .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(linear, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) + .INPUT(m, TensorType::NumberType()) + .INPUT(v, TensorType::NumberType()) + .INPUT(beta1_power, TensorType::NumberType()) + .INPUT(beta2_power, TensorType::NumberType()) .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(l2_shrinkage, TensorType::NumberType()) - .INPUT(lr_power, TensorType::NumberType()) + .INPUT(beta1, TensorType::NumberType()) + .INPUT(beta2, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) .OUTPUT(var, TensorType::NumberType()) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyFtrlV2) + .ATTR(use_nesterov, Bool, false) + .OP_END_FACTORY_REG(ApplyAdam) /** *@brief Updates "var" according to the Adam algorithm.\n @@ -968,9 +1744,11 @@ REG_OP(ApplyFtrlV2) If "True", uses the nesterov update. * *@par Outputs: -* var: A mutable Tensor. Has the same type as intput "var". +*@li var: A mutable tensor. Has the same type as input "var". +*@li m: A mutable tensor. Has the same type as input "m". +*@li v: A mutable tensor. Has the same type as input "v". */ -REG_OP(ApplyAdam) +REG_OP(ApplyAdamD) .INPUT(var, TensorType::NumberType()) .INPUT(m, TensorType::NumberType()) .INPUT(v, TensorType::NumberType()) @@ -986,7 +1764,7 @@ REG_OP(ApplyAdam) .OUTPUT(v, TensorType::NumberType()) .ATTR(use_locking, Bool, false) .ATTR(use_nesterov, Bool, false) - .OP_END_FACTORY_REG(ApplyAdam) + .OP_END_FACTORY_REG(ApplyAdamD) /** *@brief Updates "var" according to the proximal adadelta scheme. @@ -1025,6 +1803,48 @@ REG_OP(ApplyAdadelta) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdadelta) +/** +*@brief Updates "var" according to the proximal adadelta scheme. + +*@par Inputs: +*Seven inputs, including: +* @li var: A mutable Tensor of type TensorType::NumberType(). +* Should be a Variable Tensor. +* @li accum: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li accum_update: A mutable Tensor of the same type as "var". +* Should be a Variable Tensor. +* @li lr: A scalar of the same type as "var", for the scaling factor. +* @li rho: A scalar of the same type as "var", for the decay factor. +* @li epsilon: A scalar of the same type as "var", for the constant factor. +* @li grad: A Tensor of the same type as "var", for the gradient. + +*@par Attributes: +*use_locking: An optional bool. Defaults to "False". +* If "True", updating of the "var", "accum" and "accum_update" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. + +*@par Outputs: +*@li var: A mutable Tensor. Has the same type as "var". +*@li accum: A mutable Tensor. Has the same type as "var". +*@li accum_update: A mutable Tensor. Has the same type as "var". + +*/ +REG_OP(ApplyAdadeltaD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(accum_update, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(rho, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .OUTPUT(accum_update, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(ApplyAdadeltaD) + /** * @brief Updates "var" according to the ApplyMomentum algorithm. \n * accum = accum * momentum + x1 * x2 \n @@ -1193,11 +2013,11 @@ REG_OP(LarsV2Update) * @par Inputs: * Nine inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. @@ -1207,9 +2027,9 @@ REG_OP(LarsV2Update) * @par Attributes: * use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: * var: A Tensor. Has the same type and format as input "var". @@ -1233,13 +2053,13 @@ REG_OP(SparseApplyFtrl) * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme. * @par Inputs: -* Nine inputs, including: +* Five inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. @@ -1249,14 +2069,14 @@ REG_OP(SparseApplyFtrl) * @par Attributes: * use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: -* var: A Tensor. Has the same type and format as input "var". -* accum: A Tensor. Has the same type and format as input "accum". -* linear: A Tensor. Has the same type and format as input "linear". +* @li var: A Tensor. Has the same type and format as input "var". +* @li accum: A Tensor. Has the same type and format as input "accum". +* @li linear: A Tensor. Has the same type and format as input "linear". */ REG_OP(SparseApplyFtrlD) @@ -1276,8 +2096,8 @@ REG_OP(SparseApplyFtrlD) .OP_END_FACTORY_REG(SparseApplyFtrlD) /** -* @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme. -* That is for rows we have grad for, we update var, accum and linear +* @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme. +* That is for rows we have grad for, "var", "accum" and "linear" are updated. * @par Inputs: * Ten inputs, including: @@ -1288,7 +2108,7 @@ REG_OP(SparseApplyFtrlD) * @li linear: A mutable Tensor of the same type as "var". * Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. -* @li indices: A vector of indices into the first dimension of var and accum. +* @li indices: A vector of indices into the first dimension of "var" and "accum". * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. @@ -1298,7 +2118,7 @@ REG_OP(SparseApplyFtrlD) * @par Attributes: * use_locking: An optional bool. Defaults to "False". * If "True", updating of the "var" and "accum" tensors will be -* rotected by a lock; otherwise the behavior is undefined, +* protected by a lock; otherwise the behavior is undefined, * but may exhibit less contention. * @par Outputs: @@ -1321,19 +2141,19 @@ REG_OP(SparseApplyFtrlV2) .OP_END_FACTORY_REG(SparseApplyFtrlV2) /** -* @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme. -* That is for rows we have grad for, we update var, accum and linear +* @brief Updates relevant entries in '*var' according to the Ftrl-proximal scheme. +* That is for rows we have grad for, "var", "accum" and "linear" are updated. * @par Inputs: -* Ten inputs, including: +* Five inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. -* @li indices: A vector of indices into the first dimension of var and accum. +* @li indices: A vector of indices into the first dimension of "var" and "accum". * @par Attributes: * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. @@ -1342,14 +2162,14 @@ REG_OP(SparseApplyFtrlV2) * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* rotected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: -* var: A Tensor. Has the same type and format as input "var". -* accum: A Tensor. Has the same type and format as input "accum". -* linear: A Tensor. Has the same type and format as input "linear". +* @li var: A Tensor. Has the same type and format as input "var". +* @li accum: A Tensor. Has the same type and format as input "accum". +* @li linear: A Tensor. Has the same type and format as input "linear". */ REG_OP(SparseApplyFtrlV2D) @@ -1369,6 +2189,208 @@ REG_OP(SparseApplyFtrlV2D) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(SparseApplyFtrlV2D) +/** +* @brief Updates "var" in specified index according to the RMSProp algorithm. +* mean_square = decay * mean_square + (1-decay) * gradient ** 2\n +* Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n +* ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n +* mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n +* var <- var - mom\n +* +* @par Inputs: +* @li var: A mutable tensor. Must be one of the data types defined in\n +* TensorType::NumberType(). Should be from a Variable(). +* @li ms: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li mom: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li lr: A scalar. Must have the same type as "var". +* @li rho: A scalar. Must have the same type as "var". +* @li momentum: A scalar. Must have the same type as "var". +* @li epsilon: A scalar. Must have the same type as "var". +* @li grad: A tensor, specifying the gradient. +* @li indices: A vector of indices into the first dimension of "var", "mom" and "ms". +* +* @par Attributes: +* use_locking: An optional "bool". Defaults to "False". If "True", updating of +* the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the +* behavior is undefined, but may exhibit less contention. +* +* @par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* +* @attention Constraints: +* @li Note that in this sparse implementation, "ms" and "mom" will not update +* in iterations during which "grad" is 0. +* @li The input tensors "var", "ms", and "mom" must have the same shape. +* +*/ +REG_OP(SparseApplyRMSProp) + .INPUT(var, TensorType::NumberType()) + .INPUT(ms, TensorType::NumberType()) + .INPUT(mom, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(rho, TensorType::NumberType()) + .INPUT(momentum, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(SparseApplyRMSProp) + +/** +* @brief Updates "var" in specified index according to the RMSProp algorithm. +* a const input will be considered as an attribute.\n +* mean_square = decay * mean_square + (1-decay) * gradient ** 2\n +* Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n +* ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n +* mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n +* var <- var - mom +* +* @par Inputs: +* @li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +* @li ms: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li mom: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li lr: A scalar. Must have the same type as "var". +* @li grad: A tensor, specifying the gradient. +* +* @par Attributes: +* @li use_locking: An optional "bool". Defaults to "False". If "True", +* updating of the "var", "ms", and "mom" tensors will be protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* @li rho: A required scalar. Must have the same type as "var". +* @li momentum: A required scalar. Must have the same type as "var". +* @li epsilon: A required scalar. Must have the same type as "var". +* +* @par Outputs: +* @li var: A mutable tensor. Must have the same type as input "var". +* @li ms: A mutable tensor. Must have the same type as input "ms". +* @li mom: A mutable tensor. Must have the same type as input "mom". +* +* @attention Constraints: +* @li Note that in this sparse implementation, "ms" and "mom" will not update +* in iterations during which "grad" is 0. +* @li The input tensors "var", "ms" and "mom" must have the same shape. +*/ +REG_OP(SparseApplyRMSPropD) + .INPUT(var, TensorType::NumberType()) + .INPUT(ms, TensorType::NumberType()) + .INPUT(mom, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(ms, TensorType::NumberType()) + .OUTPUT(mom, TensorType::NumberType()) + .REQUIRED_ATTR(rho, Float) + .REQUIRED_ATTR(momentum, Float) + .REQUIRED_ATTR(epsilon, Float) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(SparseApplyRMSPropD) + +/** +* @brief Updates "var" in specified index according to the Adadelta algorithm. +* accum <- rho * accum + (1 - rho) * grad.square()\n +* update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n +* var <- var - update * lr\n +* accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n +* +* @par Inputs: +* @li var: A mutable tensor. Must be one of the data types defined in\n +* TensorType::NumberType(). Should be from a Variable(). +* @li accum: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li lr: A scalar. Must have the same type as "var". +* @li rho: A scalar. Must have the same type as "var". +* @li epsilon: A scalar. Must have the same type as "var". +* @li grad: A tensor, specifying the gradient. +* @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update". +* +* @par Attributes: +* use_locking: An optional "bool". Defaults to "False". If "True", updating of +* the "var", "accum", and "accum_update" tensors will be protected by a lock; otherwise the +* behavior is undefined, but may exhibit less contention. +* +* @par Outputs: +* var: A mutable tensor. Has the same type as input "var". +* +* @attention Constraints: +* @li Note that in this sparse implementation, "accum" and "accum_update" will not update +* in iterations during which "grad" is 0. +* @li The input tensors "var", "accum", and "accum_update" must have the same shape. +* +*/ +REG_OP(SparseApplyAdadelta) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(accum_update, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(rho, TensorType::NumberType()) + .INPUT(epsilon, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) + .OUTPUT(var, TensorType::NumberType()) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(SparseApplyAdadelta) + +/** +* @brief Updates "var" in specified index according to the Adadelta algorithm. +* a const input will be considered as an attribute.\n +* accum <- rho * accum + (1 - rho) * grad.square()\n +* update <- (accum_update + epsilon).sqrt() * (accum + epsilon()).rsqrt() * grad\n +* var <- var - update * lr\n +* accum_update <- rho() * accum_update + (1 - rho()) * update.square()\n +* +* @par Inputs: +* @li var: A mutable tensor. Must be one of the data types defined in +* TensorType::NumberType(). Should be from a Variable(). +* @li accum: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li accum_update: A mutable tensor. Must have the same type as "var". Should be from a +* Variable(). +* @li lr: A scalar. Must have the same type as "var". +* @li rho: A scalar. Must have the same type as "var". +* @li grad: A tensor, specifying the gradient. +* @li indices: A vector of indices into the first dimension of "var", "accum" and "accum_update". +* +* @par Attributes: +* @li use_locking: An optional "bool". Defaults to "False". If "True", +* updating of the "var", "accum", and "accum_update" tensors will be protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. +* @li epsilon: A required scalar. Must have the same type as "var". +* +* @par Outputs: +* @li var: A mutable tensor. Must have the same type as input "var". +* @li accum: A mutable tensor. Must have the same type as input "accum". +* @li accum_update: A mutable tensor. Must have the same type as input "accum_update". +* +* @attention Constraints: +* @li Note that in this sparse implementation, "accum" and "accum_update" will not update +* in iterations during which "grad" is 0. +* @li The input tensors "var", "accum" and "accum_update" must have the same shape. +*/ +REG_OP(SparseApplyAdadeltaD) + .INPUT(var, TensorType::NumberType()) + .INPUT(accum, TensorType::NumberType()) + .INPUT(accum_update, TensorType::NumberType()) + .INPUT(lr, TensorType::NumberType()) + .INPUT(rho, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) + .OUTPUT(var, TensorType::NumberType()) + .OUTPUT(accum, TensorType::NumberType()) + .OUTPUT(accum_update, TensorType::NumberType()) + .REQUIRED_ATTR(epsilon, Float) + .ATTR(use_locking, Bool, false) + .OP_END_FACTORY_REG(SparseApplyAdadeltaD) + + /** *@brief Clean memory of workspace list. diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index 992077ad..15bd8812 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_NONLINEAR_FUC_OPS_H #define GE_OP_NONLINEAR_FUC_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -172,24 +172,6 @@ REG_OP(SigmoidGrad) .OUTPUT(z, TensorType(UnaryDataType)) .OP_END_FACTORY_REG(SigmoidGrad) -REG_OP(Activation) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - /* - 0: sigmod, 1: relu, 2: tanh, 3: clipped ReLU, 4: Elu, - 5: leaky relu, 6: abs, 7: relu1, 8: softsign, 9: softplus - */ - .ATTR(mode, Int, 1) - .ATTR(coef, Float, 0) - .OP_END_FACTORY_REG(Activation) - -REG_OP(ActivationGrad) - .INPUT(dy, TensorType{DT_FLOAT}) - .INPUT(x, TensorType{DT_FLOAT}) - .OUTPUT(dx, TensorType{DT_FLOAT}) - .ATTR(mode, Int, 1) - .OP_END_FACTORY_REG(ActivationGrad) - /** *@brief Computes the binomial normal log likelihood (BNLL) output:\n *if x>0, x+log(1+exp(-x)); otherwise log(1+exp(x)). @@ -433,7 +415,7 @@ REG_OP(EluGrad) *@par Inputs: * One input: -* x: A Tensor. Must be one of the following types: float32, float16, int32, int8, double. +* x: A Tensor. Must be one of the following types: float32, float16, double. * *@par Attributes: *negative_slope: A float32. Defaults to "0.0". @@ -442,31 +424,42 @@ REG_OP(EluGrad) *y: A Tensor. Has the same type as "x". */ REG_OP(LeakyRelu) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_DOUBLE})) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE})) .ATTR(negative_slope, Float, 0.0) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT8, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE})) .OP_END_FACTORY_REG(LeakyRelu) /** -*@brief Computes the output as g if x > 0 and negative_slope * g if x <= 0. +*@brief Computes the output as gradients if features > 0 and negative_slope * gradients if features <= 0. *@par Inputs: * Two inputs, including: -* @li g: A Tensor. Must be one of the following types: float16, float32, double. -* @li x: A Tensor. Has the same type as "g". +* @li gradients: A Tensor. Must be one of the following types: float16, float32, double. +* @li features: A Tensor. Has the same type as "gradients". *@par Attributes: *negative_slope: A float32. Defaults to "0.0". *@par Outputs: -*y: A Tensor. Has the same type as "g". +*backprops: A Tensor. Has the same type as "gradients". */ REG_OP(LeakyReluGrad) -.INPUT(g, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) -.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) -.ATTR(negative_slope, Float, 0.0) -.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) -.OP_END_FACTORY_REG(LeakyReluGrad) + .INPUT(gradients, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(features, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .ATTR(negative_slope, Float, 0.0) + .OUTPUT(backprops, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(LeakyReluGrad) + +REG_OP(threshold_grad_v2_d) + .INPUT(input_x, TensorType({DT_INT32, DT_FLOAT16})) + .INPUT(input_y, TensorType({DT_INT32, DT_FLOAT16})) + .OUTPUT(output_z, TensorType({DT_INT32, DT_FLOAT16})) + .OP_END_FACTORY_REG(threshold_grad_v2_d) + +REG_OP(ThresholdV2D) + .INPUT(x, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .OP_END_FACTORY_REG(ThresholdV2D) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h b/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h index daeea466..1c702738 100644 --- a/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h +++ b/third_party/fwkacllib/inc/ops/npu_loss_scale_ops.h @@ -16,7 +16,7 @@ #ifndef GE_OP_NN_LOSS_SCALE_OPS_H #define GE_OP_NN_LOSS_SCALE_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { REG_OP(NPUAllocFloatStatusOperator) diff --git a/third_party/fwkacllib/inc/ops/outfeed_ops.h b/third_party/fwkacllib/inc/ops/outfeed_ops.h index 049d83d9..af27140a 100644 --- a/third_party/fwkacllib/inc/ops/outfeed_ops.h +++ b/third_party/fwkacllib/inc/ops/outfeed_ops.h @@ -17,35 +17,7 @@ #ifndef GE_OP_OUTFEED_OPS_H #define GE_OP_OUTFEED_OPS_H -#include "graph/operator.h" -#include "graph/operator_reg.h" - -namespace ge { - -/** -*@brief Enqueue a Tensor on the computation outfeed. - -*@par Inputs: -*Inputs include: \n -*x: A Tensor. Must be one of the following types: float16, float32, \n -float64, int8, int16, uint16, uint8, int32, int64, uint32, uint64, \n -bool, double, string. - -*@par Attributes: -*channel_name: name of operator channel, default "". - -*@attention Constraints:\n -*-The implementation for OutfeedEnqueueOp on Ascend uses AICPU, with bad performance.\n - -*/ -REG_OP(OutfeedEnqueueOp) - .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, - DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE, DT_STRING})) - .ATTR(channel_name, String, "") - .OP_END_FACTORY_REG(OutfeedEnqueueOp) - -} // namespace ge +#include "data_flow_ops.h" #endif // GE_OP_OUTFEED_OPS_H diff --git a/third_party/fwkacllib/inc/ops/pad_ops.h b/third_party/fwkacllib/inc/ops/pad_ops.h index dc471909..346c72a1 100644 --- a/third_party/fwkacllib/inc/ops/pad_ops.h +++ b/third_party/fwkacllib/inc/ops/pad_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_PAD_OPS_H #define GE_OP_PAD_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** diff --git a/third_party/fwkacllib/inc/ops/power_ops.h b/third_party/fwkacllib/inc/ops/power_ops.h deleted file mode 100644 index b1f5bc24..00000000 --- a/third_party/fwkacllib/inc/ops/power_ops.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - #ifndef GE_OP_POWER_H - #define GE_OP_POWER_H - - #include "../graph/operator_reg.h" - - namespace ge { - -/** -*@brief Computes the output as (shift + scale * x) ^ power. - -*@par Inputs: -* x: A Tensor of type float16 or float32. - -*@par Attributes: -*@li power: Optional. Defaults to 1.0. -*@li scale: Optional. Defaults to 1.0. -*@li shift: Optional. Defaults to 0.0. - -*@par Outputs: -* y: A Tensor. Has the same type and shape as "x". -*/ - - REG_OP(Power) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(power, Float, 1.0) - .ATTR(scale, Float, 1.0) - .ATTR(shift, Float, 0.0) - .OP_END_FACTORY_REG(Power); - - } // namespace ge - - #endif // GE_OP_POWER_H diff --git a/third_party/fwkacllib/inc/ops/quantize_ops.h b/third_party/fwkacllib/inc/ops/quantize_ops.h index 235f2645..d9fe2540 100644 --- a/third_party/fwkacllib/inc/ops/quantize_ops.h +++ b/third_party/fwkacllib/inc/ops/quantize_ops.h @@ -16,25 +16,9 @@ #ifndef GE_OP_QUANTIZE_OPS_H #define GE_OP_QUANTIZE_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { -REG_OP(QuantizedInnerProduct) - .INPUT(x, TensorType({DT_UINT8})) - .INPUT(w, TensorType({DT_INT8})) - .OPTIONAL_INPUT(b, TensorType({DT_INT32})) - .OPTIONAL_INPUT(scale_q, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(offset_q, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(scale_deq_req, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(offset_req, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) - .REQUIRED_ATTR(quant_algo, ListInt) - .REQUIRED_ATTR(scale_sqrt, ListInt) - .REQUIRED_ATTR(num_output, Int) - .ATTR(transpose, Bool, false) - .ATTR(bias_term, Bool, false) - .ATTR(axis, Int, 1) - .OP_END_FACTORY_REG(QuantizedInnerProduct) /** * @brief Dequantizes the input tensor into a float tensor.\n @@ -118,6 +102,40 @@ REG_OP(AscendDequant) .ATTR(dtype, Int, DT_FLOAT) .OP_END_FACTORY_REG(AscendDequant) +REG_OP(AscendAntiQuant) + .INPUT(x, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .REQUIRED_ATTR(scale, Float) + .REQUIRED_ATTR(offset, Float) + .ATTR(dtype, Int, DT_FLOAT) + .ATTR(sqrt_mode, Bool, false) + .OP_END_FACTORY_REG(AscendAntiQuant) + +REG_OP(AscendDequantS16) + .INPUT(x0, TensorType({DT_INT32})) + .INPUT(deq_scale, TensorType({DT_UINT64})) + .OPTIONAL_INPUT(x1, TensorType({DT_INT16})) + .OUTPUT(y, TensorType({DT_INT16})) + .ATTR(relu_flag, Bool, false) + .OP_END_FACTORY_REG(AscendDequantS16) + +REG_OP(AscendRequant) + .INPUT(x, TensorType({DT_INT32})) + .INPUT(req_scale, TensorType({DT_UINT64})) + .OUTPUT(y, TensorType({DT_INT8})) + .ATTR(relu_flag, Bool, false) + .OP_END_FACTORY_REG(AscendRequant) + +REG_OP(AscendRequantS16) + .INPUT(x, TensorType({DT_INT16})) + .INPUT(req_scale, TensorType({DT_UINT64})) + .OPTIONAL_INPUT(x1, TensorType({DT_INT16})) + .OUTPUT(y, TensorType({DT_INT8})) + .OUTPUT(y1, TensorType({DT_INT16})) + .ATTR(dual_output, Bool, false) + .ATTR(relu_flag, Bool, false) + .OP_END_FACTORY_REG(AscendRequantS16) + } // namespace ge #endif // GE_OP_QUANTIZE_OPS_H diff --git a/third_party/fwkacllib/inc/ops/ragged_array_ops.h b/third_party/fwkacllib/inc/ops/ragged_array_ops.h index 245f3551..4f3cf97e 100644 --- a/third_party/fwkacllib/inc/ops/ragged_array_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_array_ops.h @@ -45,12 +45,10 @@ namespace ge { REG_OP(RaggedGather) .DYNAMIC_INPUT(params_nested_splits, TensorType({DT_INT32, DT_INT64})) - .INPUT(params_dense_values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ - DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) + .INPUT(params_dense_values, TensorType({DT_INT32, DT_INT64})) .INPUT(indices, TensorType({DT_INT32, DT_INT64})) .DYNAMIC_OUTPUT(output_nested_splits, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(output_dense_values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ - DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) + .OUTPUT(output_dense_values, TensorType({DT_INT32, DT_INT64})) .REQUIRED_ATTR(Tsplits, Type) .ATTR(PARAMS_RAGGED_RANK, Int, 1) .ATTR(OUTPUT_RAGGED_RANK, Int, 0) diff --git a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h index 8e07bdc5..7a42e4d9 100644 --- a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h @@ -50,5 +50,43 @@ REG_OP(RaggedTensorToSparse) .ATTR(RAGGED_RANK, Int, 1) .ATTR(Tsplits, Type, DT_INT64) .OP_END_FACTORY_REG(RaggedTensorToSparse) + +/** +*@brief Create a dense tensor from a ragged tensor, possibly altering its shape. + +*@par Inputs: +*Six inputs, including: +*@li shape:A `Tensor`. Must be one of the following types: `int64`, `int32`. +*@li values:A 1D tensor representing the values of the ragged tensor. +*@li default_value:A `Tensor`. Must have the same type as `values`. +*@li row_partition_tensors:A list of at least 1 `Tensor` objects with the same \n +type in: `int64`, `int32`. + +*@par Attributes: +*@li num_row_partition_tensors:Numbers of row partition tensors. +*@li row_partition_types: A list of `strings`. \n +The types of the row partition tensors. At present, these can be: \n +* "ROW_SPLITS": the row_splits tensor from the ragged tensor. \n +* "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor. \n +* "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it \n +is preceeded by "FIRST_DIM_SIZE". + +*@par Outputs: +*@li result: A `Tensor`. Has the same type as `values`. +*/ +REG_OP(RaggedTensorToTensor) + .INPUT(shape, TensorType({DT_INT32, DT_INT64})) + .INPUT(values, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, + DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .INPUT(default_value, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, + DT_UINT16, DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .DYNAMIC_INPUT(row_partition_tensors, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(result, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, + DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .REQUIRED_ATTR(num_row_partition_tensors, Int) + .REQUIRED_ATTR(row_partition_types, ListString) + .OP_END_FACTORY_REG(RaggedTensorToTensor) + + } // namespace ge #endif // GE_OP_RAGGED_CONVERSION_OPS_H \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/ragged_math_ops.h b/third_party/fwkacllib/inc/ops/ragged_math_ops.h index 51797ff8..80669f0f 100644 --- a/third_party/fwkacllib/inc/ops/ragged_math_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_math_ops.h @@ -41,11 +41,11 @@ namespace ge { */ REG_OP(RaggedRange) - .INPUT(starts, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) - .INPUT(limits, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) - .INPUT(deltas, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(starts, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(limits, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(deltas, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) .OUTPUT(rt_nested_splits, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(rt_dense_values, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .OUTPUT(rt_dense_values, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) .REQUIRED_ATTR(Tsplits, Type) .OP_END_FACTORY_REG(RaggedRange) diff --git a/third_party/fwkacllib/inc/ops/reduce_ops.h b/third_party/fwkacllib/inc/ops/reduce_ops.h index 0ba3e17f..a0f78291 100644 --- a/third_party/fwkacllib/inc/ops/reduce_ops.h +++ b/third_party/fwkacllib/inc/ops/reduce_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_REDUCE_OPS_H #define GE_OP_REDUCE_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -630,6 +630,55 @@ REG_OP(Reduction) .ATTR(axis, Int, 0) .ATTR(coeff, Float, 1.0) .OP_END_FACTORY_REG(Reduction); + +/** +*@brief Computes the euclidean norm of elements across dimensions of a tensor. + +*@par Inputs: +*@li input_tensor: A Tensor. Must be one of the following types: float16, float32, int32. +*@li axes: A Tensor of type int8 or int32. Specifies the dimensions to reduce. Defaults to "None". + +*@par Attributes:\n +*keep_dims: An optional bool. If "True", reduced dimensions will be retained. Defaults to "False". + +*@par Outputs:\n +*output_tensor: A Tensor. Must be one of the following types: float16, float32, int32. + +*@attention Constraints:\n +* If "axes = None", all dimensions will be reduced. "axes" must be in the range [-rank(input_shape), rank(input_shape)). + +*/ +REG_OP(EuclideanNorm) + .INPUT(x, TensorType::NumberType()) + .INPUT(axes, TensorType::IndexNumberType()) + .OUTPUT(y, TensorType::NumberType()) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(EuclideanNorm) + +/** +*@brief Computes the euclidean norm of elements across dimensions of a tensor. + +*@par Inputs:\n +*input_min: A Tensor. Must be one of the following types: float16, float32, int32. + +*@par Attributes: +*@li axes: An optional int32, list, tuple, or NoneType value. Specifies the dimensions to reduce. Defaults to "None". +*@li keep_dims: An optional bool or NoneType value. If "True", reduced dimensions will be retained. Defaults to "None" (equivalent to "False"). + +*@par Outputs:\n +*output_min: A Tensor. Must be one of the following types: float16, float32, int32. + +*@attention Constraints:\n +* If "axes = None", all dimensions will be reduced. "axes" must be in the range [-rank(input_shape), rank(input_shape)). + +*/ +REG_OP(EuclideanNormD) + .INPUT(x, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_FLOAT16})) + .ATTR(axes, ListInt, {}) + .ATTR(keep_dims, Bool, false) + .OP_END_FACTORY_REG(EuclideanNormD) + } //namespace ge diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index abd98695..c4d64b0a 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -17,7 +17,7 @@ #ifndef GE_OP_RNN_H #define GE_OP_RNN_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -180,15 +180,15 @@ REG_OP(RNN) .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(w_xh, TensorType({DT_FLOAT16})) - .INPUT(w_sh, TensorType({DT_FLOAT16})) + .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) + .OPTIONAL_INPUT(w_sh, TensorType({DT_FLOAT16})) .INPUT(w_hh, TensorType({DT_FLOAT16})) .INPUT(w_ho, TensorType({DT_FLOAT16})) - .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(bias_o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(expose_hidden, Bool, false) .ATTR(num_output, Int, 0) + .ATTR(expose_hidden, Bool, false) .OP_END_FACTORY_REG(RNN) /** @@ -220,9 +220,9 @@ REG_OP(BasicRNNCell) .OPTIONAL_INPUT(w_xh_x_static, TensorType({DT_FLOAT16, DT_FLOAT})) .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(w_xh, TensorType({DT_FLOAT16})) + .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .OPTIONAL_INPUT(w_hh, TensorType({DT_FLOAT16})) .INPUT(w_ho, TensorType({DT_FLOAT16})) - .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(bias_o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(o_t, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/roipooling_ops.h b/third_party/fwkacllib/inc/ops/roipooling_ops.h deleted file mode 100644 index dd7a2213..00000000 --- a/third_party/fwkacllib/inc/ops/roipooling_ops.h +++ /dev/null @@ -1,78 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_OP_ROIPOOLING_OPS_H_ -#define GE_OP_ROIPOOLING_OPS_H_ - -#include "graph/operator_reg.h" - -namespace ge { - -/** -*@brief Performs Region of Interest (ROI) pooling. - -*@par Inputs: -* Three inputs, including: -*@li x: An NC1HWC0 tensor of type float16 or float32, describing the feature map. -*@li rois: A tensor of type float16 or float32, with shape [batch, 5, roi_max_num], describing the RIOs. -*@li roi_actual_num: A tensor of type int32, with shape [batch, 8], specifying the number of ROIs per batch. - -*@par Attributes: -*@li roi_max_num: An optional int32, specifying the maximum number of ROIs per batch, at most 6000. Defaults to "3008". The value must be a multiple of 16. -*@li pooled_h: A required int32, specifying the pooled H. Must be greater than 0. -*@li pooled_w: A required int32, specifying the pooled W. Must be greater than 0. -*@li spatial_scale: An optional scaling factor for mapping the input coordinates to the ROI coordinates. Defaults to "0.0625". - -*@par Outputs: -*y: An NC1HWC0 tensor of type float16 or float32, describing the result feature map. - -*@attention Constraints:\n -*@li For the feature map input: \n -(1) If pooled_h = pooled_w = 2, the feature map size must not exceed 50. \n -(2) If pooled_h = pooled_w = 3, the feature map size must not exceed 60. \n -(3) If pooled_h = pooled_w = 4, the feature map size must not exceed 70. \n -(4) If pooled_h = pooled_w = 5, the feature map size must not exceed 70. \n -(5) If pooled_h = pooled_w = 6, the feature map size must not exceed 80. \n -(6) If pooled_h = pooled_w = 7, the feature map size must not exceed 80. \n -(7) If pooled_h = pooled_w = 8, the feature map size must not exceed 80. \n -(8) If pooled_h = pooled_w = 9, the feature map size must not exceed 70. \n -(9) If pooled_h = pooled_w = 10, the feature map size must not exceed 70. \n -(10) If pooled_h = pooled_w = 11, the feature map size must not exceed 70. \n -(11) If pooled_h = pooled_w = 12, the feature map size must not exceed 70. \n -(12) If pooled_h = pooled_w = 13, the feature map size must not exceed 70. \n -(13) If pooled_h = pooled_w = 14, the feature map size must not exceed 70. \n -(14) If pooled_h = pooled_w = 15, the feature map size must not exceed 70. \n -(15) If pooled_h = pooled_w = 16, the feature map size must not exceed 70. \n -(16) If pooled_h = pooled_w = 17, the feature map size must not exceed 50. \n -(17) If pooled_h = pooled_w = 18, the feature map size must not exceed 40. \n -(18) If pooled_h = pooled_w = 19, the feature map size must not exceed 40. \n -(19) If pooled_h = pooled_w = 20, the feature map size must not exceed 40. \n -*/ - -REG_OP(RoiPooling) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(roi_actual_num, TensorType({DT_INT32})) - .ATTR(roi_max_num, Int,3008) - .REQUIRED_ATTR(pooled_h, Int) - .REQUIRED_ATTR(pooled_w, Int) - .ATTR(spatial_scale, Float, 0.0625) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) - .OP_END_FACTORY_REG(RoiPooling) - -} // namespace ge - -#endif // GE_OP_BITWISE_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/rpn_ops.h b/third_party/fwkacllib/inc/ops/rpn_ops.h index 29c0fbc9..252bfdb0 100644 --- a/third_party/fwkacllib/inc/ops/rpn_ops.h +++ b/third_party/fwkacllib/inc/ops/rpn_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_RPN_OPS_H #define GE_OP_RPN_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** *@brief Iteratively removes lower scoring boxes which have an IoU greater than diff --git a/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h b/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h new file mode 100644 index 00000000..b8861f49 --- /dev/null +++ b/third_party/fwkacllib/inc/ops/rpn_proposal_post_processing.h @@ -0,0 +1,39 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #ifndef GE_OP_RPN_PROPOSAL_POST_PROCESSING_H + #define GE_OP_RPN_PROPOSAL_POST_PROCESSING_H + + #include "graph/operator_reg.h" + +namespace ge { + REG_OP(RpnProposalPostProcessing) + .INPUT(sorted_proposal, TensorType({DT_FLOAT16})) + .INPUT(proposal_num, TensorType({DT_UINT32})) + .OUTPUT(sorted_box, TensorType({ DT_FLOAT16})) + .REQUIRED_ATTR(img_size, ListInt) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .REQUIRED_ATTR(min_size, Float) + .REQUIRED_ATTR(nms_threshold, Float) + .REQUIRED_ATTR(post_nms_num, Int) + .ATTR(box_filter, Bool, true) + .ATTR(core_max_num, Int, 8) + .OP_END_FACTORY_REG(RpnProposalPostProcessing) + } // namespace ge + + #endif // GE_OP_GENERATE_RPN_PROPOSAL_POST_PROCESSING_H + diff --git a/third_party/fwkacllib/inc/ops/score_filter_pre_sort.h b/third_party/fwkacllib/inc/ops/score_filter_pre_sort.h new file mode 100644 index 00000000..8cfac8cf --- /dev/null +++ b/third_party/fwkacllib/inc/ops/score_filter_pre_sort.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #ifndef GE_OP_SCORE_FILTER_PRE_SORT_H + #define GE_OP_SCORE_FILTER_PRE_SORT_H + + #include "graph/operator_reg.h" + +namespace ge { + REG_OP(ScoreFiltePreSort) + .INPUT(rois, TensorType({DT_FLOAT16})) + .INPUT(cls_bg_prob, TensorType({DT_FLOAT16})) + .OUTPUT(sorted_proposal, TensorType({ DT_FLOAT16})) + .OUTPUT(proposal_num, TensorType({ DT_UINT32})) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(k, Int) + .ATTR(score_filter, Bool, true) + .ATTR(core_max_num, Int, 8) + .OP_END_FACTORY_REG(ScoreFiltePreSort) + } // namespace ge + + #endif // GE_OP_SCORE_FILTER_PRE_SORT_H + diff --git a/third_party/fwkacllib/inc/ops/sdca_ops.h b/third_party/fwkacllib/inc/ops/sdca_ops.h index 3f1e938a..15428d2b 100644 --- a/third_party/fwkacllib/inc/ops/sdca_ops.h +++ b/third_party/fwkacllib/inc/ops/sdca_ops.h @@ -64,7 +64,7 @@ REG_OP(SdcaOptimizerV2) .INPUT(example_weights, TensorType({DT_FLOAT})) .INPUT(example_labels, TensorType({DT_FLOAT})) .DYNAMIC_INPUT(sparse_indices, TensorType({DT_INT64})) - .DYNAMIC_INPUT(sparse_weights, TensorType({DT_INT64})) + .DYNAMIC_INPUT(sparse_weights, TensorType({DT_FLOAT})) .DYNAMIC_INPUT(dense_weights, TensorType({DT_FLOAT})) .INPUT(example_state_data, TensorType({DT_FLOAT})) .OUTPUT(out_example_state_data, TensorType({DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index dab71025..f3b588b1 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -16,7 +16,7 @@ #ifndef GE_OP_SELECTION_OPS_H #define GE_OP_SELECTION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** @@ -240,7 +240,7 @@ REG_OP(GatherV2D) REG_OP(StridedSlice) .INPUT(x, TensorType::BasicType()) .INPUT(begin, TensorType::IndexNumberType()) - .INPUT(end, TensorType::IndexNumberTypeT()) + .INPUT(end, TensorType::IndexNumberType()) .INPUT(strides, TensorType::IndexNumberType()) .ATTR(begin_mask, Int, 0) .ATTR(end_mask, Int, 0) @@ -514,23 +514,23 @@ REG_OP(Select) .OP_END_FACTORY_REG(Select) /** -*@brief: SelectV2s elements from "x2" or "x3", depending on "condition". +*@brief: SelectV2s elements from "then" or "else", depending on "condition". *@par Inputs: * Three inputs, including: -* @li x1: A Tensor of type bool. -* @li x2: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8. -* @li x3: A Tensor of the same type as "x2". +* @li condition: A Tensor of type bool. +* @li then: A Tensor. Must be one of the following types: float16, float32, int32, int8, uint8. +* @li else: A Tensor of the same type as "then". *@par Outputs: -*y: A Tensor. Has the same type as "x2". +*result: A Tensor. Has the same type as "then". */ REG_OP(SelectV2) - .INPUT(x1, TensorType({DT_BOOL})) - .INPUT(x2,TensorType::BasicType()) - .INPUT(x3,TensorType::BasicType()) - .OUTPUT(y,TensorType::BasicType()) + .INPUT(condition, TensorType({DT_BOOL})) + .INPUT(then,TensorType::BasicType()) + .INPUT(else,TensorType::BasicType()) + .OUTPUT(result,TensorType::BasicType()) .OP_END_FACTORY_REG(SelectV2) @@ -571,7 +571,7 @@ REG_OP(SegmentMax) *@par Outputs: *y:A Tensor with same type as "x". -*/ +*/ REG_OP(SegmentMaxD) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) @@ -685,7 +685,9 @@ REG_OP(SliceD) * least "k". * Specifies the data to sort. * @li assist_seq: A 1D tensor of type float16. -* With values 0, 1, 2, ..., N-1, where "N" is the last dimension. +* with size of 2N, which "N" is the last dimension. +* The first N numbers is indices, and the next N numbers is deviation of casting +* float16 to int32. * @par Attributes: * @li k: A required int that is at least 0, specifying the number of top elements \n @@ -703,6 +705,8 @@ REG_OP(SliceD) * @attention Constraints: * @li k =< 4096 * @li Size of the last dimension =< 65500 +* @li sorted = true +* @li Don't support to get score on the platform of Ascend310 */ REG_OP(TopKD) .INPUT(x, TensorType::RealNumberType()) @@ -1088,9 +1092,9 @@ REG_OP(InplaceUpdate) * An alias of "x". The content of "y" is undefined if there are duplicates in indices. */ REG_OP(InplaceUpdateD) - .INPUT(x, TensorType::BasicType()) - .INPUT(v, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(v, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) .REQUIRED_ATTR(indices, ListInt) .OP_END_FACTORY_REG(InplaceUpdateD) @@ -1136,9 +1140,9 @@ REG_OP(InplaceAdd) * An alias of "x". The content of "y" is undefined if there are duplicates in indices. */ REG_OP(InplaceAddD) - .INPUT(x, TensorType::BasicType()) - .INPUT(v, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(v, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) .REQUIRED_ATTR(indices, ListInt) .OP_END_FACTORY_REG(InplaceAddD) @@ -1181,9 +1185,9 @@ REG_OP(InplaceSub) * An alias of x. The content of y is undefined if there are duplicates in indices. */ REG_OP(InplaceSubD) - .INPUT(x, TensorType::BasicType()) - .INPUT(v, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(v, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) .REQUIRED_ATTR(indices, ListInt) .OP_END_FACTORY_REG(InplaceSubD) @@ -1308,191 +1312,24 @@ REG_OP(UnsortedSegmentProdD) .REQUIRED_ATTR(num_segments, Int) .OP_END_FACTORY_REG(UnsortedSegmentProdD) -/** -*@brief Normalizes data. It is called Region on YOLO v2 and Yolo on YOLO v3. - -*@par Inputs: -*x: An NCHW tensor of type float16 or float32. The data is with shape (N, boxes*(coords+obj+classes), H, W),where, "obj" indicates the confidence of an object, and only one confidence is supported. Boxes are arranged as xx...xyy...yww...whh...hbb...bc0c0..c0c1c1...c1......cncn...cn. - -*@par Attributes: -*@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. -*@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). -*@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. -*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". -*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". -*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". -*@li background: A bool. - -*@par Outputs: -*@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. -*@li obj_prob: A float16 or float32 with shape [N, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the confidence. -*@li classes_prob: A float16 or float32 with shape [N, classes, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the prediction classes. - -*@attention Constraints: -*@li This operator applies to YOLO v2 and v3 networks. -*@li The succeeding layer of the Yolo operator must be operator Yolov3DetectionOutput. -*/ -REG_OP(Yolo) - .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(yolo_version, String, "V3") - .ATTR(softmax, Bool, false) - .ATTR(background, Bool, false) - .ATTR(softmaxtree, Bool, false) - .OP_END_FACTORY_REG(Yolo) - -/** -*@brief Performs YOLO V3 detection. - -*@par Inputs: -*Ten inputs, including: -*@li Operator Yolov3DetectionOutput takes the outputs of operator Yolo as its inputs. A Yolo operator has three outputs: "coords", "obj", and "class". \n -There are three Yolo operators at Yolov3DetectionOutput's preceding layer on Yolo v3. For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. - -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. - -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n - -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v3 network. -*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. - -*@see Yolo() -*/ -REG_OP(YoloV3DetectionOutput) - .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases_low, ListFloat) - .REQUIRED_ATTR(biases_mid, ListFloat) - .REQUIRED_ATTR(biases_high, ListFloat) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV3DetectionOutput) - -/** -*@brief Performs YOLO V3 detection. - -*@par Inputs: -*16 Input, including: -*@li The outputs of operator Yolo at the preceding layer (that is, three Yolo operators on YOLO v3) are used as the inputs of operator Yolov3DetectionOutput. \n -A Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -*@li windex: A windex tensor with shape [height,weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed for the three Yolo outputs, respectively. - -*@li hindex: A hindex tensor with shape [height,weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]] is formed for the three Yolo outputs, respectively. - -* -*@par Attributes: -*@li biases: A required float32. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v3 network. -*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. -*@see Yolo() -*/ -REG_OP(YoloV3DetectionOutputD) - .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex1, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex2, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex3, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex1, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex2, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex3, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases_low, ListFloat) - .REQUIRED_ATTR(biases_mid, ListFloat) - .REQUIRED_ATTR(biases_high, ListFloat) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV3DetectionOutputD) - /** *@brief Performs object detection. *@par Inputs: *@li cls_prob: An NCHW tensor of type float16 or float32, specifying the probability of the proposal is the background class. -*@li bbox_pred: An NCHW tensor of type float16 or float32, specifying the coordinates of the proposals bounding boxes. +*@li bbox_delta: An NCHW tensor of type float16 or float32, specifying the coordinates of the proposals bounding boxes. +*@li im_info: An ND tensor of type float16 or float32, specifying the Image information. *@par Attributes: -*@li im_info: A required list of floats, specifying the Image information. The value range is [1, 4096]. -*@li feat_stride: A required float32, specifying the stride of the sliding window. Must be greater than "0". Defaults to "16". -*@li base_size: A required float32, specifying the size of the generated base box. Must be greater than "0". Defaults to "16". -*@li min_size: A required float32, specifying the minimum edge length of a proposal. A box with any edge less than this value is removed. Must be greater than "0". Defaults to "16". -*@li ratio: A required list of floats, specifying the aspect ratio of the generated base box. Defaults to [0.5, 1, 2]. -*@li scale: A required list of floats, specifying the ratio of the size of the generated base box to "base_size". Defaults to [8, 16, 32]. +*@li feat_stride: A optional float32, specifying the stride of the sliding window. Must be greater than "0".Defaults to "16". +*@li base_size: A optional float32, specifying the size of the generated base box. Must be greater than "0". Defaults to "16". +*@li min_size: A optional float32, specifying the minimum edge length of a proposal. A box with any edge less than this value is removed. Must be greater than "0". Defaults to "16". +*@li ratio: A optional list of floats, specifying the aspect ratio of the generated base box. Defaults to [0.5, 1, 2]. +*@li scale: A optional list of floats, specifying the ratio of the size of the generated base box to "base_size". Defaults to [8, 16, 32]. *@li pre_nms_topn: A required int, specifying top K boxes before NMS. For float16 input, pre_nms_topn <= 6000. For float32 input, pre_nms_topn <= 3000. Defaults to "3000". *@li post_nms_topn: A required int, specifying the number of boxes to be output after NMS. The value is a multiple of 16. For float16 input, post_nms_topn <= 6000. For float32 input, post_nms_topn <= 3000 (the maximum multiple of 16 is 2992 within the range). Defaults to "304". -*@li nms_thresh: A required float32, specifying the NMS threshold. The value range is (0,1]. Defaults to "0.7". +*@li iou_threshold: A required float32, specifying the NMS threshold. The value range is (0,1]. Defaults to "0.7". +*@li output_actual_rois_num: An optional bool. Defaults to "false". *@par Outputs: *@li rois: A Tensor with shape [batch, 5, post_nms_topn], of type float16, specifying the output box information. "post_nms_topn" must be a multiple of 16. The dimension "5" indicates (batchID, x1, y1, x2, y2). The number of BBoxes output per batch is determined by "actual_rois_num". @@ -1500,18 +1337,19 @@ REG_OP(YoloV3DetectionOutputD) */ REG_OP(Proposal) .INPUT(cls_prob, TensorType({DT_FLOAT16, DT_FLOAT})) - .INPUT(bbox_pred, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(bbox_delta, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(im_info, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(rois, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(actual_rois_num, TensorType({DT_INT32})) - .ATTR(im_info, ListFloat, {375, 1240}) .ATTR(feat_stride, Float, 16) .ATTR(base_size, Float, 16) - .ATTR(min_size, ListFloat, {16, 16}) + .ATTR(min_size, Float, 16) .ATTR(ratio, ListFloat, {0.5, 1, 2}) .ATTR(scale, ListFloat, {8, 16, 32}) - .ATTR(pre_nms_topn, Int, 6000) + .ATTR(pre_nms_topn, Int, 3000) .ATTR(post_nms_topn, Int, 304) - .ATTR(nms_thresh, Float, 0.7) + .ATTR(iou_threshold, Float, 0.7) + .ATTR(output_actual_rois_num, Bool, false) .OP_END_FACTORY_REG(Proposal) /** @@ -1519,19 +1357,20 @@ REG_OP(YoloV3DetectionOutputD) *@par Inputs: *@li cls_prob: An NCHW tensor of type float16, specifying the probability of the proposal is the background class. -*@li bbox_pred: An NCHW tensor of type float16, specifying the coordinates of the proposals bounding boxes. +*@li bbox_delta: An NCHW tensor of type float16, specifying the coordinates of the proposals bounding boxes. +*@li im_info: An ND tensor of type float16 or float32, specifying the Image information. *@li rpn_bbox: An NCHW tensor of type float16, specifying the coordinates of the proposals bounding boxes. *@par Attributes: -*@li im_info: A required list of floats, specifying the Image information. The value range is [1, 4096]. -*@li feat_stride: A required float32, specifying the stride of the sliding window. Must be greater than "0". Defaults to "16". +*@li feat_stride: A required float32, specifying the stride of the sliding window. Must be greater than "0".Defaults to "16". *@li base_size: A required float32, specifying the size of the generated base box. Must be greater than "0". Defaults to "16". *@li min_size: A required float32, specifying the minimum edge length of a proposal. A box with any edge less than this value is removed. Must be greater than "0". Defaults to "16". *@li ratio: A required list of floats, specifying the aspect ratio of the generated base box. Defaults to [0.5, 1, 2]. *@li scale: A required list of floats, specifying the ratio of the size of the generated base box to "base_size". Defaults to [8, 16, 32]. *@li pre_nms_topn: A required int, specifying top K boxes before NMS. For float16 input, pre_nms_topn <= 6000. For float32 input, pre_nms_topn <= 3000. Defaults to "3000". *@li post_nms_topn: A required int, specifying the number of boxes to be output after NMS. The value is a multiple of 16. For float16 input, post_nms_topn <= 6000. For float32 input, post_nms_topn <= 3000 (the maximum multiple of 16 is 2992 within the range). Defaults to "304". -*@li nms_thresh: A required float32, specifying the NMS threshold. The value range is (0,1]. Defaults to 0.7. +*@li iou_threshold: A required float32, specifying the NMS threshold. The value range is (0,1]. Defaults to 0.7. +*@li output_actual_rois_num: An optional bool. Defaults to "false". *@par Outputs: *@li rois: A Tensor with shape [batch, 5, post_nms_topn], of type float16, specifying the output box information. "post_nms_topn" must be a multiple of 16. The dimension "5" indicates (batchID, x1, y1, x2, y2). The number of BBoxes output per batch is determined by "actual_rois_num". @@ -1539,131 +1378,22 @@ REG_OP(YoloV3DetectionOutputD) */ REG_OP(ProposalD) .INPUT(cls_prob, TensorType({DT_FLOAT16, DT_FLOAT})) - .INPUT(bbox_pred, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(bbox_delta, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(im_info, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(rpn_bbox, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(rois, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(actual_rois_num, TensorType({DT_INT32})) - .ATTR(im_info, ListFloat, {375, 1240}) .ATTR(feat_stride, Float, 16) .ATTR(base_size, Float, 16) - .ATTR(min_size, ListFloat, {16, 16}) + .ATTR(min_size, Float, 16) .ATTR(ratio, ListFloat, {0.5, 1, 2}) .ATTR(scale, ListFloat, {8, 16, 32}) - .ATTR(pre_nms_topn, Int, 6000) + .ATTR(pre_nms_topn, Int, 3000) .ATTR(post_nms_topn, Int, 304) - .ATTR(nms_thresh, Float, 0.7) + .ATTR(iou_threshold, Float, 0.7) + .ATTR(output_actual_rois_num, Bool, false) .OP_END_FACTORY_REG(ProposalD) -/** -*@brief Performs YOLO V2 detection. - -*@par Inputs: -* Four inputs, including: -*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov3DetectionOutput. \n -Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. - -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v2 network. -*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. - -*@see Yolo() -*/ -REG_OP(YoloV2DetectionOutput) - .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases, ListFloat) - .ATTR(boxes, Int, 5) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV2DetectionOutput) - -/** -*@brief Performs YOLO V2 detection. - -*@par Inputs: -*Six inputs, including: -*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov2DetectionOutput. \n -Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -*@li windex: A windex tensor with shape [height, weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed. \n - -*@li hindex: A hindex tensor with shape [height, weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]]. \n - -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. - -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -* -*@attention Constraints:\n -*@li This operator applies only to the YOLO v2 network. -*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. - -*@see Yolo() -*/ -REG_OP(YoloV2DetectionOutputD) - .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases, ListFloat) - .ATTR(boxes, Int, 5) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV2DetectionOutputD) - /** *@brief Performs plane or channel conversion on YoloV2. * If reverse=true: (N, H, W, C)->(N, H*stride, W*stride, C/(stride*stride)) @@ -1784,5 +1514,97 @@ REG_OP(WriteSelect) .INPUT(x, TensorType::ALL()) .OUTPUT(y, TensorType::ALL()) .OP_END_FACTORY_REG(WriteSelect) + +/** +*@brief Read data by stride. + +*@par Inputs: +*One input:\n +*x: A Tensor. Must be one of the following types: float16, int8. + +*@par Attributes: +*@li axis: A required int32, specifying the index of axis to read by stride. + +*@par Attributes: +*@li stride: A required int32, specifying the value of reading stride. + +*@par Outputs: +*y: A Tensor of the same type as "x". +*/ +REG_OP(StridedRead) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(axis, Int, 1) + .ATTR(stride, Int, 1) + .OP_END_FACTORY_REG(StridedRead) + +/** +*@brief: Write data by stride. + +*@par Inputs:\n +*x: A Tensor. Must be one of the following types: float16, int8. + +*@par Attributes: +*@li axis: A required int32, specifying the index of axis to write by stride. + +*@par Attributes: +*@li stride: A required int32, specifying the value of writing stride. + +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*/ +REG_OP(StridedWrite) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + .ATTR(axis, Int, 1) + .ATTR(stride, Int, 1) + .OP_END_FACTORY_REG(StridedWrite) + +/** +*@brief Computes the cumulative log sum exp of the tensor "x" along "axis". + +*@par Inputs: +* Two inputs, including: +*@li x: A Tensor. Must be one of the following types: float32, float16. +*@li axis A Tensor of type int32 or int16. Defaults to "0". +* +*@par Attributes: +*@li exclusive: If "False", performs inclusive CumulativeLogsumexp, which means that the first element of the input is identical to the first element of the output. If "True", performs exclusive CumulativeLogsumexp. +*@li reverse: A bool. Defaults to "False". +* +*@par Outputs: +*@li y: A Tensor. Has the same type as "x". +*/ +REG_OP(CumulativeLogsumexp) + .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .INPUT(axis, TensorType({DT_INT32, DT_INT16})) + .OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .ATTR(exclusive, Bool, false) + .ATTR(reverse, Bool, false) + .OP_END_FACTORY_REG(CumulativeLogsumexp) + +/** +*@brief Computes the cumulative log sum exp of the tensor "x" along "axis". +* +*@par Inputs: +* One input: +*x: A Tensor. Must be one of the following types: float32, float16. +* +*@par Attributes: +*@li axis A Tensor of type int32 or int16. Defaults to "0". +*@li exclusive: If "False", performs inclusive cumulativeLogsumexp, which means that the first element of the input is identical to the first element of the output. If "True", performs exclusive CumulativeLogsumexp. +*@li reverse: A bool. Defaults to "False". +* +*@par Outputs: +*y: A Tensor. Has the same type as "x". +*/ +REG_OP(CumulativeLogsumexpD) + .INPUT(x, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) + .REQUIRED_ATTR(axis, Int) + .ATTR(exclusive, Bool, false) + .ATTR(reverse, Bool, false) + .OP_END_FACTORY_REG(CumulativeLogsumexpD) + } // namespace ge #endif // GE_OP_SELECTION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/sparse_ops.h b/third_party/fwkacllib/inc/ops/sparse_ops.h index abb1361c..87f0d81b 100644 --- a/third_party/fwkacllib/inc/ops/sparse_ops.h +++ b/third_party/fwkacllib/inc/ops/sparse_ops.h @@ -215,7 +215,7 @@ REG_OP(SparseDenseCwiseMul) REG_OP(AddSparseToTensorsMap) .INPUT(indices, TensorType({DT_INT64})) .INPUT(values, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ - DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE, \ DT_COMPLEX64, DT_COMPLEX128, DT_RESOURCE, DT_STRING})) .INPUT(shape, TensorType({DT_INT64})) .OUTPUT(handle, TensorType({DT_INT64})) diff --git a/third_party/fwkacllib/inc/ops/spectral_ops.h b/third_party/fwkacllib/inc/ops/spectral_ops.h new file mode 100644 index 00000000..c74bebe9 --- /dev/null +++ b/third_party/fwkacllib/inc/ops/spectral_ops.h @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_SPECTRAL_OPS_H +#define GE_OP_SPECTRAL_OPS_H + +#include "graph/operator.h" +#include "graph/operator_reg.h" + +namespace ge { + +/** +*@brief Real-valued fast Fourier transform. + +*@par Inputs: +*@li input: A float32 tensor. +*@li fft_length: An int32 tensor of shape [1]. The FFT length. + +*@par Outputs: +*@li y: A complex64 tensor of the same rank as `input`. The inner-most \n +dimension of `input` is replaced with the `fft_length / 2 + 1` unique \n +frequency components of its 1D Fourier transform. + +*/ +REG_OP(RFFT) + .INPUT(input, TensorType({DT_FLOAT})) + .INPUT(fft_length, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_COMPLEX64})) + .OP_END_FACTORY_REG(RFFT) + +} // namespace ge + +#endif //GE_OP_SPECTRAL_OPS_H \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/split_combination_ops.h b/third_party/fwkacllib/inc/ops/split_combination_ops.h index 734847f4..521d05f7 100644 --- a/third_party/fwkacllib/inc/ops/split_combination_ops.h +++ b/third_party/fwkacllib/inc/ops/split_combination_ops.h @@ -16,7 +16,7 @@ #ifndef GE_OP_SPLIT_COMBINATION_OPS_H #define GE_OP_SPLIT_COMBINATION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { /** diff --git a/third_party/fwkacllib/inc/ops/state_ops.h b/third_party/fwkacllib/inc/ops/state_ops.h index 879d7c67..2b2d1362 100644 --- a/third_party/fwkacllib/inc/ops/state_ops.h +++ b/third_party/fwkacllib/inc/ops/state_ops.h @@ -102,6 +102,16 @@ REG_OP(IsVariableInitialized) .OUTPUT(y, TensorType({DT_BOOL})) .OP_END_FACTORY_REG(IsVariableInitialized) +/** +*@brief Checks whether a tensor has been initialized. Outputs boolean scalar indicating whether the tensor has been initialized. + +*@par Inputs: +*x: A tensor. + +*@par Outputs: +*y: A tensor, indicating whether "x" has been initialized, and the data type is boolean. + +*/ REG_OP(VarIsInitializedOp) .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) diff --git a/third_party/fwkacllib/inc/ops/stateful_random_ops.h b/third_party/fwkacllib/inc/ops/stateful_random_ops.h index 929481d5..9ba09dd6 100644 --- a/third_party/fwkacllib/inc/ops/stateful_random_ops.h +++ b/third_party/fwkacllib/inc/ops/stateful_random_ops.h @@ -87,9 +87,9 @@ smaller than the range of the output (either `2^32` or `2^64`). REG_OP(StatefulRandomBinomial) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .INPUT(counts, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) - .INPUT(probs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(shape, TensorType({DT_INT32})) + .INPUT(counts, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(probs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) .REQUIRED_ATTR(dtype, Type) .OP_END_FACTORY_REG(StatefulRandomBinomial) @@ -111,7 +111,7 @@ REG_OP(StatefulRandomBinomial) REG_OP(StatefulStandardNormalV2) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT64})) + .INPUT(shape, TensorType({DT_INT32,DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulStandardNormalV2) @@ -134,7 +134,7 @@ REG_OP(StatefulStandardNormalV2) REG_OP(StatefulTruncatedNormal) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT64})) + .INPUT(shape, TensorType({DT_INT32,DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulTruncatedNormal) @@ -156,7 +156,7 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded. \n REG_OP(StatefulUniform) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT64})) + .INPUT(shape, TensorType({DT_INT32,DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulUniform) @@ -177,8 +177,8 @@ The generated values are uniform integers covering the whole range of `dtype`. REG_OP(StatefulUniformFullInt) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT64})) - .OUTPUT(y, TensorType({DT_INT64})) + .INPUT(shape, TensorType({DT_INT32,DT_INT64})) + .OUTPUT(y, TensorType({DT_UINT64})) .OP_END_FACTORY_REG(StatefulUniformFullInt) /** @@ -205,7 +205,7 @@ smaller than the range of the output (either `2^32` or `2^64`). REG_OP(StatefulUniformInt) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT64})) + .INPUT(shape, TensorType({DT_INT32,DT_INT64})) .INPUT(minval, TensorType({DT_INT64})) .INPUT(maxval, TensorType({DT_INT64})) .OUTPUT(y, TensorType({DT_INT64})) diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 689cde4e..eb8655d0 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -17,7 +17,7 @@ #ifndef GE_OP_TRANSFORMATION_OPS_H #define GE_OP_TRANSFORMATION_OPS_H -#include "../graph/operator_reg.h" +#include "graph/operator_reg.h" namespace ge { REG_OP(DepthwiseWeight4DTo6D) @@ -31,6 +31,8 @@ REG_OP(DepthwiseWeight6DTo4D) .ATTR(channel_size, Int, 16) .OP_END_FACTORY_REG(DepthwiseWeight6DTo4D) + + /** *@brief Permutes the dimensions according to perm.\n The returned tensor's dimension i will correspond to the input dimension perm[i]. @@ -45,8 +47,10 @@ REG_OP(DepthwiseWeight6DTo4D) *y: A Tensor. Has the same type as "x". */ REG_OP(TransposeD) - .INPUT(x, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(x, TensorType({DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_FLOAT16, DT_FLOAT})) .REQUIRED_ATTR(perm, ListInt) .OP_END_FACTORY_REG(TransposeD) @@ -400,14 +404,45 @@ REG_OP(Unpack) * "ksizes", "strides" and "rates" are lists of integers. */ REG_OP(ExtractImagePatches) - .INPUT(x, TensorType::REALNUMBERTYPE()) - .OUTPUT(y, TensorType::REALNUMBERTYPE()) + .INPUT(x, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) .REQUIRED_ATTR(ksizes, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(rates, ListInt) .REQUIRED_ATTR(padding, String) .OP_END_FACTORY_REG(ExtractImagePatches) +/** +* @brief Extract "patches" from "input" and put them in the "depth" +* dimension of the output. + +* @par Inputs: +* x: A 5D Tensor with shape [batch, in_planes, in_rows, in_cols, depth]. + +* @par Attributes: +* @li ksizes: A required list or tuple. The size of the sliding window for each +* dimension of "x". +* @li strides: A required list or tuple. How far the centers of two consecutive +* patches are in "x". Must be: [1, stride_planes, stride_rows, stride_cols, 1]. +* @li padding: A required string. The type of padding algorithm to use. + +* @par Outputs: +* Output: A 5D Tensor with shape [batch, out_planes, out_rows, out_cols, ksize_planes * \n +* ksize_rows * ksize_cols * depth] containing patches with size (ksize_rows * ksize_cols\n +* * depth) vectorized in the "depth" dimension. Note "out_planes", "out_rows" and "out_cols"\n +* are the dimensions of the output patches. + +* @attention Constraints: +* "ksizes" and "strides" are lists of integers. +*/ +REG_OP(ExtractVolumePatches) + .INPUT(x, TensorType::REALNUMBERTYPE()) + .OUTPUT(y, TensorType::REALNUMBERTYPE()) + .REQUIRED_ATTR(ksizes, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(padding, String) + .OP_END_FACTORY_REG(ExtractVolumePatches) + /** *@brief Confuse reshape and transpose. @@ -423,8 +458,10 @@ REG_OP(ExtractImagePatches) *y: A Tensor. Has the same type as "x". */ REG_OP(ConfusionTransposeD) - .INPUT(x, TensorType::BasicType()) - .OUTPUT(y, TensorType::BasicType()) + .INPUT(x, TensorType({DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, + DT_UINT16, DT_UINT32, DT_UINT64, DT_FLOAT16, DT_FLOAT})) .REQUIRED_ATTR(perm, ListInt) .REQUIRED_ATTR(shape, ListInt) .REQUIRED_ATTR(transpose_first, Bool) @@ -466,7 +503,7 @@ REG_OP(ConfusionTranspose) *y: The flattened ND tensor. All data types are supported. *@attention Constraints: -* "axis" and "end_axis" must be within the dimension range of the input. +* "axis" and "end_axis" must be within the dimension range of the input. This operator cannot be directly called by the acllopExecute API. */ REG_OP(FlattenV2) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, @@ -481,6 +518,13 @@ REG_OP(DeConvTrans) .INPUT(x, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_INT8})) .OP_END_FACTORY_REG(DeConvTrans) + +REG_OP(Compress) + .INPUT(weight, TensorType({DT_INT8, DT_FLOAT16})) + .OUTPUT(weight_compress, TensorType({DT_INT8, DT_FLOAT16})) + .OUTPUT(compress_index, TensorType({DT_INT8})) + .REQUIRED_ATTR(compress_parameters, ListInt) + .OP_END_FACTORY_REG(Compress) } // namespace ge #endif // GE_OP_TRANSFORMATION_OPS_H diff --git a/third_party/fwkacllib/inc/register/op_kernel_registry.h b/third_party/fwkacllib/inc/register/op_kernel_registry.h index cc8924b5..2c479e92 100644 --- a/third_party/fwkacllib/inc/register/op_kernel_registry.h +++ b/third_party/fwkacllib/inc/register/op_kernel_registry.h @@ -18,7 +18,8 @@ #define INC_REGISTER_OP_KERNEL_REGISTRY_H_ #include #include -#include "register/register.h" +#include "register/register_types.h" +#include "register.h" namespace ge { class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { diff --git a/third_party/fwkacllib/inc/register/op_registry.h b/third_party/fwkacllib/inc/register/op_registry.h index 4dd1dc5b..9a214955 100644 --- a/third_party/fwkacllib/inc/register/op_registry.h +++ b/third_party/fwkacllib/inc/register/op_registry.h @@ -61,6 +61,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type); + domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); + domi::ImplyType GetImplyTypeByOriOpType(const std::string &ori_optype); const std::vector &GetRemoveInputConfigure(const std::string &ori_optype) const; @@ -70,6 +72,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { std::unordered_map op_run_mode_map_; std::unordered_map opParseParamsFnMap_; std::unordered_map fusionOpParseParamsFnMap_; + std::unordered_map op_types_to_parse_subgraph_post_func_; std::unordered_map> remove_input_configure_map_; std::unordered_map originOpType2OmOpType_; }; diff --git a/third_party/fwkacllib/inc/register/register.h b/third_party/fwkacllib/inc/register/register.h new file mode 100644 index 00000000..27da0b0b --- /dev/null +++ b/third_party/fwkacllib/inc/register/register.h @@ -0,0 +1,53 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_REGISTER_REGISTRY_H_ +#define INC_REGISTER_REGISTRY_H_ + +#include "external/register/register.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { + public: + HostCpuOp() = default; + virtual ~HostCpuOp() = default; + + virtual graphStatus Compute(Operator &op, + const std::map &inputs, + std::map &outputs) = 0; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { + public: + HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); + ~HostCpuOpRegistrar() = default; +}; + +#define REGISTER_HOST_CPU_OP_BUILDER(name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ + static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr \ + __attribute__((unused)) = \ + ::ge::HostCpuOpRegistrar(name, []()->::ge::HostCpuOp* { \ + return new (std::nothrow) op(); \ + }) +} // namespace ge + +#endif //INC_REGISTER_REGISTRY_H_ diff --git a/src/ge/common/formats/format_transfers/format_transfer.h b/third_party/fwkacllib/inc/register/register_format_transfer.h similarity index 83% rename from src/ge/common/formats/format_transfers/format_transfer.h rename to third_party/fwkacllib/inc/register/register_format_transfer.h index 3d03ebbe..72da94fc 100644 --- a/src/ge/common/formats/format_transfers/format_transfer.h +++ b/third_party/fwkacllib/inc/register/register_format_transfer.h @@ -14,16 +14,15 @@ * limitations under the License. */ -#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_H_ -#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_H_ +#ifndef INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ +#define INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ #include #include #include #include "external/graph/types.h" -#include "framework/common/ge_inner_error_codes.h" -#include "common/ge/ge_util.h" +#include "ge/ge_api_error_codes.h" namespace ge { namespace formats { @@ -65,19 +64,16 @@ class FormatTransferRegister { #define REGISTER_FORMAT_TRANSFER(TransferClass, format1, format2) \ namespace { \ FormatTransferRegister format_transfer_register_##TransferClass##format1##format2( \ - []() { return ge::MakeShared(); }, format1, format2); \ + []() { return std::make_shared(); }, format1, format2); \ } -/** - * Build a formattransfer according to 'args' - * @param args - * @param result - * @return - */ +/// Build a formattransfer according to 'args' +/// @param args +/// @param result +/// @return std::shared_ptr BuildFormatTransfer(const TransArgs &args); bool FormatTransferExists(const TransArgs &args); - } // namespace formats } // namespace ge -#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_H_ +#endif // INC_REGISTER_REGISTER_FORMAT_TRANSFER_H_ \ No newline at end of file diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index a1b3d762..868e16ce 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -259,6 +259,50 @@ RTS_API rtError_t rtLabelGoto(rtLabel_t label, rtStream_t stream); * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle */ RTS_API rtError_t rtNameLabel(rtLabel_t label, const char *name); + +/** + * @ingroup dvrt_base + * @brief label switch by index + * @param [in] ptr index value ptr + * @param [in] max index max value + * @param [in] labelInfoPtr label content info ptr + * @param [in] stream set stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + */ +RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief stream goto label + * @param [in] label goto label + * @param [in] stream stream to submit label_goto task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + */ +RTS_API rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [in] label model label list + * @param [in] labelNumber label number + * @param [in] dst device ptr + * @param [in] dstMax dst size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + */ +RTS_API rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [out] label created label handle + * @param [in] stream label bind stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input handle + */ +RTS_API rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index e5d5d360..fcdcf2ec 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -39,6 +39,8 @@ typedef enum tagRtChipType { CHIP_BEGIN = 0, CHIP_MINI = CHIP_BEGIN, CHIP_CLOUD, + CHIP_MDC, + CHIP_LHISI, CHIP_OTHER_PHN, CHIP_OTHER_OLD, CHIP_END, diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 6f5ff62b..08fa3970 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -36,9 +36,7 @@ typedef struct tagRTDeviceInfo { uint32_t ai_core_id; uint32_t aicpu_occupy_bitmap; uint32_t hardware_version; -#ifdef DRIVER_NEW_API uint32_t ts_num; -#endif } rtDeviceInfo_t; typedef enum tagRtRunMode { @@ -213,6 +211,13 @@ RTS_API rtError_t rtSetTSDevice(uint32_t tsId); * @return RT_ERROR_DRV_ERR for can not get run mode */ RTS_API rtError_t rtGetRunMode(rtRunMode *mode); + +/** + * @ingroup dvrt_dev + * @brief set chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetSocVersion(const char *version); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index 1609519f..c99eb96f 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -448,7 +448,7 @@ RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stream); * @param [in] stream subscribed stream * @return RT_ERROR_NONE for ok, errno for failed */ -RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream); +RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream, bool isBlock); /** * @ingroup rt_kernel diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index 1597c436..27ee26d2 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -37,7 +37,12 @@ extern "C" { #define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device #define RT_MEMORY_P2P_DDR ((uint32_t)0x11) // DDR memory on other device #define RT_MEMORY_DDR_NC ((uint32_t)0x20) // DDR memory of non-cache -#define RT_MEMORY_RESERVED ((uint32_t)0x40) +#define RT_MEMORY_TS_4G ((uint32_t)0x40) +#define RT_MEMORY_TS ((uint32_t)0x80) +#define RT_MEMORY_RESERVED ((uint32_t)0x100) + +#define RT_MEMORY_L1 ((uint32_t)0x1<<16) +#define RT_MEMORY_L2 ((uint32_t)0x1<<17) /** * @ingroup dvrt_mem @@ -75,6 +80,8 @@ typedef enum tagRtMemcpyKind { RT_MEMCPY_DEVICE_TO_HOST, // device to host RT_MEMCPY_DEVICE_TO_DEVICE, // device to device, 1P && P2P RT_MEMCPY_MANAGED, // managed memory + RT_MEMCPY_ADDR_DEVICE_TO_DEVICE, + RT_MEMCPY_HOST_TO_DEVICE_EX, // host to device ex (only used for 8 bytes) RT_MEMCPY_RESERVED, } rtMemcpyKind_t; @@ -85,6 +92,8 @@ typedef enum tagRtRecudeKind { typedef enum tagRtDataType { RT_DATA_TYPE_FP32 = 0, // fp32 + RT_DATA_TYPE_FP16 = 1, // fp16 + RT_DATA_TYPE_INT16 = 2, // int16 RT_DATA_TYPE_END } rtDataType_t; diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 1e03e853..790492fc 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -24,38 +24,41 @@ extern "C" { #endif typedef enum tagModelTaskType { - RT_MODEL_TASK_KERNEL = 0, - RT_MODEL_TASK_EVENT_RECORD, - RT_MODEL_TASK_EVENT_WAIT, - RT_MODEL_TASK_FUSION_START, - RT_MODEL_TASK_FUSION_END, - RT_MODEL_TASK_KERNEL_EX, - RT_MODEL_TASK_HCCL, - RT_MODEL_TASK_STREAM_SWITCH, - RT_MODEL_TASK_STREAM_ACTIVE, - RT_MODEL_TASK_LABEL_SET, - RT_MODEL_TASK_LABEL_SWITCH, - RT_MODEL_TASK_LABEL_GOTO, - RT_MODEL_TASK_PROFILER_TRACE, - RT_MODEL_TASK_MEMCPY_ASYNC, - RT_MODEL_TASK_NOTIFY_RECORD, - RT_MODEL_TASK_NOTIFY_WAIT, - RT_MODEL_TASK_REDUCE_ASYNC, - RT_MODEL_TASK_RDMA_SEND, - RT_MODEL_TASK_EVENT_RESET = 18, - RT_MODEL_TASK_MODEL_END_GRAPH, - RT_MODEL_TASK_STREAM_SWITCH_N, - RT_MODEL_TASK_RDMA_DB_SEND + RT_MODEL_TASK_KERNEL = 0, + RT_MODEL_TASK_EVENT_RECORD, + RT_MODEL_TASK_EVENT_WAIT, + RT_MODEL_TASK_FUSION_START, + RT_MODEL_TASK_FUSION_END, + RT_MODEL_TASK_KERNEL_EX, + RT_MODEL_TASK_HCCL, + RT_MODEL_TASK_STREAM_SWITCH, + RT_MODEL_TASK_STREAM_ACTIVE, + RT_MODEL_TASK_LABEL_SET, + RT_MODEL_TASK_LABEL_SWITCH, + RT_MODEL_TASK_LABEL_GOTO, + RT_MODEL_TASK_PROFILER_TRACE, + RT_MODEL_TASK_MEMCPY_ASYNC, + RT_MODEL_TASK_NOTIFY_RECORD, + RT_MODEL_TASK_NOTIFY_WAIT, + RT_MODEL_TASK_REDUCE_ASYNC, + RT_MODEL_TASK_RDMA_SEND, + RT_MODEL_TASK_EVENT_RESET = 18, + RT_MODEL_TASK_MODEL_END_GRAPH, + RT_MODEL_TASK_STREAM_SWITCH_N, + RT_MODEL_TASK_RDMA_DB_SEND, + RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, + RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, + RT_MODEL_TASK_STREAM_LABEL_GOTO, } rtModelTaskType_t; -typedef enum tagModelStreamType { - RT_MODEL_HEAD_STREAM = 0, - RT_MODEL_WAIT_ACTIVE_STREAM = 1 +typedef enum tagModelStreamType { + RT_MODEL_HEAD_STREAM = 0, + RT_MODEL_WAIT_ACTIVE_STREAM = 1 } rtModelStreamType_t; typedef enum tagModelQueueFlag { - RT_MODEL_INPUT_QUEUE = 0, - RT_MODEL_OUTPUT_QUEUE = 1 + RT_MODEL_INPUT_QUEUE = 0, + RT_MODEL_OUTPUT_QUEUE = 1 } rtModelQueueFlag_t; #define EXECUTOR_NONE ((uint32_t)0x0) @@ -67,177 +70,198 @@ typedef enum tagModelQueueFlag { * @brief the type defination of aicpu model task command */ typedef enum tagTsAicpuModelCmd { - TS_AICPU_MODEL_LOAD = 1, - TS_AICPU_MODEL_EXECUTE, - TS_AICPU_MODEL_DESTROY, - TS_AICPU_MODEL_ABORT, - TS_AICPU_MODEL_RESERVED, + TS_AICPU_MODEL_LOAD = 1, + TS_AICPU_MODEL_EXECUTE, + TS_AICPU_MODEL_DESTROY, + TS_AICPU_MODEL_ABORT, + TS_AICPU_MODEL_RESERVED, } tsAicpuModelCmd; typedef struct tagAicpuTaskInfo { - uint32_t taskID; - uint32_t streamID; - uint32_t kernelType; - uint64_t kernelName; - uint64_t kernelSo; - uint64_t paraBase; - uint32_t taskFlag; + uint32_t taskID; + uint32_t streamID; + uint32_t kernelType; + uint64_t kernelName; + uint64_t kernelSo; + uint64_t paraBase; + uint32_t taskFlag; } rtAicpuTaskInfo_t; typedef struct tagModelStreamInfo { - uint32_t streamID; - uint32_t streamFlag; + uint32_t streamID; + uint32_t streamFlag; } rtModelStreamInfo_t; typedef struct tagModelQueueInfo { - uint32_t queueID; - uint32_t flag; + uint32_t queueID; + uint32_t flag; } rtModelQueueInfo_t; typedef struct tagAicpuModelInfo { - uint32_t moduleID; - uint32_t tsId; - uint16_t streamInfoNum; - uint16_t aicpuTaskNum; - uint64_t streamInfoPtr; - uint64_t aicpuTaskPtr; - uint16_t queueSize; - uint64_t queueInfoPtr; + uint32_t moduleID; + uint32_t tsId; + uint16_t streamInfoNum; + uint16_t aicpuTaskNum; + uint64_t streamInfoPtr; + uint64_t aicpuTaskPtr; + uint16_t queueSize; + uint64_t queueInfoPtr; } rtAicpuModelInfo_t; typedef struct tagKernelTaskInfo { - uint16_t blockDim; - uint16_t argsCount; - uint16_t argsSize; - uint16_t reserved; - char *stubFunc; - uint8_t *smDesc; - uint8_t *args; - uint16_t *argsOffset; + uint16_t blockDim; + uint16_t argsCount; + uint16_t argsSize; + uint16_t reserved; + char *stubFunc; + uint8_t *smDesc; + uint8_t *args; + uint16_t *argsOffset; } rtKernelTaskInfo_t; typedef struct tagKernelTaskInfoEx { - uint32_t flags; - uint32_t argsSize; - void *args; - uint32_t reserved[6]; + uint32_t flags; + uint32_t argsSize; + void *args; + uint32_t reserved[6]; } rtKernelTaskInfoEx_t; typedef struct tagEventTaskInfo { - uint32_t eventID; - uint32_t reserved[9]; + uint32_t eventID; + uint32_t reserved[9]; } rtEventTaskInfo_t; typedef struct tagStreamSwitchTaskInfo { - int64_t value; - uint64_t pValuePtr; - uint32_t trueStreamID; - uint32_t dataType; - uint32_t reserved[4]; + int64_t value; + uint64_t pValuePtr; + uint32_t trueStreamID; + uint32_t dataType; + uint32_t reserved[4]; } rtStreamSwitchTaskInfo_t; typedef struct tagStreamSwitchNTaskInfo { - uint64_t pValuePtr; - uint64_t pTrueStreamPtr; - uint32_t size; - uint32_t elementSize; - uint32_t dataType; - uint32_t reserved[3]; + uint64_t pValuePtr; + uint64_t pTrueStreamPtr; + uint32_t size; + uint32_t elementSize; + uint32_t dataType; + uint32_t reserved[3]; } rtStreamSwitchNTaskInfo_t; typedef struct tagStreamActiveTaskInfo { - uint32_t activeStreamID; - uint32_t reserved[9]; + uint32_t activeStreamID; + uint32_t reserved[9]; } rtStreamActiveTaskInfo_t; typedef struct tagSetTaskInfo { - uint16_t labelId; - uint32_t reserved[9]; + uint16_t labelId; + uint32_t reserved[9]; } rtLabelSetTaskInfo_t; typedef struct tagSwitchTaskInfo { - uint32_t value; - uint32_t reserved[9]; + uint32_t value; + uint32_t reserved[9]; } rtLabelSwitchTaskInfo_t; typedef struct tagLabelGotoTaskInfo { - uint16_t labelId; - uint32_t reserved[9]; + uint16_t labelId; + uint32_t reserved[9]; } rtLabelGotoTaskInfo_t; typedef struct tagProfilerTraceTaskInfo { - uint64_t profilerTraceId; - uint32_t notify : 8; - uint32_t reserved_ : 24; - uint32_t flags; - uint32_t reserved[6]; + uint64_t profilerTraceId; + uint32_t notify : 8; + uint32_t reserved_ : 24; + uint32_t flags; + uint32_t reserved[6]; } rtProfilerTrace_t; typedef struct tagrtMemcpyAsyncTaskInfo { - void *dst; - uint64_t destMax; - void *src; - uint64_t count; - uint32_t kind; - uint32_t reserved; + void *dst; + uint64_t destMax; + void *src; + uint64_t count; + uint32_t kind; + uint32_t reserved; } rtMemcpyAsyncTaskInfo_t; typedef struct tagrtNotifyTaskInfo { - uint32_t notifyID; - uint32_t reserved[9]; + uint32_t notifyID; + uint32_t reserved[9]; } rtNotifyTaskInfo_t; typedef struct tagrtReduceAsyncTaskInfo { - void *dst; - uint64_t destMax; - void *src; - uint64_t count; - uint32_t kind; - uint32_t type; + void *dst; + uint64_t destMax; + void *src; + uint64_t count; + uint32_t kind; + uint32_t type; } rtReduceAsyncTaskInfo_t; typedef struct tagrtRdmaSendTaskInfo { - uint32_t index; - uint32_t wqe_index; - uint32_t reserved[8]; + uint32_t index; + uint32_t wqe_index; + uint32_t reserved[8]; } rtRdmaSendTaskInfo_t; typedef struct tagrtRdmaDbSendTaskInfo { - uint64_t dbInfo; - uint32_t dbIndex; - uint32_t reserved[7]; // offset 7 + uint64_t dbInfo; + uint32_t dbIndex; + uint32_t reserved[7]; // offset 7 } rtRdmaDbSendTaskInfo_t; typedef struct tagrtModelEndGraphTaskInfo { - uint32_t modelId; - uint32_t executorFlag; - uint32_t reserved[8]; + uint32_t modelId; + uint32_t executorFlag; + uint32_t reserved[8]; } rtModelEndGraphTaskInfo_t; +typedef struct tagrtStreamLabelSwitchByIndexTask_t { + uint64_t indexPtr; + uint64_t labelInfoPtr; + uint32_t max; + uint8_t reserved[20]; +} rtStreamLabelSwitchByIndexTask_t; + +typedef struct tagrtStreamLabelGotoTask_t { + uint16_t labelId; + uint16_t modelId; + uint8_t reserved[36]; +} rtStreamLabelGotoTask_t; + typedef struct tagTaskInfo { - uint32_t type; - uint32_t streamID; - union { - rtKernelTaskInfoEx_t kernelTaskEx; - rtKernelTaskInfo_t kernelTask; - rtEventTaskInfo_t eventTask; - rtStreamSwitchTaskInfo_t streamSwitchTask; - rtStreamActiveTaskInfo_t streamActiveTask; - rtLabelSetTaskInfo_t labelSetTask; - rtLabelSwitchTaskInfo_t labelSwitchTask; - rtLabelGotoTaskInfo_t labelGotoTask; - rtProfilerTrace_t profilertraceTask; - rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; - rtNotifyTaskInfo_t notifyTask; - rtReduceAsyncTaskInfo_t reduceAsyncTask; - rtRdmaSendTaskInfo_t rdmaSendTask; - rtRdmaDbSendTaskInfo_t rdmaDbSendTask; - rtModelEndGraphTaskInfo_t modelEndGraphTask; - rtStreamSwitchNTaskInfo_t streamSwitchNTask; - uint32_t reserved[10]; - } u; + uint32_t type; + uint32_t streamID; + union { + rtKernelTaskInfoEx_t kernelTaskEx; + rtKernelTaskInfo_t kernelTask; + rtEventTaskInfo_t eventTask; + rtStreamSwitchTaskInfo_t streamSwitchTask; + rtStreamActiveTaskInfo_t streamActiveTask; + rtLabelSetTaskInfo_t labelSetTask; + rtLabelSwitchTaskInfo_t labelSwitchTask; + rtLabelGotoTaskInfo_t labelGotoTask; + rtProfilerTrace_t profilertraceTask; + rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; + rtNotifyTaskInfo_t notifyTask; + rtReduceAsyncTaskInfo_t reduceAsyncTask; + rtRdmaSendTaskInfo_t rdmaSendTask; + rtRdmaDbSendTaskInfo_t rdmaDbSendTask; + rtModelEndGraphTaskInfo_t modelEndGraphTask; + rtStreamSwitchNTaskInfo_t streamSwitchNTask; + rtStreamLabelSwitchByIndexTask_t streamLabelSwitchIndexTask; + rtStreamLabelGotoTask_t streamLabelGotoTask; + uint32_t reserved[10]; + } u; } rtTaskInfo_t; +typedef struct tagLabelDevInfo_t { + uint16_t modelId; + uint16_t streamId; + uint16_t labelId; +}rtLabelDevInfo; + typedef void *rtModel_t; typedef rtError_t (*rtTaskGenCallback)(rtModel_t model, rtTaskInfo_t *taskInfo); @@ -311,11 +335,12 @@ RTS_API rtError_t rtModelExecute(rtModel_t model, rtStream_t stream, uint32_t fl * @ingroup rt_model * @brief get model the last persist task id * @param [in] model model to execute - * @param [out] taskid task id of the model + * @param [out] taskid last task id of the model + * @param [out] streamid last steam id of the model * @return RT_ERROR_NONE for ok * @return RT_ERROR_INVALID_VALUE for error input handle */ -RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid); +RTS_API rtError_t rtModelGetTaskId(rtModel_t model, uint32_t *taskid, uint32_t *streamid); /** * @ingroup rt_model diff --git a/third_party/fwkacllib/inc/runtime/stream.h b/third_party/fwkacllib/inc/runtime/stream.h index 83bb4b63..232b5169 100644 --- a/third_party/fwkacllib/inc/runtime/stream.h +++ b/third_party/fwkacllib/inc/runtime/stream.h @@ -36,6 +36,13 @@ extern "C" { #define RT_STREAM_FORBIDDEN_DEFAULT (0x10) #define RT_STREAM_HEAD (0x20) +/** + * @ingroup stream_type + * @brief stream type + */ +#define RT_NORMAL_STREAM (0x00) +#define RT_HUGE_STREAM (0x01) + /** * priority level default value when create a stream */ @@ -114,12 +121,13 @@ RTS_API rtError_t rtGetStreamId(rtStream_t stream, int32_t *streamId); /** * @ingroup dvrt_stream * @brief inquire max stream count and max task count per stream + * @param [in] streamType Stream Type * @param [in] MaxStrCount Max stream count * @param [in] MaxTaskCount max task count per stream * @return RT_ERROR_NONE for complete * @return RT_ERROR_INVALID_RESOURCE_HANDLE for error input stream handle */ -RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t *MaxStrCount, uint32_t *MaxTaskCount); +RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *MaxStrCount, uint32_t *MaxTaskCount); /** * @ingroup dvrt_stream diff --git a/third_party/fwkacllib/inc/tdt/data_common.h b/third_party/fwkacllib/inc/tdt/data_common.h index 14145a60..da9881ff 100644 --- a/third_party/fwkacllib/inc/tdt/data_common.h +++ b/third_party/fwkacllib/inc/tdt/data_common.h @@ -72,5 +72,27 @@ struct DataItem { uint64_t dataLen_; /**< Input data type length*/ std::shared_ptr dataPtr_; /**< Data pointer*/ }; + +/** + * @ingroup Tsdclient. + * + * tsdclient func type; + */ +enum TsdCmdType { + TSDCLOSE = 0, + TSDOPEN = 1 +}; + +/** + * @ingroup Tsdclient. + * + * tsdclient func input value object. + */ +enum InputItem { + OPEN_DEVICEID = 0, + OPEN_RANKSIZE, + CLOSE_DEVICEID +}; + } // namespace tdt #endif // HOST_INNER_INC_DATA_COMMON_H_ diff --git a/third_party/fwkacllib/inc/tdt/status.h b/third_party/fwkacllib/inc/tdt/status.h index 50a656c9..ec624b35 100644 --- a/third_party/fwkacllib/inc/tdt/status.h +++ b/third_party/fwkacllib/inc/tdt/status.h @@ -191,6 +191,7 @@ enum { TDT_HDC_SRV_TYPE_ERROR_CODE, TDT_TSD_CLT_OPEN_FAILED_CODE, TDT_TSD_CLT_CLOSE_FAILED_CODE, + TDT_TSD_CLT_INTERFACE_NOT_SUPPORT_CODE, TDT_SUPERVISOR_ILLEGAL_HEARTBEAT_TIME_CODE, TDT_SUPERVISOR_INOTIFY_READ_SIZE_ERROR_CODE, TDT_SUPERVISOR_INOTIFY_INTERRUPT_CODE, @@ -685,6 +686,7 @@ TDT_DEF_ERROR_CODE(MODID_HDC_SERVER, TDT_ERROR, TDT_BIND_CPUCORE_FAILED, "thread TDT_DEF_ERROR_CODE(MODID_HDC_SERVER, TDT_ERROR, TDT_HDC_SRV_CLOSED, "hdc server has been closed"); TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_OPEN_FAILED, "tsd client open failed"); TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_CLOSE_FAILED, "tsd client close failed"); +TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_INTERFACE_NOT_SUPPORT, "tsd client func not support"); TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_FILELIST_NOT_EXIST, "tdt filelist open failed"); TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_SAMPLE_FILE_NOT_FOUND, "tdt sample file is empty"); TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_FILE_OPEN_FAIL, "tdt open sample file fail"); diff --git a/third_party/fwkacllib/inc/tdt/tdt_host_interface.h b/third_party/fwkacllib/inc/tdt/tdt_host_interface.h index 821ee819..0e62a85c 100644 --- a/third_party/fwkacllib/inc/tdt/tdt_host_interface.h +++ b/third_party/fwkacllib/inc/tdt/tdt_host_interface.h @@ -80,6 +80,24 @@ int32_t TdtHostPushData(const std::string &channelName, const std::vector #include #include "tdt/status.h" +#include "tdt/data_common.h" #ifdef __cplusplus extern "C" { @@ -68,6 +69,87 @@ TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t rankSize); */ TDT_StatusT TsdClose(const uint32_t phyDeviceId); +/** +* @ingroup CreateCmdParameterObj +* @brief creat tsdclient func parameter obj. +* +* @par Function +* creat tsdclient func parameter obj. +* +* @param type [IN] type tdt::TsdCmdType, tsd func type. +* @param cmdParameterObj [IN] type void *, func parameter obj. +* @retval TDT_OK Success +* @retval TDT_INTERFACE_NOT_SUPPORT +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li data_common.h: Header file where tdt::TsdCmdType and tdt::InputItem defined. +* @li status.h: Header file where 'TDT_StatusT' defined +*/ +TDT_StatusT CreateCmdParameterObj(tdt::TsdCmdType type, void **cmdParameterObj); + +/** +* @ingroup SetCmdParameterObjAttribute +* @brief set cmdParameterObj input value. +* +* @par Function +* set cmdParameterObj input value. +* +* @param type [IN] type tdt::TsdCmdType, tsd func type. +* @param cmdParameterObj [IN] type void *, func parameter obj. +* @param itemType [IN] type tdt::InputItem, func input type. +* @param valuePtr [IN] type const void *, input value. +* @param valueLength [IN] type int, input value length. +* @retval TDT_OK Success +* @retval TDT_INTERFACE_NOT_SUPPORT +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li data_common.h: Header file where tdt::TsdCmdType and tdt::InputItem defined. +* @li status.h: Header file where 'TDT_StatusT' defined +*/ +TDT_StatusT SetCmdParameterObjAttribute(tdt::TsdCmdType type, void *cmdParameterObj, tdt::InputItem itemType, const void *valuePtr, int valueLength); + +/** +* @ingroup GetCmdParameterObjAttribute +* @brief set cmdParameterObj input value. +* +* @par Function +* set cmdParameterObj input value. +* +* @param type [IN] type tdt::TsdCmdType, tsd func type. +* @param cmdParameterObj [IN] type void *, func parameter obj. +* @param itemType [IN] type tdt::InputItem, func input type. +* @param valuePtr [IN] type const void *, input value. +* @param valueLength [IN] type int, input value length. +* @retval TDT_OK Success +* @retval TDT_INTERFACE_NOT_SUPPORT +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li data_common.h: Header file where tdt::TsdCmdType and tdt::InputItem defined. +* @li status.h: Header file where 'TDT_StatusT' defined +*/ +TDT_StatusT GetCmdParameterObjAttribute(tdt::TsdCmdType type, void *cmdParameterObj, tdt::InputItem itemType, void *valuePtr, int &valueLength); + +/** +* @ingroup TsdClientCmd +* @brief creat tsdclient func parameter obj. +* +* @par Function +* creat tsdclient func parameter obj. +* +* @param type [IN] type tdt::TsdCmdType, tsd func type. +* @param cmdParameterObj [IN] type void *, func parameter obj. +* @retval TDT_OK Success +* @retval TDT_INTERFACE_NOT_SUPPORT +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li data_common.h: Header file where tdt::TsdCmdType and tdt::InputItem defined. +* @li status.h: Header file where 'TDT_StatusT' defined +*/ +TDT_StatusT TsdClientCmd(tdt::TsdCmdType cmd, void *cmdParameterObj); namespace tdt { /** diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index 1fb9aff2..f77df225 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -168,6 +168,7 @@ enum { DSS, PROCMGR, // Process Manager, Base Platform BBOX, + AIVECTOR, INVLID_MOUDLE_ID }; @@ -241,6 +242,7 @@ static DCODE g_moduleIdName[] = {SET_MOUDLE_ID_MAP_NAME(SLOG), SET_MOUDLE_ID_MAP_NAME(DSS), SET_MOUDLE_ID_MAP_NAME(PROCMGR), SET_MOUDLE_ID_MAP_NAME(BBOX), + SET_MOUDLE_ID_MAP_NAME(AIVECTOR), { NULL, -1 }}; #endif // MODULE_ID_NAME @@ -253,22 +255,33 @@ extern void dlog_init(void); /** * @ingroup slog - * @brief dlog_getlevel: get module level + * @brief dlog_getlevel: get module loglevel and enableEvent * - * @param [in]moduleId: module id, eg: CCE + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), others: invalid * @param [out]enableEvent: 1: enable; 0: disable * @return: module level(0: debug, 1: info, 2: warning, 3: error, 4: null output) */ extern int dlog_getlevel(int moduleId, int *enableEvent); /** -* @ingroup slog -* @brief CheckLogLevel: check module level enable or not -* -* @param [in]moduleId: module id, eg: CCE -* @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG -* @return: 1:enable, 0:disable -*/ + * @ingroup slog + * @brief dlog_setlevel: set module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), -1: all modules, others: invalid + * @param [in]level: log level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + * @param [in]enableEvent: 1: enable; 0: disable, others:invalid + * @return: 0: SUCCEED, others: FAILED + */ +extern int dlog_setlevel(int moduleId, int level, int enableEvent); + +/** + * @ingroup slog + * @brief CheckLogLevel: check module level enable or not + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG + * @return: 1:enable, 0:disable + */ extern int CheckLogLevel(int moduleId, int logLevel); /**