From 828eb025030e66608cadd46495c75b75b3f517d6 Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Sun, 17 May 2020 11:20:17 +0800 Subject: [PATCH] Update GraphEngine to synchronize with latest Ascend driver software suite 17 May 2020 --- CMakeLists.txt | 10 +- build.sh | 70 +- cmake/external_libs/securec.cmake | 11 - inc/common/util/compress/compress.h | 36 - inc/external/ge/ge_api_types.h | 2 - inc/external/graph/inference_context.h | 2 +- inc/external/register/register.h | 22 + inc/framework/common/debug/ge_log.h | 16 +- inc/framework/common/debug/log.h | 22 +- inc/framework/common/ge_inner_error_codes.h | 1 + inc/framework/common/gflags_util.h | 4 +- inc/framework/common/helper/model_helper.h | 5 +- inc/framework/common/helper/om_file_helper.h | 6 +- inc/framework/common/l2_cache_optimize.h | 4 +- inc/framework/common/op/attr_value_util.h | 12 +- inc/framework/common/op/ge_op_utils.h | 2 - inc/framework/common/op/op_parser_util.h | 4 +- inc/framework/common/op_types.h | 4 +- inc/framework/common/scope_guard.h | 6 +- inc/framework/common/string_util.h | 4 +- inc/framework/common/types.h | 35 +- inc/framework/common/util.h | 4 +- inc/framework/omg/omg_inner_types.h | 20 +- inc/framework/omg/version.h | 4 +- inc/graph/debug/ge_attr_define.h | 314 +-- inc/graph/detail/model_serialize_imp.h | 2 +- inc/graph/model.h | 4 + inc/graph/usr_types.h | 6 +- inc/graph/utils/graph_utils.h | 2 - src/common/graph/CMakeLists.txt | 1 + src/common/graph/anchor.cc | 1 + src/common/graph/compute_graph.cc | 37 +- src/common/graph/format_refiner.cc | 1 + src/common/graph/ge_attr_define.cc | 223 +- src/common/graph/shape_refiner.cc | 3 +- src/common/graph/utils/graph_utils.cc | 57 - src/common/graph/utils/tensor_utils.cc | 1 - src/ge/CMakeLists.txt | 10 +- src/ge/client/CMakeLists.txt | 1 + src/ge/client/ge_api.cc | 2 + src/ge/common/CMakeLists.txt | 2 +- src/ge/common/auth/file_saver.cc | 5 +- src/ge/common/auth/file_saver.h | 52 +- src/ge/common/context/ctx.cc | 1 - .../format_transfers/datatype_transfer.cc | 11 +- .../format_transfer_c1hwncoc0_hwcn.cc | 21 +- .../format_transfer_dhwcn_fracz3D.cc | 5 - ...format_transfer_dhwnc_fracz3D_transpose.cc | 5 - .../format_transfer_fractal_nz.cc | 10 - .../format_transfer_fractal_z.cc | 15 - .../format_transfer_fractal_zz.cc | 10 - .../format_transfer_fracz_hwcn.cc | 6 - .../format_transfer_fracz_nchw.cc | 7 - .../format_transfer_fracz_nhwc.cc | 6 - .../format_transfer_hwcn_c1hwncoc0.cc | 24 +- .../format_transfer_nc1hwc0_nchw.cc | 8 +- .../format_transfer_nc1hwc0_nhwc.cc | 8 +- .../format_transfer_nchw_fz_c04.cc | 314 --- .../format_transfer_nchw_fz_c04.h | 35 - .../format_transfer_nchw_nc1hwc0.cc | 54 +- .../format_transfer_nhwc_nc1hwc0.cc | 8 +- .../format_transfer_transpose.cc | 10 +- src/ge/common/formats/formats.cc | 12 +- .../formats/utils/formats_trans_utils.cc | 6 +- .../formats/utils/formats_trans_utils.h | 3 - src/ge/common/fp16_t.h | 2 +- src/ge/common/ge/plugin_manager.cc | 5 +- src/ge/common/helper/model_cache_helper.cc | 1707 ------------- src/ge/common/helper/model_cache_helper.h | 121 - src/ge/common/helper/model_helper.cc | 6 +- src/ge/common/helper/om_file_helper.cc | 5 +- src/ge/common/math_util.h | 4 +- src/ge/common/model_parser/base.cc | 8 +- src/ge/common/model_saver.cc | 2 +- src/ge/common/op/attr_value_util.cc | 4 +- src/ge/common/op/ge_op_utils.cc | 6 +- src/ge/common/profiling/profiling_manager.cc | 179 +- src/ge/common/profiling/profiling_manager.h | 7 +- src/ge/common/properties_manager.cc | 2 +- src/ge/common/types.cc | 386 +-- src/ge/common/util.cc | 27 +- src/ge/executor/CMakeLists.txt | 2 +- src/ge/executor/ge_executor.cc | 17 +- src/ge/ge_local_engine/CMakeLists.txt | 1 + .../ge_local_engine/engine/host_cpu_engine.cc | 2 +- .../ge_local_engine/engine/host_cpu_engine.h | 2 +- .../ge_local_ops_kernel_info.cc | 2 + src/ge/ge_runtime/CMakeLists.txt | 1 + src/ge/ge_runtime/runtime_model.cc | 7 +- src/ge/generator/ge_generator.cc | 11 +- src/ge/generator/generator_api.cc | 2 +- src/ge/graph/build/graph_builder.cc | 15 +- src/ge/graph/build/graph_builder.h | 2 +- .../graph/build/logical_stream_allocator.cc | 287 +-- src/ge/graph/build/logical_stream_allocator.h | 26 +- src/ge/graph/build/memory/CMakeLists.txt | 1 + .../build/memory/binary_block_mem_assigner.cc | 6 +- .../graph/build/memory/block_mem_assigner.cc | 28 +- .../graph/build/memory/graph_mem_assigner.cc | 26 +- .../graph/build/memory/var_mem_assign_util.cc | 40 +- src/ge/graph/build/model_builder.cc | 86 +- src/ge/graph/build/model_builder.h | 4 +- src/ge/graph/build/run_context.cc | 2 + src/ge/graph/build/stream_allocator.cc | 19 +- src/ge/graph/build/stream_allocator.h | 4 +- src/ge/graph/build/stream_graph_optimizer.cc | 125 +- src/ge/graph/build/stream_graph_optimizer.h | 4 +- src/ge/graph/build/task_generator.cc | 29 +- src/ge/graph/build/task_generator.h | 2 +- src/ge/graph/common/omg_util.cc | 8 +- src/ge/graph/common/transop_util.cc | 4 +- src/ge/graph/execute/graph_execute.cc | 1 - src/ge/graph/label/case_label_maker.cc | 19 +- src/ge/graph/label/if_label_maker.cc | 22 +- src/ge/graph/label/label_maker.cc | 141 +- src/ge/graph/label/label_maker.h | 5 - .../label/partitioned_call_label_maker.cc | 3 + src/ge/graph/label/while_label_maker.cc | 6 +- src/ge/graph/load/graph_loader.cc | 4 +- .../new_model_manager/cpu_queue_schedule.cc | 99 +- .../new_model_manager/cpu_queue_schedule.h | 21 +- .../load/new_model_manager/data_dumper.cc | 7 +- .../load/new_model_manager/data_dumper.h | 4 +- .../load/new_model_manager/davinci_model.cc | 823 ++++--- .../load/new_model_manager/davinci_model.h | 100 +- .../new_model_manager/davinci_model_parser.cc | 4 +- .../load/new_model_manager/model_manager.cc | 63 +- .../load/new_model_manager/model_manager.h | 1 - .../load/new_model_manager/model_utils.cc | 31 +- .../load/new_model_manager/model_utils.h | 7 + .../task_info/hccl_task_info.cc | 20 +- .../task_info/kernel_ex_task_info.cc | 24 +- .../task_info/kernel_ex_task_info.h | 1 - .../task_info/kernel_task_info.cc | 74 +- .../task_info/memcpy_addr_async_task_info.cc | 149 -- .../task_info/memcpy_addr_async_task_info.h | 55 - .../task_info/memcpy_async_task_info.cc | 3 +- .../task_info/stream_switch_task_info.cc | 8 +- .../task_info/super_kernel/super_kernel.cc | 12 +- .../task_info/super_kernel/super_kernel.h | 15 +- .../super_kernel/super_kernel_factory.cc | 53 +- .../super_kernel/super_kernel_factory.h | 2 +- src/ge/graph/load/output/output.h | 3 +- src/ge/graph/manager/graph_manager.cc | 217 +- src/ge/graph/manager/graph_manager.h | 13 +- src/ge/graph/manager/graph_manager_utils.cc | 10 +- src/ge/graph/manager/graph_var_manager.cc | 84 +- src/ge/graph/manager/graph_var_manager.h | 19 +- src/ge/graph/manager/util/debug.cc | 2 +- src/ge/graph/manager/util/hcom_util.cc | 8 + .../manager/util/variable_accelerate_ctrl.cc | 11 +- src/ge/graph/optimize/common/params.h | 4 + src/ge/graph/optimize/graph_optimize.cc | 4 +- src/ge/graph/partition/graph_partition.cc | 28 +- src/ge/graph/passes/addn_pass.cc | 2 +- .../passes/aicpu_constant_folding_pass.cc | 4 +- .../passes/aicpu_constant_folding_pass.h | 1 - src/ge/graph/passes/assert_pass.cc | 6 +- src/ge/graph/passes/atomic_addr_clean_pass.cc | 18 +- src/ge/graph/passes/cast_remove_pass.cc | 2 + src/ge/graph/passes/cast_translate_pass.cc | 6 +- src/ge/graph/passes/compile_nodes_pass.h | 3 - .../graph/passes/constant_fuse_same_pass.cc | 3 + src/ge/graph/passes/control_op_attr_pass.cc | 6 + src/ge/graph/passes/control_trigger_pass.cc | 13 + src/ge/graph/passes/control_trigger_pass.h | 2 +- src/ge/graph/passes/dropout_pass.cc | 2 +- src/ge/graph/passes/end_graph_pass.cc | 7 +- src/ge/graph/passes/enter_pass.cc | 14 +- src/ge/graph/passes/flow_ctrl_pass.cc | 28 +- .../graph/passes/folding_kernel/add_kernel.cc | 1 + .../graph/passes/folding_kernel/add_kernel.h | 2 +- .../folding_kernel/broadcast_args_kernel.cc | 2 + .../broadcast_gradient_args_kernel.cc | 13 +- .../passes/folding_kernel/cast_kernel.cc | 14 +- .../folding_kernel/concat_offset_kernel.cc | 4 +- .../passes/folding_kernel/concat_v2_kernel.cc | 2 + .../folding_kernel/dynamic_stitch_kernel.cc | 7 +- .../passes/folding_kernel/empty_kernel.cc | 2 + .../folding_kernel/expanddims_kernel.cc | 5 +- .../passes/folding_kernel/fill_kernel.cc | 1 + .../passes/folding_kernel/floordiv_kernel.cc | 2 + .../passes/folding_kernel/floordiv_kernel.h | 2 +- .../passes/folding_kernel/floormod_kernel.cc | 2 + .../passes/folding_kernel/gather_v2_kernel.cc | 5 +- .../passes/folding_kernel/greater_kernel.cc | 1 + .../passes/folding_kernel/kernel_utils.cc | 18 +- .../passes/folding_kernel/kernel_utils.h | 5 +- .../passes/folding_kernel/maximum_kernel.cc | 1 + .../graph/passes/folding_kernel/mul_kernel.cc | 2 + .../passes/folding_kernel/pack_kernel.cc | 30 +- .../passes/folding_kernel/permute_kernel.cc | 7 + .../passes/folding_kernel/range_kernel.cc | 2 + .../passes/folding_kernel/rank_kernel.cc | 1 + .../folding_kernel/reduce_prod_kernel.cc | 33 +- .../folding_kernel/reduce_prod_kernel.h | 2 +- .../passes/folding_kernel/reformat_kernel.cc | 2 + .../passes/folding_kernel/reshape_kernel.cc | 2 + .../passes/folding_kernel/rsqrt_kernel.cc | 4 + .../passes/folding_kernel/shape_kernel.cc | 2 + .../passes/folding_kernel/shape_n_kernel.cc | 2 + .../passes/folding_kernel/size_kernel.cc | 3 +- .../passes/folding_kernel/slice_d_kernel.cc | 1 + .../passes/folding_kernel/slice_kernel.cc | 2 + .../passes/folding_kernel/squeeze_kernel.cc | 2 + .../folding_kernel/ssd_prior_box_kernel.cc | 15 +- .../folding_kernel/strided_slice_kernel.cc | 7 + .../graph/passes/folding_kernel/sub_kernel.cc | 1 + .../passes/folding_kernel/transdata_kernel.cc | 10 +- .../passes/folding_kernel/unpack_kernel.cc | 2 + src/ge/graph/passes/folding_pass.cc | 16 +- .../graph/passes/get_original_format_pass.cc | 12 +- src/ge/graph/passes/guarantee_const_pass.cc | 2 + src/ge/graph/passes/hccl_memcpy_pass.cc | 13 +- src/ge/graph/passes/identity_pass.cc | 7 +- .../graph/passes/isolated_op_remove_pass.cc | 3 + src/ge/graph/passes/iterator_op_pass.cc | 2 + .../graph/passes/link_gen_mask_nodes_pass.cc | 10 +- .../graph/passes/link_gen_mask_nodes_pass.h | 4 - src/ge/graph/passes/merge_pass.cc | 15 +- src/ge/graph/passes/multi_batch_pass.cc | 8 + src/ge/graph/passes/multi_batch_pass.h | 2 +- src/ge/graph/passes/net_output_pass.cc | 8 + src/ge/graph/passes/next_iteration_pass.cc | 8 + .../passes/no_use_reshape_remove_pass.cc | 2 +- src/ge/graph/passes/pass_manager.cc | 2 + src/ge/graph/passes/pass_utils.cc | 12 +- src/ge/graph/passes/pass_utils.h | 1 - src/ge/graph/passes/permute_pass.cc | 9 + .../passes/placeholder_with_default_pass.cc | 2 + src/ge/graph/passes/prevent_gradient_pass.cc | 2 + src/ge/graph/passes/print_op_pass.h | 2 +- src/ge/graph/passes/prune_pass.cc | 4 + .../passes/replace_with_empty_const_pass.cc | 156 -- .../passes/replace_with_empty_const_pass.h | 34 - src/ge/graph/passes/reshape_remove_pass.cc | 2 +- .../same_transdata_breadth_fusion_pass.cc | 7 + .../passes/shape_operate_op_remove_pass.cc | 1 + src/ge/graph/passes/snapshot_pass.cc | 2 + src/ge/graph/passes/stop_gradient_pass.cc | 2 + .../graph/passes/switch_logic_remove_pass.cc | 2 +- src/ge/graph/passes/switch_op_pass.cc | 43 +- src/ge/graph/passes/switch_op_pass.h | 4 +- src/ge/graph/passes/switch_pass.cc | 4 + .../passes/transop_breadth_fusion_pass.cc | 9 + .../graph/passes/transop_depth_fusion_pass.cc | 9 + .../transop_nearby_allreduce_fusion_pass.cc | 4 +- .../transop_without_reshape_fusion_pass.cc | 8 + .../graph/passes/transpose_transdata_pass.cc | 4 + src/ge/graph/passes/unused_const_pass.cc | 4 +- src/ge/graph/passes/unused_op_remove_pass.cc | 6 + .../passes/var_is_initialized_op_pass.cc | 6 +- src/ge/graph/passes/variable_format_pass.cc | 6 +- src/ge/graph/passes/variable_op_pass.cc | 14 +- .../graph/passes/variable_prepare_op_pass.cc | 100 +- .../graph/passes/variable_prepare_op_pass.h | 6 +- .../passes/variable_ref_delete_op_pass.cc | 12 +- src/ge/graph/preprocess/graph_preprocess.cc | 107 +- .../graph/preprocess/insert_op/ge_aipp_op.cc | 20 +- .../insert_op/util_insert_aipp_op.cc | 32 +- .../preprocess/multi_batch_copy_graph.cc | 26 +- src/ge/inc/graph_pass.h | 8 +- src/ge/init/gelib.cc | 36 +- src/ge/init/gelib.h | 10 +- src/ge/ir_build/ge_ir_build.cc | 8 +- src/ge/omm/csa_interact.cc | 2 + src/ge/session/session_manager.cc | 1 + src/ge/single_op/single_op_model.cc | 17 +- src/ge/single_op/single_op_model.h | 5 +- src/ge/single_op/task/tbe_task_builder.cc | 25 +- src/ge/single_op/task/tbe_task_builder.h | 7 +- src/proto/fusion_model.proto | 3 +- src/proto/task.proto | 24 +- tests/depends/cce/CMakeLists.txt | 2 + tests/depends/mmpa/CMakeLists.txt | 1 + tests/depends/omg/CMakeLists.txt | 1 + tests/depends/omg/src/omg_stub.cc | 5 +- tests/depends/runtime/CMakeLists.txt | 1 + tests/ut/common/graph/CMakeLists.txt | 3 +- tests/ut/ge/CMakeLists.txt | 4 +- .../ge/common/datatype_transfer_unittest.cc | 14 +- .../format_transfer_nhwc_5d_unittest.cc | 2 +- .../logical_stream_allocator_unittest.cc | 104 +- ...ew_model_manager_davinci_model_unittest.cc | 10 +- .../graph/load/output_net_output_unittest.cc | 19 + third_party/fwkacllib/inc/ops/all_ops.h | 2 + third_party/fwkacllib/inc/ops/array_ops.h | 4 - third_party/fwkacllib/inc/ops/data_flow_ops.h | 21 +- .../inc/ops/elewise_calculation_ops.h | 6 +- .../inc/ops/fsrdetectionoutput_ops.h | 67 + third_party/fwkacllib/inc/ops/image_ops.h | 5 +- third_party/fwkacllib/inc/ops/math_ops.h | 23 - .../inc/ops/matrix_calculation_ops.h | 14 +- .../fwkacllib/inc/ops/nn_calculation_ops.h | 22 +- third_party/fwkacllib/inc/ops/nn_detect_ops.h | 352 --- .../fwkacllib/inc/ops/nn_pooling_ops.h | 28 + .../fwkacllib/inc/ops/nn_training_ops.h | 880 +------ .../fwkacllib/inc/ops/nonlinear_fuc_ops.h | 18 + third_party/fwkacllib/inc/ops/power_ops.h | 49 + third_party/fwkacllib/inc/ops/quantize_ops.h | 16 + .../fwkacllib/inc/ops/ragged_array_ops.h | 6 +- .../fwkacllib/inc/ops/ragged_conversion_ops.h | 38 - .../fwkacllib/inc/ops/ragged_math_ops.h | 8 +- third_party/fwkacllib/inc/ops/rnn.h | 6 +- third_party/fwkacllib/inc/ops/sdca_ops.h | 2 +- third_party/fwkacllib/inc/ops/selection_ops.h | 283 ++- third_party/fwkacllib/inc/ops/sparse_ops.h | 10 +- .../fwkacllib/inc/ops/stateful_random_ops.h | 18 +- third_party/fwkacllib/inc/ops/string_ops.h | 3 - .../fwkacllib/inc/ops/transformation_ops.h | 37 +- .../inc/register/op_kernel_registry.h | 3 +- third_party/fwkacllib/inc/register/register.h | 53 - third_party/fwkacllib/inc/runtime/kernel.h | 2 +- third_party/fwkacllib/inc/runtime/mem.h | 2 - third_party/fwkacllib/inc/runtime/rt_model.h | 3 +- third_party/fwkacllib/inc/toolchain/slog.h | 2 - third_party/fwkacllib/version.info | 1 - third_party/patch/securec/securec.patch001 | 23 - third_party/prebuild/x86_64/libc_sec.so | Bin 0 -> 80080 bytes third_party/securec/CMakeLists.txt | 11 + third_party/securec/include/securec.h | 634 +++++ third_party/securec/include/securectype.h | 542 +++++ third_party/securec/src/CMakeLists.txt | 3 + third_party/securec/src/fscanf_s.c | 56 + third_party/securec/src/fwscanf_s.c | 55 + third_party/securec/src/gets_s.c | 75 + third_party/securec/src/input.inl | 2125 +++++++++++++++++ third_party/securec/src/memcpy_s.c | 577 +++++ third_party/securec/src/memmove_s.c | 120 + third_party/securec/src/memset_s.c | 522 ++++ third_party/securec/src/output.inl | 1401 +++++++++++ third_party/securec/src/scanf_s.c | 55 + third_party/securec/src/secinput.h | 156 ++ third_party/securec/src/securecutil.c | 74 + third_party/securec/src/securecutil.h | 541 +++++ third_party/securec/src/secureinput_a.c | 25 + third_party/securec/src/secureinput_w.c | 46 + third_party/securec/src/secureprintoutput.h | 98 + third_party/securec/src/secureprintoutput_a.c | 101 + third_party/securec/src/secureprintoutput_w.c | 170 ++ third_party/securec/src/snprintf_s.c | 113 + third_party/securec/src/sprintf_s.c | 61 + third_party/securec/src/sscanf_s.c | 61 + third_party/securec/src/strcat_s.c | 102 + third_party/securec/src/strcpy_s.c | 351 +++ third_party/securec/src/strncat_s.c | 121 + third_party/securec/src/strncpy_s.c | 143 ++ third_party/securec/src/strtok_s.c | 117 + third_party/securec/src/swprintf_s.c | 51 + third_party/securec/src/swscanf_s.c | 57 + third_party/securec/src/vfscanf_s.c | 67 + third_party/securec/src/vfwscanf_s.c | 66 + third_party/securec/src/vscanf_s.c | 68 + third_party/securec/src/vsnprintf_s.c | 149 ++ third_party/securec/src/vsprintf_s.c | 73 + third_party/securec/src/vsscanf_s.c | 88 + third_party/securec/src/vswprintf_s.c | 66 + third_party/securec/src/vswscanf_s.c | 79 + third_party/securec/src/vwscanf_s.c | 67 + third_party/securec/src/wcscat_s.c | 111 + third_party/securec/src/wcscpy_s.c | 91 + third_party/securec/src/wcsncat_s.c | 118 + third_party/securec/src/wcsncpy_s.c | 111 + third_party/securec/src/wcstok_s.c | 116 + third_party/securec/src/wmemcpy_s.c | 68 + third_party/securec/src/wmemmove_s.c | 67 + third_party/securec/src/wscanf_s.c | 55 + 367 files changed, 13180 insertions(+), 7302 deletions(-) delete mode 100644 cmake/external_libs/securec.cmake delete mode 100644 inc/common/util/compress/compress.h delete mode 100644 src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc delete mode 100644 src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h delete mode 100644 src/ge/common/helper/model_cache_helper.cc delete mode 100644 src/ge/common/helper/model_cache_helper.h delete mode 100644 src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc delete mode 100644 src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h delete mode 100644 src/ge/graph/passes/replace_with_empty_const_pass.cc delete mode 100644 src/ge/graph/passes/replace_with_empty_const_pass.h create mode 100644 third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h create mode 100644 third_party/fwkacllib/inc/ops/power_ops.h delete mode 100644 third_party/fwkacllib/inc/register/register.h delete mode 100644 third_party/fwkacllib/version.info delete mode 100644 third_party/patch/securec/securec.patch001 create mode 100755 third_party/prebuild/x86_64/libc_sec.so create mode 100644 third_party/securec/CMakeLists.txt create mode 100644 third_party/securec/include/securec.h create mode 100644 third_party/securec/include/securectype.h create mode 100644 third_party/securec/src/CMakeLists.txt create mode 100644 third_party/securec/src/fscanf_s.c create mode 100644 third_party/securec/src/fwscanf_s.c create mode 100644 third_party/securec/src/gets_s.c create mode 100644 third_party/securec/src/input.inl create mode 100644 third_party/securec/src/memcpy_s.c create mode 100644 third_party/securec/src/memmove_s.c create mode 100644 third_party/securec/src/memset_s.c create mode 100644 third_party/securec/src/output.inl create mode 100644 third_party/securec/src/scanf_s.c create mode 100644 third_party/securec/src/secinput.h create mode 100644 third_party/securec/src/securecutil.c create mode 100644 third_party/securec/src/securecutil.h create mode 100644 third_party/securec/src/secureinput_a.c create mode 100644 third_party/securec/src/secureinput_w.c create mode 100644 third_party/securec/src/secureprintoutput.h create mode 100644 third_party/securec/src/secureprintoutput_a.c create mode 100644 third_party/securec/src/secureprintoutput_w.c create mode 100644 third_party/securec/src/snprintf_s.c create mode 100644 third_party/securec/src/sprintf_s.c create mode 100644 third_party/securec/src/sscanf_s.c create mode 100644 third_party/securec/src/strcat_s.c create mode 100644 third_party/securec/src/strcpy_s.c create mode 100644 third_party/securec/src/strncat_s.c create mode 100644 third_party/securec/src/strncpy_s.c create mode 100644 third_party/securec/src/strtok_s.c create mode 100644 third_party/securec/src/swprintf_s.c create mode 100644 third_party/securec/src/swscanf_s.c create mode 100644 third_party/securec/src/vfscanf_s.c create mode 100644 third_party/securec/src/vfwscanf_s.c create mode 100644 third_party/securec/src/vscanf_s.c create mode 100644 third_party/securec/src/vsnprintf_s.c create mode 100644 third_party/securec/src/vsprintf_s.c create mode 100644 third_party/securec/src/vsscanf_s.c create mode 100644 third_party/securec/src/vswprintf_s.c create mode 100644 third_party/securec/src/vswscanf_s.c create mode 100644 third_party/securec/src/vwscanf_s.c create mode 100644 third_party/securec/src/wcscat_s.c create mode 100644 third_party/securec/src/wcscpy_s.c create mode 100644 third_party/securec/src/wcsncat_s.c create mode 100644 third_party/securec/src/wcsncpy_s.c create mode 100644 third_party/securec/src/wcstok_s.c create mode 100644 third_party/securec/src/wmemcpy_s.c create mode 100644 third_party/securec/src/wmemmove_s.c create mode 100644 third_party/securec/src/wscanf_s.c diff --git a/CMakeLists.txt b/CMakeLists.txt index fff8c055..94cf6ae5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,12 +42,12 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/eigen.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/gtest.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) include(${GE_SOURCE_DIR}/cmake/external_libs/onnx.cmake) -include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake) set(CMAKE_SKIP_RPATH TRUE) # for CPU/GPU mode, find c_sec and slog from local prebuild if(NOT ENABLE_D AND NOT GE_ONLY) set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}) + find_library(c_sec libc_sec.so ${GE_PREBUILD_PATH}) find_library(slog libslog.so ${GE_PREBUILD_PATH}) # if D_LINK_PATH is set in environment variables, search libraries in given path elseif(DEFINED ENV{D_LINK_PATH}) @@ -64,6 +64,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) message(FATAL_ERROR "Running on a unsupported architecture: ${SYSTEM_TYPE}, build terminated") endif() set(GE_LIB_PATH ${GE_LIB_PATH}/${GE_SYS_ARCH}) + find_library(c_sec libc_sec.so ${GE_LIB_PATH}) find_library(slog libslog.so ${GE_LIB_PATH}) find_library(mmpa libmmpa.so ${GE_LIB_PATH}) find_library(runtime libruntime.so ${GE_LIB_PATH}) @@ -80,6 +81,7 @@ else() endif() set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64/common) set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64) + find_library(c_sec libc_sec.so ${ASCEND_DRIVER_DIR}) find_library(slog libslog.so ${ASCEND_DRIVER_DIR}) find_library(mmpa libmmpa.so ${ASCEND_DRIVER_DIR}) find_library(msprof libmsprof.so ${ASCEND_DRIVER_DIR}) @@ -128,7 +130,7 @@ elseif(GE_ONLY) add_subdirectory(${GE_SOURCE_DIR}/src/ge/plugin/engine) endif() -if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) - add_subdirectory(tests) -endif() +# if (ENABLE_GE_COV OR ENABLE_GE_UT OR ENABLE_GE_ST) +# add_subdirectory(tests) +# endif() diff --git a/build.sh b/build.sh index 74f13849..0afaa7fb 100644 --- a/build.sh +++ b/build.sh @@ -41,7 +41,7 @@ checkopts() { VERBOSE="" THREAD_NUM=8 - ENABLE_GE_UT_ONLY_COMPILE="off" + # ENABLE_GE_UT_ONLY_COMPILE="off" ENABLE_GE_UT="off" ENABLE_GE_ST="off" ENABLE_GE_COV="off" @@ -52,7 +52,7 @@ checkopts() OPTARG=$(echo ${OPTARG} | tr '[A-Z]' '[a-z]') case "${opt}" in u) - ENABLE_GE_UT_ONLY_COMPILE="on" + # ENABLE_GE_UT_ONLY_COMPILE="on" ENABLE_GE_UT="on" GE_ONLY="off" ;; @@ -137,39 +137,39 @@ find ${OUTPUT_PATH} -name "*.so*" -print0 | xargs -0 chmod 500 echo "---------------- GraphEngine output generated ----------------" -if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then - cp ${BUILD_PATH}/graphengine/tests/st/st_resnet50_train ${OUTPUT_PATH} -fi - -if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then - cp ${BUILD_PATH}/graphengine/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} - cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} - cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} - cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} - cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} - - if [[ "X${ENABLE_GE_UT_ONLY_COMPILE}" != "Xon" ]]; then - export LD_LIBRARY_PATH=${D_LINK_PATH}/x86_64/:${BUILD_PATH}../third_party/prebuild/x86_64/:${BUILD_PATH}/graphengine/:/usr/local/HiAI/driver/lib64:/usr/local/HiAI/runtime/lib64:${LD_LIBRARY_PATH} - echo ${LD_LIBRARY_PATH} - ${OUTPUT_PATH}/ut_libgraph && - ${OUTPUT_PATH}/ut_libge_multiparts_utest && - ${OUTPUT_PATH}/ut_libge_distinct_load_utest && - ${OUTPUT_PATH}/ut_libge_others_utest && - ${OUTPUT_PATH}/ut_libge_kernel_utest - if [[ "$?" -ne 0 ]]; then - echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" - exit 1; - fi - fi - - if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then - echo "Generating coverage statistics, please wait..." - cd ${BASEPATH} - rm -rf ${BASEPATH}/cov - mkdir ${BASEPATH}/cov - gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html - fi -fi +# if [[ "X$ENABLE_GE_ST" = "Xon" ]]; then +# cp ${BUILD_PATH}/graphengine/tests/st/st_resnet50_train ${OUTPUT_PATH} +# fi + +# if [[ "X$ENABLE_GE_UT" = "Xon" || "X$ENABLE_GE_COV" = "Xon" ]]; then +# cp ${BUILD_PATH}/graphengine/tests/ut/common/graph/ut_libgraph ${OUTPUT_PATH} +# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_multiparts_utest ${OUTPUT_PATH} +# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_distinct_load_utest ${OUTPUT_PATH} +# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_others_utest ${OUTPUT_PATH} +# cp ${BUILD_PATH}/graphengine/tests/ut/ge/ut_libge_kernel_utest ${OUTPUT_PATH} + +# if [[ "X${ENABLE_GE_UT_ONLY_COMPILE}" != "Xon" ]]; then +# export LD_LIBRARY_PATH=${D_LINK_PATH}/x86_64/:${BUILD_PATH}../third_party/prebuild/x86_64/:${BUILD_PATH}/graphengine/:/usr/local/HiAI/driver/lib64:/usr/local/HiAI/runtime/lib64:${LD_LIBRARY_PATH} +# echo ${LD_LIBRARY_PATH} +# ${OUTPUT_PATH}/ut_libgraph && +# ${OUTPUT_PATH}/ut_libge_multiparts_utest && +# ${OUTPUT_PATH}/ut_libge_distinct_load_utest && +# ${OUTPUT_PATH}/ut_libge_others_utest && +# ${OUTPUT_PATH}/ut_libge_kernel_utest +# if [[ "$?" -ne 0 ]]; then +# echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" +# exit 1; +# fi +# fi + +# if [[ "X$ENABLE_GE_COV" = "Xon" ]]; then +# echo "Generating coverage statistics, please wait..." +# cd ${BASEPATH} +# rm -rf ${BASEPATH}/cov +# mkdir ${BASEPATH}/cov +# gcovr -r ./ --exclude 'third_party' --exclude 'build' --exclude 'tests' --exclude 'prebuild' --exclude 'inc' --print-summary --html --html-details -d -o cov/index.html +# fi +# fi # generate output package in tar form, including ut/st libraries/executables cd ${BASEPATH} diff --git a/cmake/external_libs/securec.cmake b/cmake/external_libs/securec.cmake deleted file mode 100644 index 34488b6f..00000000 --- a/cmake/external_libs/securec.cmake +++ /dev/null @@ -1,11 +0,0 @@ -graphengine_add_pkg(securec - VER 1.1.10 - URL https://gitee.com/openeuler/bounds_checking_function/repository/archive/v1.1.10.tar.gz - MD5 0782dd2351fde6920d31a599b23d8c91 - LIBS c_sec - PATCHES ${GE_SOURCE_DIR}/third_party/patch/securec/securec.patch001 - CMAKE_OPTION " " - ) -include_directories(${securec_INC}) -file(COPY ${securec_INC}/../lib/libc_sec.so DESTINATION ${CMAKE_SOURCE_DIR}/build/graphengine) -add_library(graphengine::securec ALIAS securec::c_sec) diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h deleted file mode 100644 index 6908fb75..00000000 --- a/inc/common/util/compress/compress.h +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMPRESS_H -#define COMPRESS_H - -#include - -enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; - -struct CompressConfig { - size_t inputSize; // length of data to compress - size_t engineNum; // how many decompress engines - size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) - size_t channel; // channels of L2 or DDR. For load balance - size_t fractalSize; // size of compressing block - bool isTight; // whether compose compressed data tightly -}; - -CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, - size_t& compressedLength); - -#endif // COMPRESS_H diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 13477bbd..bf9a10b4 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -40,8 +40,6 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; -const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; -const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; diff --git a/inc/external/graph/inference_context.h b/inc/external/graph/inference_context.h index 69079142..68a9ecf5 100644 --- a/inc/external/graph/inference_context.h +++ b/inc/external/graph/inference_context.h @@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { static std::unique_ptr Create(); private: - explicit InferenceContext(std::unique_ptr &impl); + InferenceContext(std::unique_ptr &impl); std::shared_ptr inference_context_impl_; }; } // namespace ge diff --git a/inc/external/register/register.h b/inc/external/register/register.h index f96044de..045a1570 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -116,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { namespace ge { using OpRegistrationData = domi::OpRegistrationData; using OpReceiver = domi::OpReceiver; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { + public: + HostCpuOp() = default; + virtual ~HostCpuOp() = default; + + virtual graphStatus Compute(Operator &op, const std::map &inputs, + std::map &outputs) = 0; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { + public: + HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); +}; + +#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ + static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ + ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) } // namespace ge #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index e2023cb8..f2df79a7 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -51,24 +51,24 @@ inline pid_t GetTid() { return tid; } -#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() +#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() #define GE_TIMESTAMP_END(stage, stage_name) \ do { \ - uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ (endUsec_##stage - startUsec_##stage)); \ } while (0); -#define GE_TIMESTAMP_CALLNUM_START(stage) \ - uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ - uint64_t call_num_of##stage = 0; \ +#define GE_TIMESTAMP_CALLNUM_START(stage) \ + uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ + uint64_t call_num_of##stage = 0; \ uint64_t time_of##stage = 0 -#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) +#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) -#define GE_TIMESTAMP_ADD(stage) \ - time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ +#define GE_TIMESTAMP_ADD(stage) \ + time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ call_num_of##stage++ #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index b16aa3fa..9a192a82 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -103,17 +103,17 @@ using cce::ccStatus_t; } while (0); // If expr is not true, print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - std::string msg; \ - (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ - (void)msg.append( \ - ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ - DOMI_LOGE("%s", msg.c_str()); \ - return _status; \ - } \ +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + std::string msg; \ + (void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ + (void)msg.append( \ + domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ + DOMI_LOGE("%s", msg.c_str()); \ + return _status; \ + } \ } while (0); // If expr is not true, print the log and return the specified status diff --git a/inc/framework/common/ge_inner_error_codes.h b/inc/framework/common/ge_inner_error_codes.h index 4b5538d3..b563aef7 100644 --- a/inc/framework/common/ge_inner_error_codes.h +++ b/inc/framework/common/ge_inner_error_codes.h @@ -152,6 +152,7 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 diff --git a/inc/framework/common/gflags_util.h b/inc/framework/common/gflags_util.h index 94d66ffb..4fb9511f 100644 --- a/inc/framework/common/gflags_util.h +++ b/inc/framework/common/gflags_util.h @@ -20,7 +20,7 @@ #include #include -namespace ge { +namespace domi { class GflagsUtils { public: static bool IsSetCommandTrue(const char *name) { @@ -66,6 +66,6 @@ class GflagsUtils { } } }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index c918c039..c16e3c23 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -26,7 +26,7 @@ #include "graph/model.h" #include "model/ge_model.h" -namespace ge { +namespace domi { class ModelHelper { public: ModelHelper() = default; @@ -65,8 +65,9 @@ class ModelHelper { Status LoadTask(OmFileLoadHelper& om_load_helper); Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); Status ReleaseLocalModelData() noexcept; + Status SaveModelPartition(std::shared_ptr& om_file_save_helper, ModelPartitionType type, const uint8_t* data, size_t size); }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ diff --git a/inc/framework/common/helper/om_file_helper.h b/inc/framework/common/helper/om_file_helper.h index 1e4cee9b..7c301f97 100644 --- a/inc/framework/common/helper/om_file_helper.h +++ b/inc/framework/common/helper/om_file_helper.h @@ -26,10 +26,8 @@ #include "framework/common/ge_types.h" using ProcParam = struct PROC_PARAM; -using std::string; -using std::vector; -namespace ge { +namespace domi { struct ModelPartition { ModelPartitionType type; uint8_t *data = 0; @@ -90,5 +88,5 @@ class OmFileSaveHelper { ModelFileHeader model_header_; OmFileContext context_; }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ diff --git a/inc/framework/common/l2_cache_optimize.h b/inc/framework/common/l2_cache_optimize.h index c65f67b3..8aa0a5d1 100644 --- a/inc/framework/common/l2_cache_optimize.h +++ b/inc/framework/common/l2_cache_optimize.h @@ -30,7 +30,7 @@ using std::vector; -namespace ge { +namespace domi { // Size of RC memory alignment, 2M constexpr size_t ALIGN_SIZE = 2097152; @@ -118,6 +118,6 @@ class L2CacheOptimize { bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ \ No newline at end of file diff --git a/inc/framework/common/op/attr_value_util.h b/inc/framework/common/op/attr_value_util.h index 8a90cfa2..b55d3391 100644 --- a/inc/framework/common/op/attr_value_util.h +++ b/inc/framework/common/op/attr_value_util.h @@ -21,17 +21,11 @@ #include #include +#include "common/op/attr_define.h" #include "common/types.h" -#include "graph/debug/ge_attr_define.h" #include "proto/om.pb.h" -using domi::AttrDef; -using domi::AttrDef_ListValue; -using domi::ModelDef; -using domi::NamedAttrs; -using domi::OpDef; - -namespace ge { +namespace domi { using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; using AttrDefPair = ::google::protobuf::MapPair; @@ -156,6 +150,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ diff --git a/inc/framework/common/op/ge_op_utils.h b/inc/framework/common/op/ge_op_utils.h index 87cf54d8..b3730f16 100644 --- a/inc/framework/common/op/ge_op_utils.h +++ b/inc/framework/common/op/ge_op_utils.h @@ -62,8 +62,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; - class OpUtils { public: /// diff --git a/inc/framework/common/op/op_parser_util.h b/inc/framework/common/op/op_parser_util.h index 49b4350a..e64ddc92 100644 --- a/inc/framework/common/op/op_parser_util.h +++ b/inc/framework/common/op/op_parser_util.h @@ -22,7 +22,7 @@ #include #include -namespace ge { +namespace domi { // general const float DEFAULT_ALPHA_VALUE = 1.0; const float DEFAULT_BETA_VALUE = 0.0; @@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; // Shufflechannel const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ diff --git a/inc/framework/common/op_types.h b/inc/framework/common/op_types.h index 4555d5c3..8d859169 100644 --- a/inc/framework/common/op_types.h +++ b/inc/framework/common/op_types.h @@ -20,7 +20,7 @@ #include #include -namespace ge { +namespace domi { class OpTypeContainer { public: static OpTypeContainer *Instance() { @@ -57,6 +57,6 @@ class OpTypeRegistrar { const OpTypeRegistrar g_##var_name##_reg(str_name); #define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ diff --git a/inc/framework/common/scope_guard.h b/inc/framework/common/scope_guard.h index 2154648d..6e5c4b4a 100644 --- a/inc/framework/common/scope_guard.h +++ b/inc/framework/common/scope_guard.h @@ -25,10 +25,10 @@ /// MAKE_GUARD([&] { Release Resource 1 }) /// Acquire Resource 2 // MAKE_GUARD([&] { Release Resource 2 }) -#define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) +#define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() -namespace ge { +namespace domi { class ScopeGuard { public: // Noncopyable @@ -55,6 +55,6 @@ class ScopeGuard { std::function on_exit_scope_; bool dismissed_; }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ diff --git a/inc/framework/common/string_util.h b/inc/framework/common/string_util.h index b74eddcf..42d5a2cd 100644 --- a/inc/framework/common/string_util.h +++ b/inc/framework/common/string_util.h @@ -25,7 +25,7 @@ #include #include -namespace ge { +namespace domi { class StringUtils { public: static std::string &Ltrim(std::string &s) { @@ -151,6 +151,6 @@ class StringUtils { return ret > 0 ? buffer : ""; } }; -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 1cc2245b..d98f784c 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -26,7 +26,6 @@ #include #include #include - #include "framework/common/fmk_error_codes.h" #include "framework/common/fmk_types.h" #include "framework/common/op_types.h" @@ -47,7 +46,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; +} // namespace ge +namespace domi { // Supported public properties name FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path @@ -67,6 +68,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map PROFILE_COMPONENT_MAP; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; +/// @brief Data structure definition related to task sinking +/// Build model +enum BuildMode { + GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) + GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) + GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) +}; + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; @@ -333,7 +342,7 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); -// ANN dedicated operator +/***************ANN dedicated operator *************************/ REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); @@ -350,7 +359,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); -// Training operator +/********************Training operator ***********************/ REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); @@ -434,7 +443,6 @@ REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); -REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DECLARE(SEND, "Send"); @@ -442,7 +450,6 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); -REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); @@ -821,6 +828,9 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; // number of partitions in the current model static constexpr uint32_t PARTITION_SIZE = 4; +#define SIZE_OF_MODEL_PARTITION_TABLE(table) \ + (sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) + enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; struct ModelPartitionMemInfo { @@ -834,8 +844,6 @@ struct ModelPartitionTable { ModelPartitionMemInfo partition[0]; }; -#define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) - static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success // Filter format @@ -967,8 +975,8 @@ typedef enum tagDomiNanPropagation { // mode of cropandresize typedef enum tagDomiCropAndResizeMode { - DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear - DOMI_RESIZE_METHOD_NEAREST, // resize nearest + DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ + DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ DOMI_RESIZE_RESERVED } domiCropAndResizeMode_t; @@ -1055,15 +1063,6 @@ struct BasicInfo { uint32_t total_size; // total memory size }; #pragma pack() // Cancels single-byte alignment -} // namespace ge - -namespace domi { -/// @brief Data structure definition related to task sinking -enum BuildMode { - GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) - GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) - GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) -}; } // namespace domi #endif // INC_FRAMEWORK_COMMON_TYPES_H_ diff --git a/inc/framework/common/util.h b/inc/framework/common/util.h index 4c37c01e..6447340f 100644 --- a/inc/framework/common/util.h +++ b/inc/framework/common/util.h @@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; */ #define CEIL(N, n) (((N) + (n)-1) / (n)) -namespace ge { +namespace domi { using google::protobuf::Message; /// @@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); /// @param [out] result /// bool ValidateStr(const std::string &filePath, const std::string &mode); -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_COMMON_UTIL_H_ diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index d2599856..925aa9dd 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -28,16 +28,12 @@ #include "framework/common/types.h" #include "register/register_fmk_types.h" -using domi::DOMI_TENSOR_ND; -using domi::DOMI_TENSOR_RESERVED; -using domi::domiTensorFormat_t; -using domi::FrameworkType; using std::map; using std::string; using std::unordered_map; using std::vector; -namespace ge { +namespace domi { /** * @ingroup domi_omg * @brief run model @@ -97,7 +93,7 @@ struct OmgContext { std::string ddk_version; // preferential format used by the entire network domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; - domi::FrameworkType type = domi::FMK_TYPE_RESERVED; + FrameworkType type = FMK_TYPE_RESERVED; RunMode run_mode = ONLY_PRE_CHECK; bool train_flag = false; // whether to use FP16 high precision @@ -106,25 +102,23 @@ struct OmgContext { std::string output_type; // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the - // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module - // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the - // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape - // operator needs to be deleted for the convolution kernel. + // network require special processing based on the specific network. + // e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope + // fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The + // convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. std::string net_name; // Whether to use dynamic batch size or dynamic image size bool is_dynamic_input = false; std::string dynamic_batch_size; std::string dynamic_image_size; }; -} // namespace ge -namespace domi { /** * @ingroup domi_omg * @brief get OMG context * @return OmgContext context */ -ge::OmgContext &GetContext(); +OmgContext &GetContext(); struct TEBinInfo { // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. diff --git a/inc/framework/omg/version.h b/inc/framework/omg/version.h index ac649d83..300f32eb 100644 --- a/inc/framework/omg/version.h +++ b/inc/framework/omg/version.h @@ -26,7 +26,7 @@ #include "common/string_util.h" #include "framework/common/debug/ge_log.h" -namespace ge { +namespace domi { class PlatformVersionManager { public: PlatformVersionManager() = delete; @@ -40,6 +40,6 @@ class PlatformVersionManager { return SUCCESS; } }; // class PlatformManager -} // namespace ge +} // namespace domi #endif // INC_FRAMEWORK_OMG_VERSION_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 57d1c6c6..ed992a62 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -58,8 +58,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; @@ -76,7 +74,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; +// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string +// ATTR_NAME_WEIGHTS; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; @@ -124,13 +123,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; @@ -148,24 +140,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; - // to be deleted GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; @@ -178,15 +158,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; // _Arg @@ -275,29 +255,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; - -// Huberloss -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; - -// SSDRealDivTileMul -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; - -// SSDSumMulRealDivMean -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; -/// ConcatFive2Four -/// ConcatFour2Five -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; + // Scale GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; @@ -334,6 +292,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; // Roipooling @@ -346,7 +305,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI // DetectionOutput GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; @@ -405,7 +363,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ // Permute GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; // SSD Normalize GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; @@ -446,15 +403,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; -// Log -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; // Pack GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; // Unpack GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; // Gathernd @@ -463,16 +414,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND // Argmax GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; -// Upsample -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; // Relu GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; @@ -511,7 +454,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; // Generate_rpn_proposal GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; @@ -544,7 +486,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; -static const std::string NOT_NET_OUTPUT = "not_net_output"; // ENTER GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; @@ -570,9 +511,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; -// RetinaNet -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; // MatMul GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; @@ -621,30 +559,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; // Upsample GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; -// PadV2 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; - -// MirrorPad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; // Filler GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; @@ -665,6 +583,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; @@ -689,6 +637,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; + // Public attribute GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; @@ -740,159 +696,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; - -// L2_normalize -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -// Log time stamp -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; -// SpaceToDepth/DepthToSpace -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; - -// SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; - -// MaxPoolGradWithArgmax -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; - -// AvgPoolGrad -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; - -// Varible -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; - -// Assign -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; - -// ShapeN -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; - -// Space2bacth batch2space -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; -// Depth_to_space space_to_depth -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; -// FakeQuantWithMinMaxVars -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; -// Mobilenet_ssd_conv_fusion -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; - -// Lsh project -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; - -// Control flow -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; - -// GatherV2 attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; - -// Reshape attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; - -// Axis attr def -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; -// The node link with SparseSoftmaxCrossEntropyWithLogits -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; -// For constant folding -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; - // Used for mark the active label list to find stream of activated node GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; @@ -905,6 +708,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM // Control flow GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; @@ -979,33 +783,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; -// functional ops attr -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; - // used for label switch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; - -// Varible -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; - -// HCOM -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; - -// Dynamic stitch -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h index ad4e6475..1d50577c 100644 --- a/inc/graph/detail/model_serialize_imp.h +++ b/inc/graph/detail/model_serialize_imp.h @@ -22,7 +22,7 @@ #include #include #include "graph/anchor.h" -#include "graph/detail/attributes_holder.h" +#include "detail/attributes_holder.h" #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/node.h" diff --git a/inc/graph/model.h b/inc/graph/model.h index 38ea501b..464a2401 100644 --- a/inc/graph/model.h +++ b/inc/graph/model.h @@ -25,7 +25,11 @@ #include "graph/ge_attr_value.h" #include "graph/graph.h" +namespace domi { +class ModelHelper; +} namespace ge { +using domi::ModelHelper; using std::map; using std::string; using std::vector; diff --git a/inc/graph/usr_types.h b/inc/graph/usr_types.h index 90e02001..796a70a3 100644 --- a/inc/graph/usr_types.h +++ b/inc/graph/usr_types.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef INC_GRAPH_USR_TYPES_H_ -#define INC_GRAPH_USR_TYPES_H_ +#ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ +#define INC_EXTERNAL_GRAPH_USR_TYPES_H_ #include #include @@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { #undef USR_TYPE_BYTES_DEC } // namespace ge -#endif // INC_GRAPH_USR_TYPES_H_ +#endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index fb979e3e..8066e8b5 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -262,8 +262,6 @@ class GraphUtils { static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); - - static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); }; class ComputeGraphBuilder { diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt index c0f8ccaf..56b68c69 100755 --- a/src/common/graph/CMakeLists.txt +++ b/src/common/graph/CMakeLists.txt @@ -59,6 +59,7 @@ include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/inc/common) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) include_directories(${GE_SOURCE_DIR}/build) diff --git a/src/common/graph/anchor.cc b/src/common/graph/anchor.cc index f02037e5..0b9eb00a 100644 --- a/src/common/graph/anchor.cc +++ b/src/common/graph/anchor.cc @@ -53,6 +53,7 @@ void Anchor::UnlinkAll() noexcept { if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { GELOGW("unlink peer_anchor_ptr failed."); } + } while (!peer_anchors_.empty()); } } diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index 2dcc7a54..a35747d4 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -54,34 +54,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS return s; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - if (sub_graph_.empty()) { - return Vistor(shared_from_this(), nodes_); - } - - std::vector all_nodes; - std::deque candidates; - - candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); - - while (!candidates.empty()) { - NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { + vector all_nodes(nodes_.size()); + (void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); + for (const auto &sub_graph : sub_graph_) { + if (sub_graph == nullptr) { + GELOGW("sub graph is nullptr"); continue; } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); - } + for (const auto &node : sub_graph->GetAllNodes()) { + all_nodes.push_back(node); } } - return Vistor(shared_from_this(), all_nodes); } size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } @@ -619,7 +602,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::vector &stack) { - GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); + GELOGI("Runing_Dfs_Sort"); // Record the number of non data nodes but no input nodes GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); @@ -664,7 +647,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, std::deque &stack) { - GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); + GELOGI("Runing_Bfs_Sort"); std::vector stack_input; std::map breadth_node_map; // Record the number of non data nodes but no input nodes @@ -752,7 +735,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog use_BFS = true; } } else { - GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); + GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); } if (use_BFS) { diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc index 2230dc1b..04294180 100644 --- a/src/common/graph/format_refiner.cc +++ b/src/common/graph/format_refiner.cc @@ -66,6 +66,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std anchor_points.clear(); // Get all anchor point nodes and switch nodes for (const auto &node_ptr : graph->GetAllNodes()) { + std::vector is_node_set_format; if (node_ptr == nullptr) { return GRAPH_FAILED; } diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index 961d3bc4..139bb4f3 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -42,8 +42,6 @@ const std::string ATTR_NAME_BIAS = "bias"; const std::string ATTR_NAME_BIAS_TERM = "bias_term"; -const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; - const std::string ATTR_NAME_PAD = "pad"; const std::string ATTR_NAME_PADS = "pad"; @@ -85,7 +83,6 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; const std::string ATTR_NAME_AXIS = "axis"; const std::string ATTR_NAME_BROADCAST = "broadcast"; -const std::string ATTR_NAME_OUTPUT = "output"; const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; const std::string ATTR_NAME_TIDX = "t_idx"; @@ -106,13 +103,6 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; const std::string ATTR_NAME_NAN_OPT = "nan_opt"; const std::string ATTR_NAME_AIPP = "aipp"; -const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; - -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; - -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; -const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; @@ -121,7 +111,6 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; @@ -133,11 +122,9 @@ const std::string ATTR_NAME_WEIGHTS = "value"; const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; -const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; -const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; -const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; -const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; -const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; +const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; + +const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; // To be deleted const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; @@ -151,13 +138,15 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; +const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; + // Refinedet const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; - +const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; +const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; // _Arg const std::string ATTR_NAME_INDEX = "index"; @@ -247,30 +236,6 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; const std::string BATCHNORM_ATTR_SCALE = "scale"; const std::string BATCHNORM_ATTR_BIAS = "bias"; -const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; -const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; -const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; - -// huberloss -const std::string HUBER_LOSS_ATTR_DELTA = "delta"; - -// SSDRealDivTileMul -const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; - -// SSDSumMulRealDivMean -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; - -// ConcatFive2Four -// ConcatFour2Five -const std::string SSD_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_CLASS_NUM = "class_num"; -const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; -const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; -const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; -const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; // Scale const std::string SCALE_ATTR_SCALE = "scale"; @@ -375,7 +340,6 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; // Permute const std::string PERMUTE_ATTR_ORDER = "order"; -const std::string PERMUTE_ATTR_PERM = "perm"; // SSD Normalize const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; @@ -403,10 +367,6 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -// RefinedetDetectionOutput -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; - // PRelu const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; @@ -420,16 +380,11 @@ const std::string POWER_ATTR_NAME_POWER = "power"; const std::string POWER_ATTR_NAME_SCALE = "scale"; const std::string POWER_ATTR_NAME_SHIFT = "shift"; -// log -const std::string LOG_ATTR_NAME_SCALE = "scale"; -const std::string LOG_ATTR_NAME_SHIFT = "shift"; -const std::string LOG_ATTR_NAME_BASE = "base"; // Pack const std::string PACK_ATTR_NAME_NUM = "N"; // Unpack const std::string UNPACK_ATTR_NAME_NUM = "num"; -const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; // Gathernd const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; @@ -439,13 +394,6 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; -const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; -const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; -const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; - -// upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; -const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; // Relu const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; @@ -485,7 +433,6 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; -const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; // Generate_rpn_proposal const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; @@ -584,42 +531,19 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; // Rnn -const std::string RNN_TENSORFLOW = "rnn_tensorflow"; -const std::string RNN_MODE_STATIC = "rnn_static"; -const std::string MUTI_RNN = "multi_rnn"; -const std::string CNN_RNN = "cnn_rnn"; const std::string RNN_MODE_ = "rnn_"; - +const std::string CNN_RNN = "cnn_rnn"; +const std::string MUTI_RNN = "multi_rnn"; const std::string CELL_MODE = "mode"; const std::string LSTM_CELL = "lstm_cell"; const std::string GRU_CELL = "gru_cell"; const std::string RNN_HT = "ht"; const std::string RNN_XT_HT = "xt_ht"; const std::string RNN_BATCH_SIZE = "batch_size"; -const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; -const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; -const std::string LSTM_ACTIVATE = "lstm_activate"; -const std::string LSTM_OUT_MAP = "lstm_out_map"; -const std::string LSTM_OUT_MODE = "lstm_out_mode"; -const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; -const std::string LSTM_TIME_MAJOR = "lstm_time_major"; -const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; // Upsample const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; -// PadV2 -const std::string PADV2_ATTR_NAME_MODE = "mode"; -const std::string PADV2_ATTR_NAME_PADS = "paddings"; -const std::string PADV2_ATTR_NAME_T = "T"; -const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; - -// MirrorPad -const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; -const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; -const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; // Filler const std::string FILLER_TYPE = "filler_type"; const std::string FILLER_VALUE = "filler_value"; @@ -630,6 +554,9 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; // TopKV2 const std::string TOPKV2_ATTR_K = "k"; +const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; +const std::string L2_NORMALIZE_ATTR_EPS = "eps"; + // Calibaration const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; @@ -733,121 +660,6 @@ const std::string TARGET_TYPE_TINY = "TINY"; const std::string TARGET_TYPE_LITE = "LITE"; -// l2_normalize -const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - -const std::string POOL_PARAMA_ATTR_WINDOW = "window"; -const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; -const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; -const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; - -// HCOM -const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; -const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; - -const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; -const std::string HCOM_ATTR_GROUP = "group"; -const std::string HCOM_ATTR_SR_TAG = "sr_tag"; -const std::string HCOM_ATTR_SRC_RANK = "src_rank"; -const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; -const std::string HCOM_ATTR_FUSION = "fusion"; -const std::string HCOM_ATTR_SHAPE = "shape"; -const std::string HCOM_ATTR_DATA_TYPE = "dtype"; - -// SpaceToDepth/DepthToSpace -const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; - -// SparseSoftmaxCrossEntropyWithLogits -const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; - -// MaxPoolGradWithArgmax -const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; - -// AvgPoolGrad -const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; - -// Pad -const std::string ATTR_PAD_FORMAT = "attr_pad_format"; - -// Varible -const std::string VAR_ATTR_FORMAT = "_var_format"; -const std::string VAR_ATTR_NAME = "var_name"; -const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; -const std::string VAR_ATTR_4D_FORMAT = "4D"; -const std::string VAR_ATTR_5D_FORMAT = "5D"; -const std::string VAR_ATTR_DATA_TYPE = "data_format"; -const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; -const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; -const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; -const std::string VAR_ATTR_SHAPE = "shape"; -const std::string HALF_VAR_NAME_END = "_fp16"; -const std::string VAR_ATTR_INITED = "var_is_inited"; - -const std::string VAR_ATTR_CONTAINER = "container"; -const std::string VAR_ATTR_SHARED_NAME = "shared_name"; -const std::string VAR_ATTR_DTYPE = "dtype"; - -const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; -const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; -const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; -const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; -const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; -const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; - -// Assign -const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; - -// space2bacth batch2space -const std::string BATCH_SPACE_ATTR_BLOCK = "block"; -const std::string BATCH_SPACE_ATTR_PADDING = "padding"; - -// depth_to_space space_to_depth -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; - -// FakeQuantWithMinMaxVars -const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; -const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; - -// mobilenet_ssd_conv_fusion -const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; -const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; -const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; - -// lsh project -const std::string LSH_PROJ_TYPE = "lsh_project_type"; - -// log time stamp -const std::string LOG_TIME_STAMP_LOGID = "logid"; -const std::string LOG_TIME_STAMP_NOTIFY = "notify"; - -// ShapeN -const std::string SHAPEN_ATTR_N = "N"; -const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; -const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; - -// GatherV2 attr def -const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; -const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; - -// Reshape attr def -const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; -const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; - -// axis attr def -const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; - -const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; - -const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; -const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; - -// For constant folding -const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; - const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; @@ -882,8 +694,6 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; -const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; -const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; @@ -954,14 +764,7 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; -// functional ops attr -const std::string ATTR_NAME_WHILE_COND = "cond"; -const std::string ATTR_NAME_WHILE_BODY = "body"; - // used for label switch const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; - -const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; -const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; } // namespace ge diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index 321786a9..da4388f9 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -21,8 +21,9 @@ #include #include #include - +#include "framework/common/types.h" #include "graph/utils/graph_utils.h" + #include "debug/ge_log.h" #include "debug/ge_op_types.h" #include "external/graph/operator.h" diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index 1886ee66..c5e45516 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -28,7 +28,6 @@ #include #include #include -#include #include "./ge_context.h" #include "debug/ge_util.h" @@ -2000,60 +1999,4 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & GELOGD("Build exist nodes succ."); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec) { - std::vector stack_input; - std::map map_in_edge_num; - graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Sort nodes failed."); - return GRAPH_FAILED; - } - const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; - std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, - [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); - - std::queue stack; - NodePtr cur_node = nullptr; - std::map name_node_map; - vector nodes_name; - while (!stack_input.empty() || !stack.empty()) { - if (!stack.empty()) { - cur_node = stack.front(); - stack.pop(); - } else { - cur_node = stack_input.back(); - stack_input.pop_back(); - } - node_vec.emplace_back(cur_node); - compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); - for (const auto &iter : name_node_map) { - nodes_name.emplace_back(iter.first); - } - std::sort(nodes_name.begin(), nodes_name.end()); - for (const auto &iter : nodes_name) { - stack.push(name_node_map[iter]); - } - name_node_map.clear(); - nodes_name.clear(); - } - // If they are not equal, there is a closed loop - if (node_vec.size() != compute_graph->nodes_.size()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - itered_nodes_set.insert(node.get()); - } - GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", - compute_graph->nodes_.size(), node_vec.size()); - for (auto &node : compute_graph->nodes_) { - if (itered_nodes_set.count(node.get()) == 0) { - GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - } // namespace ge diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc index 072673c0..7b8ad3cd 100644 --- a/src/common/graph/utils/tensor_utils.cc +++ b/src/common/graph/utils/tensor_utils.cc @@ -282,7 +282,6 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_FRACTAL_Z_3D: case FORMAT_FRACTAL_Z_3D_TRANSPOSE: case FORMAT_NDC1HWC0: - case FORMAT_FRACTAL_Z_C04: graph_status = CalcElementCntByDims(dims, element_cnt); break; default: diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index e2996cce..028baf60 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -41,9 +41,9 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/framework/common) include_directories(${GE_SOURCE_DIR}/inc/runtime) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) @@ -55,7 +55,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" "common/ge/plugin_manager.cc" - "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" @@ -93,7 +92,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/kernel_task_info.cc" "graph/load/new_model_manager/task_info/label_goto_task_info.cc" "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -196,7 +194,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/passes/prevent_gradient_pass.cc" "graph/passes/print_op_pass.cc" "graph/passes/prune_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" "graph/passes/reshape_remove_pass.cc" "graph/passes/resource_pair_add_control_pass.cc" "graph/passes/resource_pair_remove_control_pass.cc" @@ -271,7 +268,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" "common/ge/plugin_manager.cc" - "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" "engine_manager/dnnengine_manager.cc" "ge_local_engine/engine/host_cpu_engine.cc" @@ -308,7 +304,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/kernel_task_info.cc" "graph/load/new_model_manager/task_info/label_goto_task_info.cc" "graph/load/new_model_manager/task_info/label_set_task_info.cc" - "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -409,7 +404,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/passes/prevent_gradient_pass.cc" "graph/passes/print_op_pass.cc" "graph/passes/prune_pass.cc" - "graph/passes/replace_with_empty_const_pass.cc" "graph/passes/reshape_remove_pass.cc" "graph/passes/resource_pair_add_control_pass.cc" "graph/passes/resource_pair_remove_control_pass.cc" @@ -474,7 +468,7 @@ target_link_libraries(ge_compiler ${slog} ${mmpa} ${msprof} - ${runtime_compiler} + ${runtime} ${resouce} rt dl) diff --git a/src/ge/client/CMakeLists.txt b/src/ge/client/CMakeLists.txt index a99b4eb1..c1111d8e 100755 --- a/src/ge/client/CMakeLists.txt +++ b/src/ge/client/CMakeLists.txt @@ -46,6 +46,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index 679b155b..9b9e5568 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -32,6 +32,8 @@ using domi::GetContext; using domi::OpRegistry; +using domi::RealPath; +using domi::StringUtils; using std::map; using std::string; using std::vector; diff --git a/src/ge/common/CMakeLists.txt b/src/ge/common/CMakeLists.txt index 1dce0b4d..a637888e 100755 --- a/src/ge/common/CMakeLists.txt +++ b/src/ge/common/CMakeLists.txt @@ -41,7 +41,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" - "formats/format_transfers/format_transfer_nchw_fz_c04.cc" "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" "formats/format_transfers/format_transfer_transpose.cc" @@ -80,6 +79,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/common/auth/file_saver.cc b/src/ge/common/auth/file_saver.cc index daa19448..04638ecf 100644 --- a/src/ge/common/auth/file_saver.cc +++ b/src/ge/common/auth/file_saver.cc @@ -17,6 +17,7 @@ #include "common/auth/file_saver.h" #include + #include #include #include @@ -28,6 +29,8 @@ #include "framework/common/debug/log.h" #include "framework/common/util.h" +using domi::CreateDirectory; +using domi::ModelEncryptType; using ge::ModelBufferData; namespace { @@ -267,4 +270,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co } return ret; } -} // namespace ge +} // namespace ge diff --git a/src/ge/common/auth/file_saver.h b/src/ge/common/auth/file_saver.h index d415746d..a4473050 100644 --- a/src/ge/common/auth/file_saver.h +++ b/src/ge/common/auth/file_saver.h @@ -26,26 +26,30 @@ #include "graph/buffer.h" #include "mmpa/mmpa_api.h" +using domi::ModelFileHeader; +using domi::ModelPartition; +using domi::ModelPartitionTable; + struct PROC_PARAM { uint8_t *model_name; - // ISV Ek buffer + /* ISV Ek buffer */ uint8_t *model_key; uint32_t model_key_len; - // ISV root certificate buffer + /* ISV root certificate buffer */ uint8_t *root_cert; uint32_t root_cert_len; - // ISV private key buffer + /* ISV private key buffer */ uint8_t *pri_key; uint32_t pri_key_len; - // Raw AI Module Image buffer + /* Raw AI Module Image buffer */ uint8_t *ai_image; uint32_t ai_image_len; - // ISV HW key buffer + /* ISV HW key buffer */ uint8_t *hw_key; uint32_t hw_key_len; }; @@ -62,11 +66,11 @@ using std::string; class FileSaver { public: - /// - /// @ingroup domi_common - /// @brief save model, no encryption - /// @return Status result - /// + /** + * @ingroup domi_common + * @brief save model, no encryption + * @return Status result + */ static Status SaveToFile(const string &file_path, const ge::ModelData &model, const ModelFileHeader *model_file_header = nullptr); @@ -80,26 +84,26 @@ class FileSaver { static Status SaveToFile(const string &file_path, const void *data, int len); protected: - /// - /// @ingroup domi_common - /// @brief Check validity of the file path - /// @return Status result - /// + /** + * @ingroup domi_common + * @brief Check validity of the file path + * @return Status result + */ static Status CheckPath(const string &file_path); static Status WriteData(const void *data, uint32_t size, int32_t fd); static Status OpenFile(int32_t &fd, const std::string &file_path); - /// - /// @ingroup domi_common - /// @brief save model to file - /// @param [in] file_path file output path - /// @param [in] file_header file header info - /// @param [in] data model data - /// @param [in] len model length - /// @return Status result - /// + /** + * @ingroup domi_common + * @brief save model to file + * @param [in] file_path file output path + * @param [in] file_header file header info + * @param [in] data model data + * @param [in] len model length + * @return Status result + */ static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, int len); diff --git a/src/ge/common/context/ctx.cc b/src/ge/common/context/ctx.cc index f6ae364d..34ba5d25 100644 --- a/src/ge/common/context/ctx.cc +++ b/src/ge/common/context/ctx.cc @@ -16,7 +16,6 @@ #include "framework/omg/omg_inner_types.h" -using ge::OmgContext; namespace domi { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { static OmgContext context; diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.cc b/src/ge/common/formats/format_transfers/datatype_transfer.cc index e5d21307..bac3a178 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.cc +++ b/src/ge/common/formats/format_transfers/datatype_transfer.cc @@ -134,6 +134,10 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result } auto trans_mode = iter->second; + if (args.src_data_size == 0) { + GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); + return PARAM_INVALID; + } int size = GetSizeByDataType(args.dst_data_type); if (size <= 0) { GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", @@ -145,12 +149,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result return PARAM_INVALID; } size_t total_size = static_cast(args.src_data_size * size); - result.length = total_size; - if (total_size == 0) { - GELOGI("In TransDataType, total_size is zero, has no data."); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); @@ -164,6 +162,7 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result return INTERNAL_ERROR; } result.data = dst; + result.length = total_size; return SUCCESS; } diff --git a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc index 40dc749d..28d713b5 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc @@ -27,9 +27,7 @@ namespace ge { namespace formats { namespace { -bool CheckDataTypeSupported(const DataType &data_type) { - return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); -} +bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { auto src_shape = args.src_shape; @@ -53,11 +51,10 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; } - auto cube_size = GetCubeSizeByDataType(args.src_data_type); - if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || + if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || - src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || - src_shape.at(kC1hwncoc0C0) != cube_size) { + src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || + src_shape.at(kC1hwncoc0C0) != kCubeSize) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -81,7 +78,6 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size auto c0 = args.src_shape.at(kC1hwncoc0C0); auto co = args.src_shape.at(kC1hwncoc0Co); auto c = args.dst_shape.at(kHwcnC); - auto cube_size = GetCubeSizeByDataType(args.src_data_type); int64_t cn = c * n; int64_t wcn = w * cn; int64_t coc0 = co * c0; @@ -97,8 +93,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size int64_t c_head_addr = w_head_addr + c_idx * n; for (int64_t n_idx = 0; n_idx < n; n_idx++) { int64_t dst_idx = c_head_addr + n_idx; - int64_t c1_idx = c_idx / cube_size; - int64_t c0_idx = c_idx % cube_size; + int64_t c1_idx = c_idx / kCubeSize; + int64_t c0_idx = c_idx % kCubeSize; int64_t co_idx = c0_idx; int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; auto src_offset = src_idx * size; @@ -134,11 +130,6 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu int size = GetSizeByDataType(args.src_data_type); int64_t total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc index dc8e1033..45808fa0 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc @@ -88,11 +88,6 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc index 11e3d270..86c6935d 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc @@ -89,11 +89,6 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index ff7b84a4..76834437 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -116,11 +116,6 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -189,11 +184,6 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index f3d06496..aedc7589 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -119,11 +119,6 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = total_ele_cnt * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -199,11 +194,6 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -269,11 +259,6 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index d5507765..be0c3abb 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -117,11 +117,6 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", @@ -194,11 +189,6 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc index b0eebcfa..3453c232 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc @@ -133,12 +133,6 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc index 9f8d9e39..6f616051 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nchw.cc @@ -133,12 +133,6 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; @@ -146,7 +140,6 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str(), total_size); - if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), diff --git a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc index 9a1e5f3b..57b840af 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc @@ -132,12 +132,6 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc index 7101256a..fbadb4c3 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc @@ -27,20 +27,16 @@ namespace ge { namespace formats { namespace { -bool CheckDataTypeSupported(const DataType &data_type) { - return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); -} +bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } -Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector &src_shape, - std::vector &dst_shape) { - auto cube_size = GetCubeSizeByDataType(data_type); +Status TransShapeHwcnToC1hwncoc0(const std::vector &src_shape, std::vector &dst_shape) { dst_shape.clear(); - dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast(cube_size))); + dst_shape.push_back((src_shape.at(kHwcnC) - 1) / kCubeSize + 1); dst_shape.push_back(src_shape.at(kHwcnH)); dst_shape.push_back(src_shape.at(kHwcnW)); dst_shape.push_back(src_shape.at(kHwcnN)); - dst_shape.push_back(cube_size); - dst_shape.push_back(cube_size); + dst_shape.push_back(kCubeSize); + dst_shape.push_back(kCubeSize); if (!CheckShapeValid(dst_shape, kC1hwncoc0DimsNum)) { GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -69,7 +65,7 @@ Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { return PARAM_INVALID; } std::vector expect_dst_shape; - auto ret = TransShapeHwcnToC1hwncoc0(args.src_data_type, args.src_shape, expect_dst_shape); + auto ret = TransShapeHwcnToC1hwncoc0(args.src_shape, expect_dst_shape); if (ret != SUCCESS) { return ret; } @@ -169,12 +165,6 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; @@ -198,7 +188,7 @@ Status FormatTransferHwcnC1hwncoc0::TransShape(Format src_format, const std::vec GELOGE(PARAM_INVALID, "Failed to check src shape %s", ShapeToString(src_shape).c_str()); return PARAM_INVALID; } - return TransShapeHwcnToC1hwncoc0(data_type, src_shape, dst_shape); + return TransShapeHwcnToC1hwncoc0(src_shape, dst_shape); } else { return UNSUPPORTED; } diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc index 57ab1266..0a5af5ff 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { } if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || - src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { + src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNchwC) - 1) / c0 + 1) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -130,12 +130,6 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc index e68e54de..92fd5772 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { } if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || - src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { + src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNhwcC) - 1) / c0 + 1) { GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); return PARAM_INVALID; @@ -130,12 +130,6 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc deleted file mode 100644 index 638cc9eb..00000000 --- a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.cc +++ /dev/null @@ -1,314 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "common/formats/format_transfers/format_transfer_nchw_fz_c04.h" -#include "common/formats/format_transfers/format_transfer_transpose.h" - -#include -#include -#include - -#include "common/formats/utils/formats_definitions.h" -#include "common/formats/utils/formats_trans_utils.h" -#include "common/util.h" -#include "framework/common/debug/ge_log.h" -#include "graph/utils/type_utils.h" - -/** 【Explain about transfer from nchw to FZ_CO4】 - * First Step: Padding in N and C axis. Here C must be less or equal than 4 - * After Padding, it will be like (n = ceil(n,16)*16, 4, h, w) - * Second Step: transpose. It will be like (n = ceil(n,16)*16, h, w, 4) - * Third Step: View the 4D as 2D , first dim is N, second dim is h*w*c. - * Padding to (N, ceil(Z/16)*16) - * Last Step: View the (N, ceil(Z/16)*16) as 4D (N/16, 16, C/16, 16) and transpose to (C/16, N/16, 16, 16) - */ - -namespace ge { -namespace formats { -namespace { - -constexpr int64_t kMaxDimsNumC = 4; - -Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } - -Status TransShape(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector &dst_shape) { - auto c0 = GetCubeSizeByDataType(data_type); - if (c0 < 0) { - return UNSUPPORTED; - } - auto chw = c * h * w; - - auto first_dim = Ceil(chw, c0); - auto no = Ceil(n, static_cast(c0)); - - dst_shape.clear(); - dst_shape.push_back(first_dim); - dst_shape.push_back(no); - dst_shape.push_back(c0); - dst_shape.push_back(c0); - - if (!IsShapeValid(dst_shape)) { - GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -Status TransShapeNchwToFzC04(const std::vector &src_shape, DataType data_type, - std::vector &dst_shape) { - if (!CheckShapeValid(src_shape, kNchwDimsNum)) { - return PARAM_INVALID; - } - - auto n = src_shape.at(kNchwN); - auto c = src_shape.at(kNchwC); - auto h = src_shape.at(kNchwH); - auto w = src_shape.at(kNchwW); - return TransShape(n, c, h, w, data_type, dst_shape); -} - -Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { - int64_t n = args.src_shape.at(kNchwN); - int64_t c = args.src_shape.at(kNchwC); - int64_t h = args.src_shape.at(kNchwH); - int64_t w = args.src_shape.at(kNchwW); - - int64_t c0 = GetCubeSizeByDataType(args.src_data_type); - int size = GetSizeByDataType(args.src_data_type); - - auto data = args.data; - TransResult trans_result_1; - std::vector perm_arg_1 = {0, 2, 3, 1}; - std::vector expect_shape = {n, h, w, c}; - auto ret = ge::formats::Transpose(data, args.src_shape, args.src_data_type, perm_arg_1, trans_result_1); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); - return NOT_CHANGED; - } - - TransArgs args_tmp = args; - args_tmp.src_shape = expect_shape; - args_tmp.data = trans_result_1.data.get(); - // check size it should be same with original - size_t expect_size = n * c * h * w * size; // before has do check about mul - if (trans_result_1.length != expect_size) { - GELOGE(INTERNAL_ERROR, "size is not match after transpose!"); - return NOT_CHANGED; - } - - /* prepare for padding in chw*/ - int64_t tmp = h * w * c; - int64_t n_o = Ceil(n, static_cast(c0)); - int64_t c_o = c0; - int64_t h_o = Ceil(tmp, c0); - int64_t w_o = c0; - std::vector shape_o = {n_o, c_o, h_o, w_o}; - - // data overflow check totally - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); - return INTERNAL_ERROR); - auto t1 = h_o * w_o; - auto t2 = n_o * c_o; - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); - return INTERNAL_ERROR); - - int64_t total_ele_cnt = n_o * c_o * h_o * w_o; - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); - return INTERNAL_ERROR); - int64_t dst_size = total_ele_cnt * size; - if (dst_size == 0) { - result.length = 0; - return SUCCESS; - } - - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", - TypeUtils::FormatToSerialString(args.src_format).c_str(), - TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } - auto retMem = memset_s(dst.get(), dst_size, 0, dst_size); - if (retMem != EOK) { - GELOGE(INTERNAL_ERROR, "memst failed!"); - return INTERNAL_ERROR; - } - // copy data - auto block = c * h * w * size; - auto stride = h_o * w_o * size; - auto p_s = trans_result_1.data.get(); - auto p_d = dst.get(); - auto protectSize = dst_size; - for (auto k = 0; k < n; k++) { - ret = memcpy_s(p_d + k * stride, protectSize, p_s + k * block, block); - if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); - return INTERNAL_ERROR; - } - protectSize = protectSize - block; - } - - // transpose : 2,0,1,3 - std::vector perm_arg_2 = {2, 0, 1, 3}; - ret = ge::formats::Transpose(dst.get(), shape_o, args.src_data_type, perm_arg_2, result); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from NCHW to HWCN"); - return NOT_CHANGED; - } - - return SUCCESS; -} - -Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr &dst) { - args_tmp = args; - auto src_shape = args_tmp.src_shape; - if (!CheckShapeValid(src_shape, kNchwDimsNum)) { - return PARAM_INVALID; - } - int64_t c0 = GetCubeSizeByDataType(args.src_data_type); - - auto n = src_shape.at(kNchwN); - auto c = src_shape.at(kNchwC); - auto h = src_shape.at(kNchwH); - auto w = src_shape.at(kNchwW); - - if (c > kMaxDimsNumC) { - GELOGE(PARAM_INVALID, "Invalie dim c num[%lu].It should be in (0,4]", c); - return PARAM_INVALID; - } - - auto n_o = Ceil(n, c0) * c0; - auto c_o = kMaxDimsNumC; - auto h_o = h; - auto w_o = w; - args_tmp.src_shape.at(kNchwN) = n_o; - args_tmp.src_shape.at(kNchwC) = c_o; - args_tmp.src_shape.at(kNchwH) = h_o; - args_tmp.src_shape.at(kNchwW) = w_o; - - // data overflow check - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(h_o, w_o), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", h_o, w_o); - return INTERNAL_ERROR); - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(n_o, c_o), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", n_o, c_o); - return INTERNAL_ERROR); - auto t1 = h_o * w_o; - auto t2 = n_o * c_o; - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(t1, t2), GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", t1, t2); - return INTERNAL_ERROR); - - int64_t total_ele_cnt = n_o * c_o * h_o * w_o; - int size = GetSizeByDataType(args.src_data_type); - GE_IF_BOOL_EXEC(!CheckInt64MulOverflow(total_ele_cnt, size), - GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); - return INTERNAL_ERROR); - - int64_t dst_size = total_ele_cnt * size; - if (dst_size == 0) { - return SUCCESS; - } - - dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", - TypeUtils::FormatToSerialString(args.src_format).c_str(), - TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } - auto ret = memset_s(dst.get(), dst_size, 0, dst_size); - if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "memst failed!"); - return INTERNAL_ERROR; - } - - auto p_s = args.data; - auto p_d = dst.get(); - auto block = h * w * size; - auto protectSize = dst_size; - - for (int i = 0; i < n; i++) { - for (int j = 0; j < c; j++) { - ret = memcpy_s(p_d + (i * c_o * h_o * w_o + j * h_o * w_o) * size, protectSize, - p_s + (i * c * h * w + j * h * w) * size, block); - if (ret != EOK) { - GELOGE(INTERNAL_ERROR, "memcpy_s failed!"); - return INTERNAL_ERROR; - } - protectSize = protectSize - block; - } - } - args_tmp.data = dst.get(); - - return SUCCESS; -} -} // namespace - -Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult &result) { - GELOGD("Begin to trans format from %s to %s, src shape %s, data type %s, dst shape %s", - TypeUtils::FormatToSerialString(args.src_format).c_str(), - TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), - TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), ShapeToString(args.dst_shape).c_str()); - TransArgs args_tmp = args; - std::shared_ptr dst = nullptr; - auto ret = PaddingNC(args, args_tmp, dst); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Padding in NC axis failed!"); - return ret; - } - - std::vector expect_shape; - ret = TransShape(args_tmp.src_format, args_tmp.src_shape, args_tmp.src_data_type, args_tmp.dst_format, expect_shape); - if (ret != SUCCESS) { - return ret; - } - - if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { - GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", - TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), - TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), - ShapeToString(expect_shape).c_str()); - return PARAM_INVALID; - } - - if (args_tmp.src_format == FORMAT_NCHW && args_tmp.dst_format == FORMAT_FRACTAL_Z_C04) { - return TransFormatFromNchwToFzC04(args_tmp, result); - } - - return UNSUPPORTED; -} - -Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vector &src_shape, - DataType data_type, Format dst_format, std::vector &dst_shape) { - if (CheckDataTypeSupport(data_type) != SUCCESS) { - return UNSUPPORTED; - } - if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z_C04) { - return TransShapeNchwToFzC04(src_shape, data_type, dst_shape); - } - - return UNSUPPORTED; -} - -REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) - -} // namespace formats -} // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h b/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h deleted file mode 100644 index a1232d47..00000000 --- a/src/ge/common/formats/format_transfers/format_transfer_nchw_fz_c04.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ -#define GE_COMMON_FORMATS_FORMAT_TRANSFERS_NCHW_FZC04_H_ - -#include - -#include "common/formats/format_transfers/format_transfer.h" - -namespace ge { -namespace formats { -class FormatTransferNchwToFZC04 : public FormatTransfer { - public: - Status TransFormat(const ge::formats::TransArgs &args, ge::formats::TransResult &result) override; - Status TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, - std::vector &dst_shape) override; -}; -} // namespace formats -} // namespace ge - -#endif // GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_FRACTAL_Z_H_ diff --git a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc index b4e92cbc..7b90c6a8 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc @@ -40,7 +40,7 @@ Status TransShapeNchwToNc1hwc0(const std::vector &src_shape, DataType d } dst_shape.clear(); dst_shape.push_back(src_shape.at(kNchwN)); - dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); + dst_shape.push_back((src_shape.at(kNchwC) - 1) / c0 + 1); dst_shape.push_back(src_shape.at(kNchwH)); dst_shape.push_back(src_shape.at(kNchwW)); dst_shape.push_back(c0); @@ -74,8 +74,25 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { return SUCCESS; } +} // namespace -Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { +Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { + if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { + return PARAM_INVALID; + } + // Guarantee the validity of parameters in check function + int size = GetSizeByDataType(args.src_data_type); + auto total_size = GetItemNumByShape(args.dst_shape) * size; + if (total_size <= 0) { + GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, + ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); + return PARAM_INVALID; + } + GELOGD( + "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " + "%s, dst shape %s memory size %ld", + ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), + ShapeToString(args.dst_shape).c_str(), total_size); std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { GELOGE(OUT_OF_MEMORY, @@ -152,39 +169,6 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in result.length = static_cast(total_size); return SUCCESS; } -} // namespace - -Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { - if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { - return PARAM_INVALID; - } - // Guarantee the validity of parameters in check function - int size = GetSizeByDataType(args.src_data_type); - auto total_size = GetItemNumByShape(args.dst_shape) * size; - if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, - ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); - return PARAM_INVALID; - } - GELOGD( - "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " - "%s, dst shape %s memory size %ld", - ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), - ShapeToString(args.dst_shape).c_str(), total_size); - if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", - ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), - ShapeToString(args.dst_shape).c_str(), total_size); - return INTERNAL_ERROR; - } - return SUCCESS; -} Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector &src_shape, DataType data_type, Format dst_format, std::vector &dst_shape) { diff --git a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index a5be94ff..26e533fc 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -38,7 +38,7 @@ Status TransShapeNhwcToNc1hwc0(const std::vector &src_shape, DataType d } dst_shape.clear(); dst_shape.push_back(src_shape.at(kNhwcN)); - dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); + dst_shape.push_back((src_shape.at(kNhwcC) - 1) / c0 + 1); dst_shape.push_back(src_shape.at(kNhwcH)); dst_shape.push_back(src_shape.at(kNhwcW)); dst_shape.push_back(c0); @@ -161,12 +161,6 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult int size = GetSizeByDataType(args.src_data_type); auto total_size = GetItemNumByShape(args.dst_shape) * size; if (total_size <= 0) { - int64_t src_size = GetItemNumByShape(args.src_shape); - if (total_size == 0 && src_size == 0) { - result.length = static_cast(total_size); - return SUCCESS; - } - GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); return PARAM_INVALID; diff --git a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc index ec309543..9b3457ca 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -51,8 +51,8 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &src_shape, Data int64_t dst_ele_num = GetItemNumByShape(dst_shape); int64_t data_size = GetSizeByDataType(src_data_type); int64_t dst_size = data_size * dst_ele_num; + std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); - if (dst_ele_num == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } - std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); int64_t dst_index = 0; std::vector dst_indexes(dst_shape.size()); while (dst_index < dst_ele_num) { diff --git a/src/ge/common/formats/formats.cc b/src/ge/common/formats/formats.cc index d01d055b..938f0888 100644 --- a/src/ge/common/formats/formats.cc +++ b/src/ge/common/formats/formats.cc @@ -24,7 +24,6 @@ #include #include -#include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/type_utils.h" @@ -39,13 +38,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg TypeUtils::FormatToSerialString(args.dst_format).c_str()); return UNSUPPORTED; } - - auto src_shape_size = GetItemNumByShape(args.src_shape); - if (args.data == nullptr && src_shape_size != 0) { + if (args.data == nullptr) { GELOGE(PARAM_INVALID, "Invalid input null data"); return PARAM_INVALID; } - return transfer->TransFormat(args, result); } @@ -75,12 +71,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); return UNSUPPORTED; } - - if (args.data == nullptr && args.src_data_size != 0) { - GELOGE(PARAM_INVALID, "Invalid input null data"); - return PARAM_INVALID; - } - return transfer->TransDataType(args, result); } diff --git a/src/ge/common/formats/utils/formats_trans_utils.cc b/src/ge/common/formats/utils/formats_trans_utils.cc index 23da0f74..35a0a073 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.cc +++ b/src/ge/common/formats/utils/formats_trans_utils.cc @@ -69,11 +69,11 @@ bool IsShapeValid(const std::vector &shape) { } int64_t num = 1; for (auto dim : shape) { - if (dim < 0) { - GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); + if (dim < 1) { + GELOGE(PARAM_INVALID, "Invalid zero dim in the shape %s", ShapeToString(shape).c_str()); return false; } - if (dim != 0 && kShapeItemNumMAX / dim < num) { + if (kShapeItemNumMAX / dim < num) { GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); return false; } diff --git a/src/ge/common/formats/utils/formats_trans_utils.h b/src/ge/common/formats/utils/formats_trans_utils.h index a8fbd09b..310aaf38 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.h +++ b/src/ge/common/formats/utils/formats_trans_utils.h @@ -64,9 +64,6 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst); template T Ceil(T n1, T n2) { - if (n1 == 0) { - return 0; - } return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; } diff --git a/src/ge/common/fp16_t.h b/src/ge/common/fp16_t.h index 34908b95..854df58f 100644 --- a/src/ge/common/fp16_t.h +++ b/src/ge/common/fp16_t.h @@ -601,4 +601,4 @@ int16_t GetManBitLength(T man) { return len; } }; // namespace ge -#endif // GE_COMMON_FP16_T_H_ +#endif // GE_COMMON_FP16_T_H_ \ No newline at end of file diff --git a/src/ge/common/ge/plugin_manager.cc b/src/ge/common/ge/plugin_manager.cc index 29cb8a83..f2eb8f5f 100644 --- a/src/ge/common/ge/plugin_manager.cc +++ b/src/ge/common/ge/plugin_manager.cc @@ -27,7 +27,6 @@ #include #include "framework/common/debug/log.h" -#include "framework/common/util.h" namespace ge { static const int kMaxNumOfSo = 64; @@ -101,7 +100,7 @@ Status PluginManager::LoadSo(const string &path, const vector &func_chec } std::string file_name = single_path.substr(single_path.rfind('/') + 1, string::npos); - string file_path_dlopen = RealPath(single_path.c_str()); + string file_path_dlopen = domi::RealPath(single_path.c_str()); if (file_path_dlopen.empty()) { GELOGW("Failed to get realpath of %s!", single_path.c_str()); continue; @@ -226,7 +225,7 @@ Status PluginManager::Load(const string &path, const vector &func_check_ } std::string canonical_path_str = (std::string(canonical_path) + "/" + file_name); - string file_path_dlopen = RealPath(canonical_path_str.c_str()); + string file_path_dlopen = domi::RealPath(canonical_path_str.c_str()); if (file_path_dlopen.empty()) { GELOGW("failed to get realpath of %s", canonical_path_str.c_str()); continue; diff --git a/src/ge/common/helper/model_cache_helper.cc b/src/ge/common/helper/model_cache_helper.cc deleted file mode 100644 index 58c82138..00000000 --- a/src/ge/common/helper/model_cache_helper.cc +++ /dev/null @@ -1,1707 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include - -#include "common/ge/ge_util.h" -#include "common/helper/model_cache_helper.h" -#include "common/types.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_types.h" -#include "framework/common/helper/model_helper.h" -#include "framework/common/util.h" -#include "graph/detail/attributes_holder.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/load/new_model_manager/davinci_model_parser.h" -#include "graph/model.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "init/gelib.h" -#include "proto/ge_ir.pb.h" - -using namespace std; - -namespace { -const char *const kGraphName = "temp_name"; -const char *const kDpop = "DPOP"; -const char *const kDpopFunction = "dpop_function"; -// Keys of json -const char *const kNodeNum = "nodeNum"; -const char *const kEdgeNum = "edgeNum"; -const char *const kGraphHash = "graphHash"; -const char *const kNodeHash = "nodeHash"; -const char *const kHash = "hash"; -const char *const kSessionId = "sessionId"; -const char *const kDeviceId = "deviceId"; -const char *const kJobId = "jobId"; -const char *const kGraphMemMaxSize = "graphMemMaxSize"; -const char *const kVarMemMaxSize = "varMemMaxSize"; -const char *const kVarMemLogicBase = "varMemLogicBase"; -const char *const kUseMaxMemSize = "useMaxMemSize"; -const char *const kMemResourceMap = "memResourceMap"; -const char *const kMemType = "memType"; -const char *const kTotalSize = "totalSize"; -const char *const kVarMemSize = "varMemSize"; -const char *const kVarResource = "varResource"; -const char *const kVarAddrMgrMap = "varAddrMgrMap"; -const char *const kName = "name"; -const char *const kAddress = "address"; -const char *const kOffset = "offset"; -const char *const kMemoryType = "memoryType"; -const char *const kTensorDesc = "tensorDesc"; -const char *const kDataType = "dataType"; -const char *const kShape = "shape"; -const char *const kLayout = "layout"; -const char *const kOriginDataType = "originDataType"; -const char *const kOriginShape = "originShape"; -const char *const kOriginLayout = "originLayout"; -const char *const kRealDimCnt = "realDimCnt"; -const char *const kCurVarTensorDescMap = "curVarTensorDescMap"; -const char *const kTransRoads = "transRoads"; -const char *const kTransRoad = "transRoad"; -const char *const kNodeType = "nodeType"; -const char *const kInputTensorDesc = "inputTensorDesc"; -const char *const kOutputTensorDesc = "outputTensorDesc"; -const char *const kChangedGraphId = "changedGraphId"; -const char *const kAllocatedGraphId = "allocatedGraphId"; -const char *const kGraphId = "graphId"; -const char *const kVarBroadcastInfo = "varBroadcastInfo"; -const char *const kBroadcastName = "broadcastName"; -const char *const kIdx = "idx"; -const char *const kInputOffset = "inputOffset"; -const char *const kInputSize = "inputSize"; -const char *const kOutputOffset = "outputOffset"; -const char *const kOutputSize = "outputSize"; -// Suffix of cache files -const char *const kBeforeVarManagerSuffix = "_before_build_var_manager.json"; -const char *const kAfterVarManagerSuffix = "_after_build_var_manager.json"; -const char *const kManifestSuffix = ".manifest"; -const char *const kOmSuffix = ".om"; -} // namespace - -namespace ge { -map ModelCacheHelper::graph_id_run_times_; -ModelCacheHelper::ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph) - : session_id_(session_id), - graph_id_(graph_id), - compute_graph_(compute_graph), - is_cache_path_valid_for_output(false) { - if (graph_id_run_times_.count(graph_id) == 0) { - graph_id_run_times_[graph_id] = 1; - } else { - graph_id_run_times_[graph_id] = graph_id_run_times_[graph_id] + 1; - } - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - std::string cache_path = instance_ptr->GetIncreBuildCachePath(); - GELOGD("Incre build path conf: %s", cache_path.c_str()); - string fake_file_path = cache_path + to_string(graph_id_) + kManifestSuffix; - if (CheckOutputPathValid(fake_file_path)) { - is_cache_path_valid_for_output = true; - } else { - GELOGW("Invalid cache path for output."); - } - std::string real_cache_path = RealPath(cache_path.c_str()); - if (real_cache_path.empty()) { - GELOGW("Invalid incre build cache path conf: %s", cache_path.c_str()); - return; - } - cache_path_ = real_cache_path + '/'; - GELOGD("Try to use incre build cache path: %s", cache_path_.c_str()); - } -} - -bool ModelCacheHelper::IsModelCacheHit() const { - CacheInfo cache_info; - if (GetCacheInfo(cache_info) != SUCCESS) { - GELOGI("Get cache info of graph id[%u] failed.", graph_id_); - return false; - } - // Check number of nodes and edges first. - if (cache_info.node_num != compute_graph_->GetDirectNodesSize()) { - GELOGI("Graph id[%u] cache miss: the node number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - if (cache_info.edge_num != edge_num) { - GELOGI("Graph id[%u] cache miss: the edge number of the graph does not match the cache info.", graph_id_); - return false; - } - size_t compute_graph_hash; - auto ret = GetComputeGraphHash(compute_graph_hash); - if (ret != SUCCESS || cache_info.graph_hash != compute_graph_hash) { - GELOGI("Graph id[%u] cache miss: the hash code of the graph does not match the cache info.", graph_id_); - return false; - } - if (!IsNodeHashSameAsCache(cache_info.nodes_hash)) { - GELOGI("Graph id[%u] cache miss: the hash code of node does not match the cache info.", graph_id_); - return false; - } - - string var_manager_cache = to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kBeforeVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return false; - } - if (!IsVarManagerSameAsCache(var_manager_json)) { - GELOGI("Graph id[%u] cache miss: the VarManager dos not match the cache info.", graph_id_); - return false; - } - GELOGI("Graph id[%u] cache hit.", graph_id_); - return true; -} - -Status ModelCacheHelper::RefreshComputeGraph(const ComputeGraphPtr &compute_graph) { - if (compute_graph->IsValid()) { - compute_graph_ = compute_graph; - var_names_.clear(); - for (const auto &node : compute_graph_->GetDirectNode()) { - bool is_variable = (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || - (node->GetType() == VARHANDLEOP) || (node->GetType() == CONSTANTOP); - if (!is_variable) { - continue; - } - var_names_.insert(node->GetName()); - } - return SUCCESS; - } else { - GELOGW("Invalid compute graph."); - return FAILED; - } -} - -Status ModelCacheHelper::ClearCache(uint32_t graph_id) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return SUCCESS; - } - string manifest_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string manifest_file_path = RealPath(manifest_file.c_str()); - int ret; - if (!manifest_file_path.empty()) { - ret = remove(manifest_file_path.c_str()); - // If remove file failed, print the warning log - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", manifest_file_path.c_str()); - } - } - string before_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string before_var_manager_file_path = RealPath(before_var_manager_file.c_str()); - if (!before_var_manager_file_path.empty()) { - ret = remove(before_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", before_var_manager_file_path.c_str()); - } - } - string after_var_manager_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string after_var_manager_file_path = RealPath(after_var_manager_file.c_str()); - if (!after_var_manager_file_path.empty()) { - ret = remove(after_var_manager_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", after_var_manager_file_path.c_str()); - } - } - string om_file = cache_path_ + to_string(graph_id) + kManifestSuffix; - string om_file_path = RealPath(om_file.c_str()); - if (!om_file_path.empty()) { - ret = remove(om_file_path.c_str()); - if (ret != 0) { - GELOGW("Clear cache [%s] failed.", om_file_path.c_str()); - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarManagerFromCache() const { - string var_manager_cache = to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kAfterVarManagerSuffix; - Json var_manager_json; - if (LoadJsonFromFile(var_manager_cache, var_manager_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", var_manager_cache.c_str()); - return FAILED; - } - - Json mem_resource_json = move(var_manager_json[kMemResourceMap]); - auto ret = RecoverMemResource(mem_resource_json); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[MemResource]"); - return FAILED; - } - Json var_resource_json = move(var_manager_json[kVarResource]); - ret = RecoverAllocatedGraphId(var_resource_json[kAllocatedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[AllocatedGraphId]"); - return FAILED; - } - ret = RecoverChangedGraphId(var_resource_json[kChangedGraphId]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[ChangedGraphId]"); - return FAILED; - } - ret = RecoverBroadcastInfo(var_resource_json[kVarBroadcastInfo]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarBroadcastInfo]"); - return FAILED; - } - ret = RecoverVarAddrAndTensorDesc(var_resource_json[kVarAddrMgrMap]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[VarAddrMgrMap & CurVarTensorDesc]"); - return FAILED; - } - ret = RecoverTransRoads(var_resource_json[kTransRoads]); - if (ret != SUCCESS) { - GELOGW("Recover VarManager from cache failed.[TransRoads]"); - return FAILED; - } - GELOGI("Recover VarManager from cache[%s] success.", cache_path_.c_str()); - return SUCCESS; -} - -Status ModelCacheHelper::RecompileNodes(GeModelPtr &ge_model) { - std::shared_ptr instance = ge::GELib::GetInstance(); - if (instance == nullptr || !instance->InitFlag()) { - GELOGW("RecompileNodes failed."); - return ge::GE_CLI_GE_NOT_INITIALIZED; - } - auto compute_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - vector nodes; - for (auto &node : compute_graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - string kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - // reset op kernel lib - (void)instance->DNNEngineManagerObj().GetDNNEngineName(op_desc); - kernel_lib_name = op_desc->GetOpKernelLibName(); - if (kernel_lib_name.empty()) { - GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); - return INTERNAL_ERROR; - } - } - OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); - if (kernel_info == nullptr) { - GELOGW("Get op %s ops kernel info store failed", node->GetName().c_str()); - return INTERNAL_ERROR; - } - auto ge_desc = MakeShared(op_desc); - if (ge_desc == nullptr) { - GELOGE(GE_GRAPH_MEMORY_ALLOC_FAILED, "Fail to malloc op desc."); - return FAILED; - } - // TBE compile op - vector node_vec = {node}; - auto ret = kernel_info->CompileOp(node_vec); - if (ret != ge::SUCCESS) { - GELOGW("Compile single op failed, node name is %s", node->GetName().c_str()); - return ret; - } - } - // Reset TBE Kernels - TBEKernelStore tbe_kernel_store; - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); - tbe_kernel_store.AddTBEKernel(tbe_kernel); - GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); - } - if (!tbe_kernel_store.Build()) { - GELOGW("TBE Kernels store build failed!"); - return FAILED; - } - ge_model->SetTBEKernelStore(tbe_kernel_store); - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHash(map &hash_map) const { - vector nodes; - GraphUtils::TopologicalSortingByName(compute_graph_, nodes); - ModelSerializeImp model_serialize_imp; - std::hash node_hash; - for (const auto &node : nodes) { - if (node == nullptr) { - continue; - } - proto::OpDef op_def; - bool is_framework_op = (node->GetType() == FRAMEWORKOP); - string type; - bool is_dpop = false; - string origin_dpop_name; - if (is_framework_op) { - if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { - GELOGI("Get original type of framework op[%s], %s.", node->GetName().c_str(), type.c_str()); - if (type == kDpop) { - GELOGI("DPOP op found:%s.", node->GetName().c_str()); - origin_dpop_name = node->GetName(); - node->GetOpDesc()->SetName(kDpopFunction); - is_dpop = true; - } - } else { - GELOGW("Get original type of framework op[%s] failed.", node->GetName().c_str()); - } - } - bool ret = model_serialize_imp.SerializeNode(node, &op_def, is_framework_op); - op_def.set_id(0); - if (is_dpop) { - node->GetOpDesc()->SetName(origin_dpop_name); - } - if (!ret) { - GELOGW("Fail to serialize node[%].", node->GetName().c_str()); - return INTERNAL_ERROR; - } - string prototxt; - ret = google::protobuf::TextFormat::PrintToString(op_def, &prototxt); - if (!ret) { - GELOGW("Print OpDef to string failed."); - hash_map.clear(); - return INTERNAL_ERROR; - } - size_t hash_code = node_hash(prototxt); - if (is_dpop) { - hash_map[kDpopFunction] = hash_code; - } else { - hash_map[node->GetName()] = hash_code; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetComputeGraphHash(size_t &hash) const { - proto::GraphDef graph_proto; - ModelSerializeImp model_serialize_imp; - // The name of compute graph may be generated randomly, so replace it temporarily. - const string origin_name = compute_graph_->GetName(); - compute_graph_->SetName(kGraphName); - bool serialize_ret = model_serialize_imp.SerializeGraph(compute_graph_, &graph_proto); - graph_proto.clear_op(); - if (!serialize_ret) { - GELOGW("Serialize graph failed."); - hash = 0; - return INTERNAL_ERROR; - } - compute_graph_->SetName(origin_name); - // Generate proto text of GraphDef - string prototxt; - bool print_ret = google::protobuf::TextFormat::PrintToString(graph_proto, &prototxt); - if (!print_ret) { - GELOGW("Print GraphDef to string failed."); - hash = 0; - return INTERNAL_ERROR; - } - // Get the hash code of proto text - std::hash graph_hash; - hash = graph_hash(prototxt); - return SUCCESS; -} - -Status ModelCacheHelper::SaveJsonToFile(const string &file_name, const Json &json) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return PARAM_INVALID; - } - // Check whether the manifest exists, if not, create it. - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - const int FILE_AUTHORITY = 0600; - int fd = open(path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, FILE_AUTHORITY); - if (fd < 0) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - if (close(fd) != 0) { - GELOGW("Fail to close the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - - // Write json into cache file - ofstream ofs; - ofs.open(path); - if (!ofs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - ofs << json << std::endl; - ofs.close(); - return SUCCESS; -} - -Status ModelCacheHelper::LoadJsonFromFile(const string &file_name, Json &json) const { - if (!json.is_null()) { - GELOGW("Input param json type should be null."); - return PARAM_INVALID; - } - string real_path = RealPath(cache_path_.c_str()); - if (real_path.empty()) { - GELOGW("File path is invalid. please check cache path: %s", cache_path_.c_str()); - return FAILED; - } - const string path = cache_path_ + file_name; - if (!CheckInputPathValid(path)) { - GELOGW("Invalid cache path for input:%s.", path.c_str()); - return FAILED; - } - string cache_real_path = RealPath(path.c_str()); - if (real_path.empty()) { - GELOGI("File[%s] is not found.", path.c_str()); - return FAILED; - } - // Read json from cache file - ifstream ifs; - ifs.open(path); - if (!ifs.is_open()) { - GELOGW("Fail to open the file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - ifs >> json; - if (!json.is_object()) { - GELOGW("Fail to load the json file: %s.", path.c_str()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveCacheInfoToCache() const { - // Generate cache json - // example: {"edgeNum":6,"nodeNum":7,"graphCache":134714827475991356} - Json cache_json; - try { - cache_json[kNodeNum] = compute_graph_->GetDirectNodesSize(); - size_t edge_num = 0; - for (const auto &node : compute_graph_->GetDirectNode()) { - for (const auto &anchor : node->GetAllInAnchors()) { - edge_num += anchor->GetPeerAnchors().size(); - } - } - cache_json[kEdgeNum] = edge_num; - size_t hash = 0; - auto ret = GetComputeGraphHash(hash); - if (ret != SUCCESS) { - GELOGW("Error occur when generate graph hash code."); - return ret; - } - cache_json[kGraphHash] = hash; - Json nodes_hash_json; - ret = GetNodesHashMapJson(nodes_hash_json); - if (ret != SUCCESS) { - GELOGW("Error occur when generate nodes hash code."); - return ret; - } - cache_json[kNodeHash] = nodes_hash_json; - } catch (const std::exception &e) { - GELOGW("Fail to generate cache info json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - string cache_manifest = to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - - auto ret = SaveJsonToFile(cache_manifest, cache_json); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCacheInfo(CacheInfo &cache_info) const { - string cache_manifest = to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kManifestSuffix; - Json cache_json; - if (LoadJsonFromFile(cache_manifest, cache_json) != SUCCESS) { - GELOGW("Fail to load json from cache file: %s", cache_manifest.c_str()); - return INTERNAL_ERROR; - } - if (!cache_json.is_object()) { - GELOGW("Manifest should be a json object"); - return INTERNAL_ERROR; - } - try { - cache_info.node_num = cache_json[kNodeNum]; - cache_info.edge_num = cache_json[kEdgeNum]; - cache_info.graph_hash = cache_json[kGraphHash]; - Json nodes_hash_json = cache_json[kNodeHash]; - if (!(nodes_hash_json.is_null() || nodes_hash_json.is_array())) { - GELOGW("Nodes hash in cache be null or array."); - return FAILED; - } - for (const auto &iter : nodes_hash_json) { - cache_info.nodes_hash[iter[kName].get()] = iter[kHash].get(); - } - } catch (const std::exception &e) { - GELOGW("Fail to get info from json file. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -bool ModelCacheHelper::IsAllocatedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare allocated graph id info between json and VarManager - std::unordered_map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return false; - } - for (const auto &iter : allocated_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find allocated graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The allocated graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsNodeHashSameAsCache(const map &hash_map) const { - map cur_hash_map; - GetNodesHash(cur_hash_map); - if (hash_map.size() != cur_hash_map.size()) { - GELOGI("The number of hash code is different from cache info."); - return false; - } - for (const auto &iter : cur_hash_map) { - if (hash_map.count(iter.first) == 0) { - GELOGI("Node[%s] is not found in cache info.", iter.first.c_str()); - return false; - } - if (hash_map.at(iter.first) != iter.second) { - GELOGI("The hash code of node[%s] is different from cache info.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsMemResourceSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare var mem size info between json and VarManager - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return false; - } - for (const auto &iter : var_mem_size) { - int64_t mem_size = VarManager::Instance(session_id_)->GetVarMemSize(iter.first); - if (mem_size != iter.second) { - GELOGW("The var mem size of memory_type[%u] in cache is different from VarManager.", iter.first); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsChangedGraphIdSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable changed graph id info between json and VarManager - std::unordered_map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse ChangedGraphId from Json."); - return false; - } - for (const auto &iter : changed_graph_id) { - uint32_t graph_id = 0; - ret = VarManager::Instance(session_id_)->GetChangedGraphId(iter.first, graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to find changed graph id of var[%s].", iter.first.c_str()); - return false; - } - if (graph_id != iter.second) { - GELOGW("The changed graph id of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsCurVarTensorDescSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable tensor desc info between json and VarManager - std::unordered_map cur_var_tensor_desc; - auto ret = ParseCurVarTensorDescMapFromJson(json, cur_var_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to parse CurVarTensorDesc from Json."); - return false; - } - for (const auto &iter : cur_var_tensor_desc) { - GeTensorDesc tensor_desc; - ret = VarManager::Instance(session_id_)->GetCurVarDesc(iter.first, tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - uint32_t l_real_dim_cnt = 0; - uint32_t r_real_dim_cnt = 0; - TensorUtils::GetRealDimCnt(tensor_desc, l_real_dim_cnt); - TensorUtils::GetRealDimCnt(iter.second, r_real_dim_cnt); - if ((tensor_desc.GetDataType() != iter.second.GetDataType()) || - (tensor_desc.GetOriginDataType() != iter.second.GetOriginDataType()) || - (tensor_desc.GetFormat() != iter.second.GetFormat()) || - (tensor_desc.GetOriginFormat() != iter.second.GetOriginFormat()) || - (tensor_desc.GetShape().ToString() != iter.second.GetShape().ToString()) || - (tensor_desc.GetOriginShape().ToString() != iter.second.GetOriginShape().ToString()) || - (l_real_dim_cnt != r_real_dim_cnt)) { - GELOGW("The var tensor desc of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsVarAddrMgrMapSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare variable address info between json and VarManager - std::vector> var_addr_mgr_vector; - std::unordered_set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return false; - } - for (const auto &iter : var_addr_mgr_vector) { - uint8_t *dev_ptr = nullptr; - rtMemType_t memory_type; - ret = VarManager::Instance(session_id_)->GetVarAddr(iter.first, iter.second.tensor_desc, &dev_ptr, memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to find tensor desc of var[%s].", iter.first.c_str()); - return false; - } - // Compare memory type and logic address - if (iter.second.memory_type != memory_type || iter.second.address != dev_ptr) { - GELOGW("The VarAddrMgr of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsBroadcastInfoSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare broadcast info between json and VarManager - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return false; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - if (VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, iter.first, broadcast_info) != SUCCESS) { - GELOGW("Fail to find broadcast info of var[%s].", iter.first.c_str()); - return false; - } - if (iter.second.var_name != broadcast_info.var_name || iter.second.idx != broadcast_info.idx || - iter.second.input_size != broadcast_info.input_size || - iter.second.input_offset != broadcast_info.input_offset || - iter.second.output_size != broadcast_info.output_size || - iter.second.output_offset != broadcast_info.output_offset) { - GELOGW("The BroadcastInfo of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelCacheHelper::IsTransRoadsSameAsCache(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return false; - } - // Compare trans road between json and VarManager - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return false; - } - for (const auto &iter : trans_roads) { - VarTransRoad *trans_road; - trans_road = VarManager::Instance(session_id_)->GetTransRoad(iter.first); - if (trans_road == nullptr) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return false; - } - if (trans_road->size() != iter.second.size()) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - // Compare every trans node in trans road. - for (size_t idx = 0; idx < trans_road->size(); idx += 1) { - if (!(trans_road->at(idx).node_type == iter.second.at(idx).node_type && - trans_road->at(idx).input == iter.second.at(idx).input && - trans_road->at(idx).output == iter.second.at(idx).output)) { - GELOGW("The TransRoad of variable[%s] in cache is different from VarManager.", iter.first.c_str()); - return false; - } - } - } - return true; -} - -bool ModelCacheHelper::IsVarManagerParamSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (json[kSessionId].get() != session_id_) { - GELOGW("Check VarManager cache failed.[sessionId]"); - return false; - } - if (json[kDeviceId].get() != VarManager::Instance(session_id_)->DeviceId()) { - GELOGW("Check VarManager cache failed.[deviceId]"); - return false; - } - if (json[kJobId].get() != VarManager::Instance(session_id_)->JobId()) { - GELOGW("Check VarManager cache failed.[jobId]"); - return false; - } - if (json[kGraphMemMaxSize].get() != VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()) { - GELOGW("Check VarManager cache failed.[graphMemMaxSize]"); - return false; - } - if (json[kVarMemMaxSize].get() != VarManager::Instance(session_id_)->GetVarMemMaxSize()) { - GELOGW("Check VarManager cache failed.[varMemMaxSize]"); - return false; - } - if (json[kVarMemLogicBase].get() != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check VarManager cache failed.[varMemLogicBase]"); - return false; - } - if (json[kUseMaxMemSize].get() != VarManager::Instance(session_id_)->GetUseMaxMemorySize()) { - GELOGW("Check VarManager cache failed.[useMaxMemSize]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -bool ModelCacheHelper::IsVarManagerSameAsCache(Json &json) const { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return false; - } - try { - if (!IsVarManagerParamSameAsCache(json)) { - GELOGW("Check VarManager cache failed.[Param]"); - return false; - } - Json mem_resource_json = move(json[kMemResourceMap]); - auto ret = IsMemResourceSameAsCache(mem_resource_json); - if (!ret) { - GELOGW("Check VarManager cache failed.[MemResource]"); - return false; - } - Json var_resource_json = move(json[kVarResource]); - ret = IsAllocatedGraphIdSameAsCache(var_resource_json[kAllocatedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[AllocatedGraphId]"); - return false; - } - ret = IsChangedGraphIdSameAsCache(var_resource_json[kChangedGraphId]); - if (!ret) { - GELOGW("Check VarManager cache failed.[ChangedGraphId]"); - return false; - } - ret = IsBroadcastInfoSameAsCache(var_resource_json[kVarBroadcastInfo]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarBroadcastInfo]"); - return false; - } - ret = IsCurVarTensorDescSameAsCache(var_resource_json[kCurVarTensorDescMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[CurVarTensorDesc]"); - return false; - } - ret = IsVarAddrMgrMapSameAsCache(var_resource_json[kVarAddrMgrMap]); - if (!ret) { - GELOGW("Check VarManager cache failed.[VarAddrMgrMap]"); - return false; - } - ret = IsTransRoadsSameAsCache(var_resource_json[kTransRoads]); - if (!ret) { - GELOGW("Check VarManager cache failed.[TransRoads]"); - return false; - } - } catch (const std::exception &e) { - GELOGW("Fail to check VarManager json. Error message: %s", e.what()); - return false; - } - return true; -} - -Status ModelCacheHelper::RecoverMemResource(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::map var_mem_size; - auto ret = ParseMemResourceFromJson(json, var_mem_size); - if (ret != SUCCESS) { - GELOGW("Fail to parse MemResource from Json."); - return ret; - } - for (const auto &iter : var_mem_size) { - ret = VarManager::Instance(session_id_)->UpdateVarMemSize(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover var mem size."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverAllocatedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map allocated_graph_id; - auto ret = ParseAllocatedGraphIdFromJson(json, allocated_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : allocated_graph_id) { - ret = VarManager::Instance(session_id_)->SetAllocatedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover allocated graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverChangedGraphId(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map changed_graph_id; - auto ret = ParseChangedGraphIdFromJson(json, changed_graph_id); - if (ret != SUCCESS) { - GELOGW("Fail to parse AllocatedGraphId from Json."); - return ret; - } - for (const auto &iter : changed_graph_id) { - ret = VarManager::Instance(session_id_)->SetChangedGraphId(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover changed graph id."); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverVarAddrAndTensorDesc(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::vector> var_addr_mgr_vector; - std::unordered_set var_offset_set; - auto ret = ParseVarAddrMgrMapFromJson(json, var_addr_mgr_vector, var_offset_set); - if (ret != SUCCESS) { - GELOGW("Fail to parse VarAddrMgrMap from Json."); - return ret; - } - for (const auto &iter : var_addr_mgr_vector) { - const VarAddrMgr &tensor_addr_mgr = iter.second; - const bool var_exist = VarManager::Instance(session_id_)->IsVarExist(iter.first, tensor_addr_mgr.tensor_desc); - // SaveVarVddr if var does not exist, the logic address will be recorded by VarManager - if (!var_exist) { - auto logic_address = reinterpret_cast(tensor_addr_mgr.address); - auto offset = (tensor_addr_mgr.offset); - // Check logic address and offset - if (logic_address - offset != VarManager::Instance(session_id_)->GetVarMemLogicBase()) { - GELOGW("Check logic_address[%u] and offset [%u] of %s failed, var mem logic base is %u, abandon", logic_address, - offset, iter.first.c_str(), VarManager::Instance(session_id_)->GetVarMemLogicBase()); - return PARAM_INVALID; - } - // Offset is needed by SaveVarVddr instead of logic address - ret = VarManager::Instance(session_id_) - ->SaveVarAddr(iter.first, tensor_addr_mgr.tensor_desc, reinterpret_cast(offset), - tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc of var[%s].", iter.first.c_str()); - return ret; - } - } - // SetVarAddr to update cur_var_tensor_desc_map_ - ret = VarManager::Instance(session_id_) - ->SetVarAddr(iter.first, tensor_addr_mgr.tensor_desc, tensor_addr_mgr.address, tensor_addr_mgr.memory_type); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarAddr or TensorDesc desc of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverBroadcastInfo(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_broadcast_info; - auto ret = ParseBroadcastInfoFromJson(json, var_broadcast_info); - if (ret != SUCCESS) { - GELOGW("Fail to parse BroadcastInfo from Json."); - return ret; - } - for (const auto &iter : var_broadcast_info) { - VarBroadCastInfo broadcast_info; - ret = VarManager::Instance(session_id_)->SaveBroadCastInfo(graph_id_, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to recover broadcast info of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::RecoverTransRoads(const Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map> trans_roads; - auto ret = ParseTransRoadsFromJson(json, trans_roads); - if (ret != SUCCESS) { - GELOGW("Fail to parse TransRoads from Json."); - return ret; - } - for (const auto &iter : trans_roads) { - ret = VarManager::Instance(session_id_)->SetTransRoad(iter.first, iter.second); - if (ret != SUCCESS) { - GELOGW("Fail to find trans road of var[%s].", iter.first.c_str()); - return ret; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json) { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - try { - json[kDataType] = static_cast(ge_tensor_desc.GetDataType()); - json[kOriginDataType] = static_cast(ge_tensor_desc.GetOriginDataType()); - json[kLayout] = static_cast(ge_tensor_desc.GetFormat()); - json[kOriginLayout] = static_cast(ge_tensor_desc.GetOriginFormat()); - json[kShape] = ge_tensor_desc.GetShape().GetDims(); - json[kOriginShape] = ge_tensor_desc.GetOriginShape().GetDims(); - uint32_t real_dim_cnt = 0; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - json[kRealDimCnt] = real_dim_cnt; - } catch (const std::exception &e) { - GELOGW("Fail to trans GeTensorDesc to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::JsonToTensorDesc(const Json &json, ge::GeTensorDesc &ge_tensor_desc) { - if (!json.is_object()) { - GELOGW("Input param json type should be object."); - return PARAM_INVALID; - } - try { - ge_tensor_desc.SetDataType(static_cast(json[kDataType].get())); - ge_tensor_desc.SetOriginDataType(static_cast(json[kOriginDataType].get())); - ge_tensor_desc.SetFormat(static_cast(json[kLayout].get())); - ge_tensor_desc.SetOriginFormat(static_cast(json[kOriginLayout].get())); - GeShape shape(json[kShape].get>()); - ge_tensor_desc.SetShape(shape); - GeShape origin_shape(json[kOriginShape].get>()); - ge_tensor_desc.SetOriginShape(origin_shape); - auto real_dim_cnt = json[kRealDimCnt].get(); - (void)TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); // [No need to check value] - } catch (const std::exception &e) { - GELOGW("Fail to trans Json to GeTensorDesc. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetNodesHashMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - map hash_map; - GetNodesHash(hash_map); - for (const auto &iter : hash_map) { - Json node_hash_json; - try { - node_hash_json[kName] = iter.first; - node_hash_json[kHash] = iter.second; - json.emplace_back(move(node_hash_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans node cache to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetMemResourceMap(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - const auto total_size = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - const auto var_mem_size = VarManager::Instance(session_id_)->GetVarMemSize(RT_MEMORY_HBM); - Json mem_resource_json; - try { - mem_resource_json[kMemType] = RT_MEMORY_HBM; - mem_resource_json[kTotalSize] = total_size; - mem_resource_json[kVarMemSize] = var_mem_size; - json.emplace_back(move(mem_resource_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans MemResourceMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarAddrMgrMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - std::unordered_map var_addr_mgr_map; - VarManager::Instance(session_id_)->GetAllVarAddrMgr(var_addr_mgr_map); - try { - for (const auto &iter : var_addr_mgr_map) { - Json var_addr_json; - string name; - GetVarNameFromVarKey(iter.first, iter.second.tensor_desc, name); - var_addr_json[kName] = name; - var_addr_json[kAddress] = reinterpret_cast(iter.second.address); - var_addr_json[kMemoryType] = iter.second.memory_type; - var_addr_json[kOffset] = iter.second.offset; - - // Copy tensor desc to json. - Json tensor_desc_json; - auto ret = TensorDescToJson(iter.second.tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - var_addr_json[kTensorDesc] = move(tensor_desc_json); - - json.emplace_back(move(var_addr_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarAddrMgrMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetCurVarTensorDescMapJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - Json cur_tensor_desc_json; - GeTensorDesc tensor_desc; - auto ret = VarManager::Instance(session_id_)->GetCurVarDesc(name, tensor_desc); - if (ret != SUCCESS) { - GELOGI("Get variable[%s] current tensor desc failed. It will be skipped.", name.c_str()); - continue; - } - cur_tensor_desc_json[kName] = name; - - Json tensor_desc_json; - ret = TensorDescToJson(tensor_desc, tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - cur_tensor_desc_json[kTensorDesc] = move(tensor_desc_json); - json.emplace_back(move(cur_tensor_desc_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans CurVarTensorDescMap to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetTransRoadsJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - try { - for (const auto &name : var_names_) { - auto trans_road = VarManager::Instance(session_id_)->GetTransRoad(name); - if (trans_road == nullptr) { - continue; - } - // Json object, variable name and trans road - Json trans_road_map_json; - trans_road_map_json[kName] = name; - - Json trans_road_json; - Status ret; - // Add nodes' info to json - for (const auto &trans_node_info : *trans_road) { - Json trans_node_info_json; - trans_node_info_json[kNodeType] = trans_node_info.node_type; - Json input_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.input, input_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kInputTensorDesc] = move(input_tensor_desc_json); - Json output_tensor_desc_json; - ret = TensorDescToJson(trans_node_info.output, output_tensor_desc_json); - if (ret != SUCCESS) { - GELOGW("Fail to trans tensor desc to json."); - return INTERNAL_ERROR; - } - trans_node_info_json[kOutputTensorDesc] = move(output_tensor_desc_json); - trans_road_json.emplace_back(move(trans_node_info_json)); - } - trans_road_map_json[kTransRoad] = move(trans_road_json); - json.emplace_back(move(trans_road_map_json)); - } - } catch (const std::exception &e) { - GELOGW("Fail to trans VarToTransRoad to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetChangedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t changed_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetChangedGraphId(name, changed_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_changed_graph_id; - try { - name_and_changed_graph_id[kName] = name; - name_and_changed_graph_id[kGraphId] = changed_graph_id; - json.emplace_back(move(name_and_changed_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans ChangedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetAllocatedGraphIdJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - uint32_t allocated_graph_id = 0; - Status ret = VarManager::Instance(session_id_)->GetAllocatedGraphId(name, allocated_graph_id); - if (ret != SUCCESS) { - continue; - } - Json name_and_allocated_graph_id; - try { - name_and_allocated_graph_id[kName] = name; - name_and_allocated_graph_id[kGraphId] = allocated_graph_id; - json.emplace_back(move(name_and_allocated_graph_id)); - } catch (const std::exception &e) { - GELOGW("Fail to trans AllocatedGraphId to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetBroadcastInfoJson(Json &json) const { - if (!(json.is_null() || json.is_array())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const auto &name : var_names_) { - VarBroadCastInfo var_broadcast_info; - Status ret = VarManager::Instance(session_id_)->GetBroadCastInfo(graph_id_, name, var_broadcast_info); - if (ret != SUCCESS) { - continue; - } - Json var_broadcast_info_json; - try { - var_broadcast_info_json[kName] = name; - var_broadcast_info_json[kBroadcastName] = var_broadcast_info.broadcast_name; - var_broadcast_info_json[kIdx] = var_broadcast_info.idx; - var_broadcast_info_json[kInputOffset] = var_broadcast_info.input_offset; - var_broadcast_info_json[kInputSize] = var_broadcast_info.input_size; - var_broadcast_info_json[kOutputOffset] = var_broadcast_info.output_offset; - var_broadcast_info_json[kOutputSize] = var_broadcast_info.output_size; - json.emplace_back(move(var_broadcast_info_json)); - } catch (const std::exception &e) { - GELOGW("Fail to trans VarBroadcastInfo to json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarResourceJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - Json var_addr_mgr_map_json; - Status ret = GetVarAddrMgrMapJson(var_addr_mgr_map_json); - if (ret != SUCCESS) { - GELOGW("GetVarAddrMgrMapJson failed."); - return INTERNAL_ERROR; - } - - Json cur_var_tensor_desc_map_json; - ret = GetCurVarTensorDescMapJson(cur_var_tensor_desc_map_json); - if (ret != SUCCESS) { - GELOGW("GetCurVarTensorDescMapJson failed."); - return INTERNAL_ERROR; - } - - Json trans_roads_json; - ret = GetTransRoadsJson(trans_roads_json); - if (ret != SUCCESS) { - GELOGW("GetTransRoadsJson failed."); - return INTERNAL_ERROR; - } - - Json changed_graph_id_json; - ret = GetChangedGraphIdJson(changed_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetChangedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json allocated_graph_id_json; - ret = GetAllocatedGraphIdJson(allocated_graph_id_json); - if (ret != SUCCESS) { - GELOGW("GetAllocatedGraphIdJson failed."); - return INTERNAL_ERROR; - } - - Json var_broadcast_info_json; - ret = GetBroadcastInfoJson(var_broadcast_info_json); - if (ret != SUCCESS) { - GELOGW("GetBroadcastInfoJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kVarAddrMgrMap] = move(var_addr_mgr_map_json); - json[kCurVarTensorDescMap] = move(cur_var_tensor_desc_map_json); - json[kTransRoads] = move(trans_roads_json); - json[kChangedGraphId] = move(changed_graph_id_json); - json[kAllocatedGraphId] = move(allocated_graph_id_json); - json[kVarBroadcastInfo] = move(var_broadcast_info_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarResource json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarManagerJson(Json &json) const { - if (!(json.is_null() || json.is_object())) { - GELOGW("Input param json type should be null or object."); - return PARAM_INVALID; - } - - Json mem_resource_map_json; - auto ret = GetMemResourceMap(mem_resource_map_json); - if (ret != SUCCESS) { - GELOGW("GetMemResourceMap failed."); - return INTERNAL_ERROR; - } - - Json var_resource_json; - ret = GetVarResourceJson(var_resource_json); - if (ret != SUCCESS) { - GELOGW("GetVarResourceJson failed."); - return INTERNAL_ERROR; - } - - try { - json[kSessionId] = session_id_; - json[kDeviceId] = VarManager::Instance(session_id_)->DeviceId(); - json[kJobId] = VarManager::Instance(session_id_)->JobId(); - json[kGraphMemMaxSize] = VarManager::Instance(session_id_)->GetGraphMemoryMaxSize(); - json[kVarMemMaxSize] = VarManager::Instance(session_id_)->GetVarMemMaxSize(); - json[kVarMemLogicBase] = VarManager::Instance(session_id_)->GetVarMemLogicBase(); - json[kUseMaxMemSize] = VarManager::Instance(session_id_)->GetUseMaxMemorySize(); - json[kMemResourceMap] = move(mem_resource_map_json); - json[kVarResource] = move(var_resource_json); - } catch (const exception &e) { - GELOGW("Fail to generate VarManager json. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveVarManagerToCache(bool before_build) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - Json var_manager_json; - auto ret = GetVarManagerJson(var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to generate VarManager json."); - return FAILED; - } - string var_manager_path = to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + - (before_build ? kBeforeVarManagerSuffix : kAfterVarManagerSuffix); - ret = SaveJsonToFile(var_manager_path, var_manager_json); - if (ret != SUCCESS) { - GELOGW("Fail to save VarManager info to json file, path: %s.", cache_path_.c_str()); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::SaveOmModelToCache(const GeModelPtr &ge_model) const { - if (!is_cache_path_valid_for_output) { - GELOGW("Invalid cache path."); - return FAILED; - } - string om_path = RealPath(cache_path_.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check path om: %s", cache_path_.c_str()); - return FAILED; - } - string cache_om_path = cache_path_; - cache_om_path += (to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix); - GELOGI("SaveOmModelToCache: start to save om model : %s", cache_om_path.c_str()); - ModelHelper model_helper; - SaveParam save_param; - ModelBufferData model; - Status ret = model_helper.SaveToOmModel(ge_model, save_param, cache_om_path, model); - if (ret != SUCCESS) { - GELOGW("SaveOmModelToCache: save mode failed. ret = %u", ret); - return ret; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map &mem_resource) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - mem_resource.clear(); - for (const Json &mem_resource_json : json) { - MemResource var_addr_mgr; - try { - rtMemType_t mem_type = mem_resource_json[kMemType].get(); - uint64_t var_mem_size = mem_resource_json[kVarMemSize].get(); - mem_resource[mem_type] = var_mem_size; - } catch (const exception &e) { - GELOGW("Fail to trans Json to MemResource. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseVarAddrMgrMapFromJson( - const Json &json, std::vector> &var_addr_mgr_vector, - std::unordered_set &var_offset_set) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - var_addr_mgr_vector.clear(); - var_offset_set.clear(); - for (const Json &var_addr_json : json) { - VarAddrMgr var_addr_mgr; - try { - auto logic_address = var_addr_json[kAddress].get(); - auto address = reinterpret_cast(logic_address); - var_addr_mgr.address = address; - var_addr_mgr.offset = var_addr_json[kOffset].get(); - var_addr_mgr.memory_type = var_addr_json[kMemoryType].get(); - auto ret = JsonToTensorDesc(var_addr_json[kTensorDesc], var_addr_mgr.tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - var_addr_mgr_vector.emplace_back(var_addr_json[kName].get(), move(var_addr_mgr)); - var_offset_set.insert(logic_address); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - cur_var_tensor_desc_map.clear(); - for (const Json &tensor_desc_json : json) { - GeTensorDesc tensor_desc; - try { - auto ret = JsonToTensorDesc(tensor_desc_json[kTensorDesc], tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - cur_var_tensor_desc_map[tensor_desc_json[kName].get()] = move(tensor_desc); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarAddrMgr. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseTransRoadsFromJson( - const Json &json, std::unordered_map> &trans_roads) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - trans_roads.clear(); - try { - for (const Json &name_trans_road_json : json) { - const Json &trans_road_json = name_trans_road_json[kTransRoad]; - if (!(trans_road_json.is_array() || trans_road_json.is_null())) { - GELOGW("%s json type should be null or object.", kTransRoad); - return PARAM_INVALID; - } - vector trans_road; - for (const Json &trans_node_json : trans_road_json) { - TransNodeInfo trans_node_info; - trans_node_info.node_type = trans_node_json[kNodeType]; - GeTensorDesc input_tensor_desc; - auto ret = JsonToTensorDesc(trans_node_json[kInputTensorDesc], input_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.input = move(input_tensor_desc); - GeTensorDesc output_tensor_desc; - ret = JsonToTensorDesc(trans_node_json[kOutputTensorDesc], output_tensor_desc); - if (ret != SUCCESS) { - GELOGW("Fail to trans json to tensor desc."); - return ret; - } - trans_node_info.output = move(output_tensor_desc); - trans_road.emplace_back(move(trans_node_info)); - } - trans_roads[name_trans_road_json[kName].get()] = move(trans_road); - } - } catch (const exception &e) { - GELOGW("Fail to trans Json to TransRoads. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseChangedGraphIdFromJson(const Json &json, - std::unordered_map &changed_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - changed_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - changed_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to changed graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseAllocatedGraphIdFromJson(const Json &json, - std::unordered_map &allocated_graph_id) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - allocated_graph_id.clear(); - for (const Json &name_graph_id_json : json) { - try { - allocated_graph_id[name_graph_id_json[kName].get()] = name_graph_id_json[kGraphId].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to allocated graph id. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status ModelCacheHelper::ParseBroadcastInfoFromJson( - const Json &json, std::unordered_map &var_broadcast_info) { - if (!(json.is_array() || json.is_null())) { - GELOGW("Input param json type should be null or array."); - return PARAM_INVALID; - } - for (const Json &broadcast_info_json : json) { - VarBroadCastInfo broadcast_info; - try { - broadcast_info.var_name = broadcast_info_json[kName].get(); - broadcast_info.broadcast_name = broadcast_info_json[kBroadcastName].get(); - broadcast_info.idx = broadcast_info_json[kIdx].get(); - broadcast_info.input_offset = broadcast_info_json[kInputOffset].get(); - broadcast_info.input_size = broadcast_info_json[kInputSize].get(); - broadcast_info.output_offset = broadcast_info_json[kOutputOffset].get(); - broadcast_info.output_size = broadcast_info_json[kOutputSize].get(); - } catch (const exception &e) { - GELOGW("Fail to trans Json to VarBroadCastInfo. Error message: %s", e.what()); - return INTERNAL_ERROR; - } - var_broadcast_info[broadcast_info.var_name] = broadcast_info; - } - return SUCCESS; -} - -Status ModelCacheHelper::LoadOmModelFromCache(GeModelPtr &ge_model) const { - string cache_om = cache_path_ + to_string(graph_id_) + to_string(graph_id_run_times_[graph_id_]) + kOmSuffix; - if (!CheckInputPathValid(cache_om)) { - GELOGW("Invalid cache path for input:%s.", cache_om.c_str()); - return FAILED; - } - string om_path = RealPath(cache_om.c_str()); - if (om_path.empty()) { - GELOGW("file path is invalid. please check file om: %s", om_path.c_str()); - return FAILED; - } - GELOGI("load model data from file: %s", om_path.c_str()); - Status ret; - string key_path; - int32_t priority = 0; - ModelData model_data; - ret = DavinciModelParser::LoadFromFile(om_path.c_str(), key_path.c_str(), priority, model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from file fialed. ret = %u", ret); - return ret; - } - - ModelHelper model_helper; - ret = model_helper.LoadModel(model_data); - if (ret != SUCCESS) { - GELOGW("LoadOmModelFromCache: Load model from data failed. ret = %u", ret); - return ret; - } - ge_model = model_helper.GetGeModel(); - // Load TbeKernelBin to op desc from TBEKernelStore - const TBEKernelStore &tbekernel_store = ge_model->GetTBEKernelStore(); - const ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - for (const auto &node : compute_graph_in_model->GetDirectNode()) { - auto op_desc = node->GetOpDesc(); - tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); - GELOGI("LoadOmModelFromCache: Load tbe kernel bin to op desc."); - } - return SUCCESS; -} - -Status ModelCacheHelper::GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, - string &var_name) { - std::string::size_type underline_idx = var_key.rfind('_'); - if (underline_idx == std::string::npos) { - GELOGW("Invalid var key: underline not found"); - return FAILED; - } - std::string::size_type format_idx = - var_key.rfind(std::to_string(static_cast(tensor_desc.GetFormat())), underline_idx); - if (format_idx == std::string::npos) { - GELOGW("Invalid var key: format not found"); - return FAILED; - } - var_name = var_key.substr(0, format_idx); - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/common/helper/model_cache_helper.h b/src/ge/common/helper/model_cache_helper.h deleted file mode 100644 index 91257282..00000000 --- a/src/ge/common/helper/model_cache_helper.h +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ -#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ - -#include -#include -#include - -#include "ge/ge_api_error_codes.h" -#include "graph/compute_graph.h" -#include "graph/manager/graph_var_manager.h" -#include "model/ge_model.h" - -namespace ge { -using Json = nlohmann::json; - -struct CacheInfo { - size_t node_num; - size_t edge_num; - size_t graph_hash; - map nodes_hash; - CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} -}; - -class ModelCacheHelper { - public: - ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); - - Status SaveCacheInfoToCache() const; - Status SaveVarManagerToCache(bool before_build) const; - Status SaveOmModelToCache(const GeModelPtr &ge_model) const; - bool IsModelCacheHit() const; - Status RecoverVarManagerFromCache() const; - Status LoadOmModelFromCache(GeModelPtr &ge_model) const; - Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); - Status ClearCache(uint32_t graph_id) const; - - private: - Status GetComputeGraphHash(size_t &hash) const; - Status GetNodesHash(map &hash_map) const; - Status GetCacheInfo(CacheInfo &cache_info) const; - - Status RecoverMemResource(const Json &json) const; - Status RecoverAllocatedGraphId(const Json &json) const; - Status RecoverChangedGraphId(const Json &json) const; - Status RecoverVarAddrAndTensorDesc(const Json &json) const; - Status RecoverBroadcastInfo(const Json &json) const; - Status RecoverTransRoads(const Json &json) const; - static Status RecompileNodes(GeModelPtr &ge_model); - - bool IsNodeHashSameAsCache(const map &hash_map) const; - bool IsMemResourceSameAsCache(Json &json) const; - bool IsChangedGraphIdSameAsCache(Json &json) const; - bool IsAllocatedGraphIdSameAsCache(Json &json) const; - bool IsCurVarTensorDescSameAsCache(Json &json) const; - bool IsVarAddrMgrMapSameAsCache(Json &json) const; - bool IsBroadcastInfoSameAsCache(Json &json) const; - bool IsTransRoadsSameAsCache(Json &json) const; - bool IsVarManagerSameAsCache(Json &json) const; - bool IsVarManagerParamSameAsCache(Json &json) const; - - Status SaveJsonToFile(const string &file_name, const Json &json) const; - Status LoadJsonFromFile(const string &file_name, Json &json) const; - - Status GetNodesHashMapJson(Json &json) const; - Status GetMemResourceMap(Json &json) const; - Status GetVarAddrMgrMapJson(Json &json) const; - Status GetCurVarTensorDescMapJson(Json &json) const; - Status GetTransRoadsJson(Json &json) const; - Status GetChangedGraphIdJson(Json &json) const; - Status GetAllocatedGraphIdJson(Json &json) const; - Status GetBroadcastInfoJson(Json &json) const; - Status GetVarResourceJson(Json &json) const; - Status GetVarManagerJson(Json &json) const; - - static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); - static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); - static Status ParseMemResourceFromJson(const Json &json, map &mem_resource); - static Status ParseVarAddrMgrMapFromJson(const Json &json, - std::vector> &var_addr_mgr_vector, - std::unordered_set &var_offset_set); - static Status ParseCurVarTensorDescMapFromJson( - const Json &json, std::unordered_map &cur_var_tensor_desc_map); - static Status ParseTransRoadsFromJson(const Json &json, - std::unordered_map> &trans_roads); - static Status ParseChangedGraphIdFromJson(const Json &json, - std::unordered_map &changed_graph_id); - static Status ParseAllocatedGraphIdFromJson(const Json &json, - std::unordered_map &allocated_graph_id); - static Status ParseBroadcastInfoFromJson(const Json &json, - std::unordered_map &var_broadcast_info); - static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); - - uint64_t session_id_; - uint32_t graph_id_; - string cache_path_; - ComputeGraphPtr compute_graph_; - std::set var_names_; - bool is_cache_path_valid_for_output; - static map graph_id_run_times_; -}; - -using ModelCacheHelperPtr = std::shared_ptr; -} // namespace ge - -#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ diff --git a/src/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc index 3f1c3f58..29b3ff7d 100644 --- a/src/ge/common/helper/model_helper.cc +++ b/src/ge/common/helper/model_helper.cc @@ -26,17 +26,15 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" -using domi::ModelTaskDef; using ge::ModelBufferData; using ge::TBEKernelPtr; using ge::TBEKernelStore; using std::string; - namespace { const int64_t kOriginalOmPartitionNum = 1; } -namespace ge { +namespace domi { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelHelper::~ModelHelper() { (void)ReleaseLocalModelData(); } Status ModelHelper::SaveModelPartition(std::shared_ptr &om_file_save_helper, ModelPartitionType type, @@ -508,4 +506,4 @@ Status ModelHelper::ReleaseLocalModelData() noexcept { } return result; } -} // namespace ge +} // namespace domi diff --git a/src/ge/common/helper/om_file_helper.cc b/src/ge/common/helper/om_file_helper.cc index 58477b4e..3f2fc833 100644 --- a/src/ge/common/helper/om_file_helper.cc +++ b/src/ge/common/helper/om_file_helper.cc @@ -25,10 +25,11 @@ #include "framework/common/ge_inner_error_codes.h" #include "framework/common/util.h" +using ge::FileSaver; using ge::ModelBufferData; using std::string; -namespace ge { +namespace domi { // For Load FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { if (CheckModelValid(model) != SUCCESS) { @@ -225,4 +226,4 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat return SUCCESS; #endif } -} // namespace ge +} // namespace domi diff --git a/src/ge/common/math_util.h b/src/ge/common/math_util.h index 5e783e81..87364a2b 100644 --- a/src/ge/common/math_util.h +++ b/src/ge/common/math_util.h @@ -26,7 +26,7 @@ #include "framework/common/util.h" #include "mmpa/mmpa_api.h" -namespace ge { +namespace domi { /** * @ingroup domi_calibration @@ -68,6 +68,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { return SUCCESS; } -} // end namespace ge +} // end namespace domi #endif // GE_COMMON_MATH_UTIL_H_ diff --git a/src/ge/common/model_parser/base.cc b/src/ge/common/model_parser/base.cc index a9a21ec5..8485d799 100644 --- a/src/ge/common/model_parser/base.cc +++ b/src/ge/common/model_parser/base.cc @@ -22,9 +22,15 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/util.h" +#include "framework/common/debug/ge_log.h" + +using domi::GetFileLength; +using domi::MODEL_FILE_MAGIC_NUM; +using domi::ModelEncryptType; +using domi::ModelFileHeader; +using domi::RealPath; namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelParserBase::ModelParserBase() {} diff --git a/src/ge/common/model_saver.cc b/src/ge/common/model_saver.cc index f68051f4..424d2f1c 100644 --- a/src/ge/common/model_saver.cc +++ b/src/ge/common/model_saver.cc @@ -63,7 +63,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi const char *model_char = model_str.c_str(); uint32_t len = static_cast(model_str.length()); // Write data to file - mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); + int32_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); diff --git a/src/ge/common/op/attr_value_util.cc b/src/ge/common/op/attr_value_util.cc index 5d74aa1d..77d81076 100644 --- a/src/ge/common/op/attr_value_util.cc +++ b/src/ge/common/op/attr_value_util.cc @@ -18,7 +18,7 @@ #include "framework/common/debug/log.h" #include "framework/common/util.h" -namespace ge { +namespace domi { #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ FMK_FUNC_DEV_VISIBILITY void SetAttrDef(ARG_TYPE value, AttrDef *out) { \ GE_CHECK_NOTNULL_JUST_RETURN(out); \ @@ -312,4 +312,4 @@ DEFINE_GET_ATTR_LIST_SIZE(const std::string &, uint32_t, u); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, float, f); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, double, f); DEFINE_GET_ATTR_LIST_SIZE(const std::string &, bool, b); -} // namespace ge +} // namespace domi diff --git a/src/ge/common/op/ge_op_utils.cc b/src/ge/common/op/ge_op_utils.cc index b3bed399..b8a17514 100644 --- a/src/ge/common/op/ge_op_utils.cc +++ b/src/ge/common/op/ge_op_utils.cc @@ -25,10 +25,10 @@ #include "framework/common/debug/log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/ge_inner_error_codes.h" +#include "framework/common/op/attr_define.h" #include "framework/common/op/attr_value_util.h" #include "framework/common/util.h" #include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" @@ -69,8 +69,6 @@ const uint32_t FOR_LIMIT_INPUT = 1; const uint32_t FOR_DELTA_INPUT = 2; const uint32_t FOR_DATA_INPUT = 3; -const int NORMAL_TENSOR_SIZE = 4; - // Get the value of key from attr #define AIPP_GET_ATTR_VALUE(KEY, ATTR_TYPE) \ if (aipp_attr.GetItem(#KEY).GetValue(KEY) != SUCCESS) { \ @@ -179,7 +177,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OpUtils::TransferDim(con for (auto dim_temp : dim) { new_dim_list.push_back(dim_temp); } - if (input_shape_size > DIM_DEFAULT_SIZE) { + if (input_shape_size > domi::DIM_DEFAULT_SIZE) { dim_vector = dim; GELOGI("Dim_vector size is %zu, do not to transfer dim", input_shape_size); return SUCCESS; diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index b4bab921..603bdfb1 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -182,7 +182,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In return SUCCESS; } else { std::string prof_options_str = std::string(prof_options); - profiling_opts_ = StringUtils::Split(prof_options_str, ':'); + profiling_opts_ = domi::StringUtils::Split(prof_options_str, ':'); is_profiling_ = true; } GELOGI("The profiling in options is %s, %s", is_profiling, prof_options); @@ -314,119 +314,122 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( - const std::vector &task_desc_info, const int32_t &device_id) { + const std::vector &task_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter == nullptr) { GELOGI("Profiling report is nullptr!"); return; } - std::string data; - for (const auto &task : task_desc_info) { - std::string op_name = task.op_name; - uint32_t block_dim = task.block_dim; - uint32_t task_id = task.task_id; - uint32_t stream_id = task.stream_id; - data = op_name.append(" ").append(std::to_string(block_dim) - .append(" ") - .append(std::to_string(task_id)) - .append(" ") - .append(std::to_string(stream_id)) - .append("\n")); - - Msprof::Engine::ReporterData reporter_data{}; - reporter_data.deviceId = device_id; - reporter_data.data = (unsigned char *)data.c_str(); - reporter_data.dataLen = data.size(); - int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); - if (ret != EOK) { - GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); - return; - } + for (size_t i = 0; i < device_id_.size(); ++i) { + for (const auto &task : task_desc_info) { + std::string op_name = task.op_name; + uint32_t block_dim = task.block_dim; + uint32_t task_id = task.task_id; + uint32_t stream_id = task.stream_id; + data = op_name.append(" ").append(std::to_string(block_dim) + .append(" ") + .append(std::to_string(task_id)) + .append(" ") + .append(std::to_string(stream_id)) + .append("\n")); + + Msprof::Engine::ReporterData reporter_data{}; + reporter_data.deviceId = device_id_[i]; + reporter_data.data = (unsigned char *)data.c_str(); + reporter_data.dataLen = data.size(); + int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "task_desc_info", sizeof("task_desc_info")); + if (ret != EOK) { + GELOGE(ret, "Report data tag of task_desc_info memcpy error!"); + return; + } - ret = reporter->Report(&reporter_data); - if (ret != SUCCESS) { - GELOGE(ret, "Reporter data of task_desc_info fail!"); - return; + ret = reporter->Report(&reporter_data); + if (ret != SUCCESS) { + GELOGE(ret, "Reporter data of task_desc_info fail!"); + return; + } } - } - data.clear(); + data.clear(); + } #endif } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( - const std::vector &compute_graph_desc_info, const int32_t &device_id) { + const std::vector &compute_graph_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); std::string data; - for (const auto &graph : compute_graph_desc_info) { - data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); - for (size_t i = 0; i < graph.input_format.size(); ++i) { - data.append(" input_id:") - .append(std::to_string(i)) - .append(" input_format:") - .append(std::to_string(graph.input_format.at(i))) - .append(" input_data_type:") - .append(std::to_string(graph.input_data_type.at(i))) - .append(" input_shape:\""); - size_t input_shape_len = graph.input_shape.at(i).size(); - if (input_shape_len == 0) { - data.append(""); - } else if (input_shape_len == 1) { - data.append(std::to_string(graph.input_shape.at(i).at(0))); - } else { - for (size_t j = 0; j < input_shape_len - 1; ++j) { - data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); + for (size_t idx = 0; idx < device_id_.size(); ++idx) { + for (const auto &graph : compute_graph_desc_info) { + data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); + for (size_t i = 0; i < graph.input_format.size(); ++i) { + data.append(" input_id:") + .append(std::to_string(i)) + .append(" input_format:") + .append(std::to_string(graph.input_format.at(i))) + .append(" input_data_type:") + .append(std::to_string(graph.input_data_type.at(i))) + .append(" input_shape:\""); + size_t input_shape_len = graph.input_shape.at(i).size(); + if (input_shape_len == 0) { + data.append(""); + } else if (input_shape_len == 1) { + data.append(std::to_string(graph.input_shape.at(i).at(0))); + } else { + for (size_t j = 0; j < input_shape_len - 1; ++j) { + data.append(std::to_string(graph.input_shape.at(i).at(j))).append(","); + } + data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); } - data.append(std::to_string(graph.input_shape.at(i).at(input_shape_len - 1))); - } - data.append("\""); - } + data.append("\""); + } - for (size_t i = 0; i < graph.output_format.size(); ++i) { - data.append(" output_id:") - .append(std::to_string(i)) - .append(" output_format:") - .append(std::to_string(graph.output_format.at(i))) - .append(" output_data_type:") - .append(std::to_string(graph.output_data_type.at(i))) - .append(" output_shape:\""); - size_t output_shape_len = graph.output_shape.at(i).size(); - if (output_shape_len == 0) { - data.append(""); - } else if (output_shape_len == 1) { - data.append(std::to_string(graph.output_shape.at(i).at(0))); - } else { - for (size_t j = 0; j < output_shape_len - 1; ++j) { - data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); + for (size_t i = 0; i < graph.output_format.size(); ++i) { + data.append(" output_id:") + .append(std::to_string(i)) + .append(" output_format:") + .append(std::to_string(graph.output_format.at(i))) + .append(" output_data_type:") + .append(std::to_string(graph.output_data_type.at(i))) + .append(" output_shape:\""); + size_t output_shape_len = graph.output_shape.at(i).size(); + if (output_shape_len == 0) { + data.append(""); + } else if (output_shape_len == 1) { + data.append(std::to_string(graph.output_shape.at(i).at(0))); + } else { + for (size_t j = 0; j < output_shape_len - 1; ++j) { + data.append(std::to_string(graph.output_shape.at(i).at(j))).append(","); + } + data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); } - data.append(std::to_string(graph.output_shape.at(i).at(output_shape_len - 1))); + data.append("\""); } - data.append("\""); - } - data.append("\n"); + data.append("\n"); - Msprof::Engine::ReporterData reporter_data{}; - Report(device_id, data, *reporter, reporter_data); + Msprof::Engine::ReporterData reporter_data{}; + Report(idx, data, *reporter, reporter_data); - data.clear(); + data.clear(); + } } #endif } FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( - const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, + const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, Msprof::Engine::ReporterData &reporter_data) { #ifdef DAVINCI_SUPPORT_PROFILING size_t index = data.size() / kReportMaxLen; if (index >= 1) { - reporter_data.deviceId = device_id; + reporter_data.deviceId = device_id_[idx]; int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); GE_IF_BOOL_EXEC(ret != EOK, GELOGE(ret, "Report data tag of graph_desc_info memcpy error!"); return;); for (size_t i = 0; i < index; ++i) { @@ -442,7 +445,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Reporter data of graph_desc_info fail!"); return;); } } else { - reporter_data.deviceId = device_id; + reporter_data.deviceId = device_id_[idx]; reporter_data.data = (unsigned char *)data.c_str(); reporter_data.dataLen = data.size(); int ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "graph_desc_info", sizeof("graph_desc_info")); @@ -457,24 +460,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Report( FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( const std::vector &task_desc_info, const std::vector &compute_graph_desc_info) { #ifdef DAVINCI_SUPPORT_PROFILING - int32_t device_id = 0; - rtError_t rt_ret = rtGetDevice(&device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); - return; - } - GELOGI("current device_id:%d", device_id); - - auto ret = std::find(device_id_.begin(), device_id_.end(), device_id); - if (ret == device_id_.end()) { - GELOGE(FAILED, "get valid device_id failed, profiling report failed."); - return; - } - GELOGI("start ProfilingTaskDescInfo."); - ProfilingTaskDescInfo(task_desc_info, device_id); + ProfilingTaskDescInfo(task_desc_info); GELOGI("start ProfilingGraphDescInfo."); - ProfilingGraphDescInfo(compute_graph_desc_info, device_id); + ProfilingGraphDescInfo(compute_graph_desc_info); GELOGI("Report profiling data for GE end."); #endif } diff --git a/src/ge/common/profiling/profiling_manager.h b/src/ge/common/profiling/profiling_manager.h index d3bfec63..e56f514f 100644 --- a/src/ge/common/profiling/profiling_manager.h +++ b/src/ge/common/profiling/profiling_manager.h @@ -50,11 +50,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { void ReportProfilingData(const std::vector &task_desc_info, const std::vector &compute_graph_desc_info); - void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, + void Report(const size_t &idx, const string &data, Msprof::Engine::Reporter &reporter, Msprof::Engine::ReporterData &reporter_data); - void ProfilingTaskDescInfo(const std::vector &task_desc_info, const int32_t &device_id); - void ProfilingGraphDescInfo(const std::vector &compute_graph_desc_info, - const int32_t &device_id); + void ProfilingTaskDescInfo(const std::vector &task_desc_info); + void ProfilingGraphDescInfo(const std::vector &compute_graph_desc_info); void SetProfilingConfig(const string &profiling_cfg); vector GetProfilingDeviceId() const { return device_id_; } diff --git a/src/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc index b34f9463..e44fc4eb 100644 --- a/src/ge/common/properties_manager.cc +++ b/src/ge/common/properties_manager.cc @@ -59,7 +59,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::Init(co // Load file contents bool PropertiesManager::LoadFileContent(const std::string &file_path) { // Normalize the path - string resolved_file_path = RealPath(file_path.c_str()); + string resolved_file_path = domi::RealPath(file_path.c_str()); if (resolved_file_path.empty()) { DOMI_LOGE("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); return false; diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index e8ae5257..8b4e3ed4 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -15,6 +15,7 @@ */ #include "framework/common/types.h" + #include "graph/types.h" namespace ge { @@ -26,13 +27,14 @@ const std::string DUMP_LAYER = "layer"; const std::string DUMP_FILE_PATH = "path"; } // namespace ge -using ge::OpTypeRegistrar; -namespace ge { +namespace domi { const int DEFAULT_FORMAT = static_cast(ge::FORMAT_NCHW); -// Supported public property names -const std::string PROP_OME_START_TIME = "ome_start_time"; // start time -const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; // dump path -const std::string PROP_OME_LOG_PATH = "ome_log_path"; // log path +/** + * @brief Supported public property names + */ +const std::string PROP_OME_START_TIME = "ome_start_time"; /**< start time */ +const std::string PROP_OME_DUMP_PATH = "ome_dump_path"; /**< dump path */ +const std::string PROP_OME_LOG_PATH = "ome_log_path"; /**< log path */ // Profile related constant const uint32_t CCE_PROFILE_ON = 0; @@ -385,7 +387,6 @@ REGISTER_OPTYPE_DEFINE(STREAMSWITCH, "StreamSwitch"); REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); -REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); REGISTER_OPTYPE_DEFINE(SEND, "Send"); @@ -393,7 +394,6 @@ REGISTER_OPTYPE_DEFINE(RECV, "Recv"); REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); -REGISTER_OPTYPE_DEFINE(LABELGOTOEX, "LabelGotoEx"); REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); @@ -469,315 +469,315 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 8589934592; // Max size of 8 GB. const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. #endif -/// -///@brief Magic number of model file -/// +/** + * @brief Magic number of model file + */ const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number -/// -///@brief Model head length -/// +/** + * @brief Model head length + */ const uint32_t MODEL_FILE_HEAD_LEN = 256; -/// -///@ingroup domi_omg -///@brief Input node type -/// +/** + * @ingroup domi_omg + * @brief Input node type + */ const std::string INPUT_TYPE = "Input"; -/// -///@ingroup domi_omg -///@brief AIPP label, label AIPP conv operator -/// +/** + * @ingroup domi_omg + * @brief AIPP label, label AIPP conv operator + */ const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; -/// -///@ingroup domi_omg -///@brief AIPP label, label aipp data operator -/// +/** + * @ingroup domi_omg + * @brief AIPP label, label aipp data operator + */ const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; -/// -///@ingroup domi_omg -///@brief Record the w dimension of model input corresponding to dynamic AIPP -/// +/** + * @ingroup domi_omg + * @brief Record the w dimension of model input corresponding to dynamic AIPP + */ const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; -/// -///@ingroup domi_omg -///@brief Record the H dimension of model input corresponding to dynamic AIPP -/// +/** + * @ingroup domi_omg + * @brief Record the H dimension of model input corresponding to dynamic AIPP + */ const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; -/// -///@ingroup domi_omg -///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator -/// +/** + * @ingroup domi_omg + * @brief The tag of the data operator. Mark this input to the dynamic AIPP operator + */ const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; -/// -///@ingroup domi_omg -///@brief DATA node type -/// +/** + * @ingroup domi_omg + * @brief DATA node type + */ const std::string DATA_TYPE = "Data"; -/// -///@ingroup domi_omg -///@brief DATA node type -/// +/** + * @ingroup domi_omg + * @brief DATA node type + */ const std::string AIPP_DATA_TYPE = "AippData"; -/// -///@ingroup domi_omg -///@brief Frame operator type -/// +/** + * @ingroup domi_omg + * @brief Frame operator type + */ const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; -/// -///@ingroup domi_omg -///@brief Data node type -/// +/** + * @ingroup domi_omg + * @brief Data node type + */ const std::string ANN_DATA_TYPE = "AnnData"; const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; const std::string ANN_CONV_TYPE = "AnnConvolution"; const std::string ANN_FC_TYPE = "AnnFullConnection"; -/// -///@ingroup domi_omg -///@brief Convolution node type -/// +/** + * @ingroup domi_omg + * @brief Convolution node type + */ const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; -/// -///@ingroup domi_omg -///@brief Convolution node type -/// +/** + * @ingroup domi_omg + * @brief Convolution node type + */ const std::string OP_TYPE_CONVOLUTION = "Convolution"; -/// -///@ingroup domi_omg -///@brief Add convolution node name to AIPP -/// +/** + * @ingroup domi_omg + * @brief Add convolution node name to AIPP + */ const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; -/// -///@ingroup domi_omg -///@brief Operator configuration item separator -/// +/** + * @ingroup domi_omg + * @brief Operator configuration item separator + */ const std::string OP_CONF_DELIMITER = ":"; -/// -///@ingroup domi_omg -///@brief attr value name -/// +/** + * @ingroup domi_omg + * @brief attr value name + */ const std::string ATTR_NAME_VALUE1 = "value1"; -/// -///@ingroup domi_omg -///@brief attr value name, 6d_2_4d C -/// +/** + * @ingroup domi_omg + * @brief attr value name, 6d_2_4d C + */ const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; -/// -///@ingroup domi_omg -///@brief alpha default value -/// +/** + * @ingroup domi_omg + * @brief alpha default value + */ const float ALPHA_DEFAULT_VALUE = 1.0; -/// -///@ingroup domi_omg -///@brief beta default value -/// +/** + * @ingroup domi_omg + * @brief beta default value + */ const float BETA_DEFAULT_VALUE = 0.0; -/// -///@ingroup domi_omg -///@brief coef default value -/// +/** + * @ingroup domi_omg + * @brief coef default value + */ const float COEF_DEFAULT_VALUE = 0.0; -/// -///@ingroup domi_omg -///@brief Relu6 coef value -/// +/** + * @ingroup domi_omg + * @brief Relu6 coef value + */ const float RELU6_COEF = 6.0; -/// -///@ingroup domi_omg -///@brief stride default value -/// +/** + * @ingroup domi_omg + * @brief stride default value + */ const uint32_t STRIDE_DEFAULT_VALUE = 1; -/// -///@ingroup domi_omg -///@brief pad default value -/// +/** + * @ingroup domi_omg + * @brief pad default value + */ const uint32_t PAD_DEFAULT_VALUE = 0; -/// -///@ingroup domi_omg -///@brief dilation default value -/// +/** + * @ingroup domi_omg + * @brief dilation default value + */ const int DILATION_DEFAULT_VALUE = 1; -/// -///@ingroup domi_omg -///@brief kernel default value -/// +/** + * @ingroup domi_omg + * @brief kernel default value + */ const uint32_t KERNEL_DEFAULT_VALUE = 0; -/// -///@ingroup domi_omg -///@brief defaule convolution group size -/// +/** + * @ingroup domi_omg + * @brief defaule convolution group size + */ const uint32_t DEFAULT_CONV_GROUP = 1; -/// -///@ingroup domi_omg -///@brief Default deconvolution adj -/// +/** + * @ingroup domi_omg + * @brief Default deconvolution adj + */ const uint32_t DEFAULT_DECONV_ADJ = 0; -/// -///@ingroup domi_omg -///@brief Represents value 1 -/// +/** + * @ingroup domi_omg + * @brief Represents value 1 + */ const uint32_t NUM_ONE = 1; -/// -///@ingroup domi_omg -///@brief spatial dim size default value -/// +/** + * @ingroup domi_omg + * @brief spatial dim size default value + */ const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; -/// -///@ingroup domi_omg -///@brief dim extended default value -/// +/** + * @ingroup domi_omg + * @brief dim extended default value + */ const int32_t DIM_DEFAULT_VALUE = 1; -/// -///@ingroup domi_omg -///@brief The first weight list in opdef is filter -/// +/** + * @ingroup domi_omg + * @brief The first weight list in opdef is filter + */ const int32_t WEIGHT_FILTER_INDEX = 0; -/// -///@ingroup domi_omg -///@brief The second weight list in opdef is bias -/// +/** + * @ingroup domi_omg + * @brief The second weight list in opdef is bias + */ const int32_t WEIGHT_BIAS_INDEX = 1; const int32_t TENSOR_ND_SUPPORT_SIZE = 8; -/// -///@ingroup domi_omg -///@brief NCHW index default value -/// +/** + * @ingroup domi_omg + * @brief NCHW index default value + */ const uint32_t NCHW_DIM_N = 0; const uint32_t NCHW_DIM_C = 1; const uint32_t NCHW_DIM_H = 2; const uint32_t NCHW_DIM_W = 3; -/// -///@ingroup domi_omg -///@brief KCHW index default value -/// +/** + * @ingroup domi_omg + * @brief KCHW index default value + */ const uint32_t KCHW_DIM_K = 0; const uint32_t KCHW_DIM_C = 1; const uint32_t KCHW_DIM_H = 2; const uint32_t KCHW_DIM_W = 3; -/// -///@ingroup domi_omg -///@brief HWCK index default value -/// +/** + * @ingroup domi_omg + * @brief HWCK index default value + */ const uint32_t HWCK_DIM_H = 0; const uint32_t HWCK_DIM_W = 1; const uint32_t HWCK_DIM_C = 2; const uint32_t HWCK_DIM_K = 3; -/// -///@ingroup domi_omg -///@brief NHWC index default value -/// +/** + * @ingroup domi_omg + * @brief NHWC index default value + */ const uint32_t NHWC_DIM_N = 0; const uint32_t NHWC_DIM_H = 1; const uint32_t NHWC_DIM_W = 2; const uint32_t NHWC_DIM_C = 3; -/// -///@ingroup domi_omg -///@brief CHWN index default value -/// +/** + * @ingroup domi_omg + * @brief CHWN index default value + */ const uint32_t CHWN_DIM_N = 3; const uint32_t CHWN_DIM_C = 0; const uint32_t CHWN_DIM_H = 1; const uint32_t CHWN_DIM_W = 2; -/// -///@ingroup domi_omg -///@brief CHW index default value -/// +/** + * @ingroup domi_omg + * @brief CHW index default value + */ const uint32_t CHW_DIM_C = 0; const uint32_t CHW_DIM_H = 1; const uint32_t CHW_DIM_W = 2; -/// -///@ingroup domi_omg -///@brief HWC index default value -/// +/** + * @ingroup domi_omg + * @brief HWC index default value + */ const uint32_t HWC_DIM_H = 0; const uint32_t HWC_DIM_W = 1; const uint32_t HWC_DIM_C = 2; -/// -///@ingroup domi_omg -///@brief Pad index default value -/// +/** + * @ingroup domi_omg + * @brief Pad index default value + */ const uint32_t PAD_H_HEAD = 0; const uint32_t PAD_H_TAIL = 1; const uint32_t PAD_W_HEAD = 2; const uint32_t PAD_W_TAIL = 3; -/// -///@ingroup domi_omg -///@brief window index default value -/// +/** + * @ingroup domi_omg + * @brief window index default value + */ const uint32_t WINDOW_H = 0; const uint32_t WINDOW_W = 1; -/// -///@ingroup domi_omg -///@brief stride index default value -/// +/** + * @ingroup domi_omg + * @brief stride index default value + */ const uint32_t STRIDE_H = 0; const uint32_t STRIDE_W = 1; -/// -///@ingroup domi_omg -///@brief dilation index default value -/// +/** + * @ingroup domi_omg + * @brief dilation index default value + */ const uint32_t DILATION_H = 0; const uint32_t DILATION_W = 1; -/// -///@ingroup domi_omg -///@brief the num of XRBG channel -/// +/** + * @ingroup domi_omg + * @brief the num of XRBG channel + */ const uint32_t XRGB_CHN_NUM = 4; -/// -///@ingroup domi_omg -///@brief global pooling default value -/// +/** + * @ingroup domi_omg + * @brief global pooling default value + */ const bool DEFAULT_GLOBAL_POOLING = false; -const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// +const uint32_t MODEL_VERSION = 0x10000000; /**< Model version 1.0 */ // Eltwise's input size const int ELTWISE_MIN_INPUT_SIZE = 2; -// flowctrl +/* flowctrl */ const std::string NODE_NAME_STREAM_SWITCH = "IteratorCtrl_StreamSwitch"; const std::string NODE_NAME_STREAM_ACTIVE = "IteratorCtrl_StreamActive"; const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER = "npu_runconfig/iterations_per_loop"; @@ -792,4 +792,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; -}; // namespace ge +}; // namespace domi diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index f1a2fe6c..44a8586d 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -57,7 +57,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M const int kMaxFileSizeLimit = INT_MAX; } // namespace -namespace ge { +namespace domi { static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Message *proto) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(proto == nullptr, return false, "incorrect parameter. nullptr == proto"); @@ -196,7 +196,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); auto dir_path_len = directory_path.length(); if (dir_path_len >= PATH_MAX) { - GELOGW("Directory path is too long."); + GELOGE(ge::FAILED, "Directory path is too long."); return -1; } char tmp_dir_path[PATH_MAX] = {0}; @@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", + GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); return ret; } @@ -218,7 +218,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: int32_t ret = mmMkdir(const_cast(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 if (ret != 0) { if (errno != EEXIST) { - GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); + GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", + directory_path.c_str()); return ret; } } @@ -338,7 +339,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { // The specified path is empty if (file_path.empty()) { - GELOGW("Path is empty."); + GELOGE(ge::FAILED, "Path is empty."); return false; } @@ -357,23 +358,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string real_path = RealPath(file_path.c_str()); // Unable to get absolute path (does not exist or does not have permission to access) if (real_path.empty()) { - GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); + GELOGE(ge::FAILED, "Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); return false; } // The absolute path points to a file that is not readable if (access(real_path.c_str(), R_OK) != 0) { - GELOGW("Can not read file in %s, %s", file_path.c_str(), strerror(errno)); + GELOGE(ge::FAILED, "Can not read file in %s, %s", file_path.c_str(), strerror(errno)); return false; } return true; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { +FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { // The specified path is empty if (file_path.empty()) { - GELOGW("Path is empty."); + GELOGE(ge::FAILED, "Path is empty."); return false; } @@ -393,8 +394,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const // Can get absolute path (file exists) if (!real_path.empty()) { // File is not readable or writable - if (access(real_path.c_str(), W_OK | F_OK) != 0) { - GELOGW("Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); + if (access(real_path.c_str(), R_OK | W_OK | F_OK) != 0) { + GELOGE(ge::FAILED, "Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); return false; } } else { @@ -412,7 +413,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); // Determine whether the specified path is valid by creating the path if (CreateDirectory(prefix_path) != 0) { - GELOGW("Can not create prefix path for path[ %s ].", file_path.c_str()); + GELOGE(ge::FAILED, "Can not create prefix path for path[ %s ].", file_path.c_str()); return false; } } @@ -435,4 +436,4 @@ FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::str return true; #endif } -} // namespace ge +} // namespace domi diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index fae29e75..7401b062 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -47,7 +47,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/kernel_task_info.cc" "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" "../graph/load/new_model_manager/task_info/label_set_task_info.cc" - "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" @@ -86,6 +85,7 @@ include_directories(${GE_SOURCE_DIR}/inc) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index 7342f1a7..120187cc 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -193,8 +193,15 @@ Status GeExecutor::Initialize() { } // Start profiling + int32_t device_id = 0; + rtError_t rt_ret = rtGetDevice(&device_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "runtime get device_id failed, current device_id:%d", device_id); + return FAILED; + } + GELOGI("current device_id:%d", device_id); Options profiling_options; - profiling_options.device_id = 0; + profiling_options.device_id = device_id; profiling_options.job_id = ""; ProfilingManager::Instance().Init(profiling_options); @@ -345,7 +352,7 @@ Status GeExecutor::LoadModelOffline(uint32_t &model_id, const std::string &path, return GE_EXEC_NOT_INIT; } - string filePath = RealPath(path.c_str()); + string filePath = domi::RealPath(path.c_str()); if (filePath.empty()) { GELOGE(ge::FAILED, "fileath is invalid. please check your text file '%s'.", path.c_str()); return ge::FAILED; @@ -396,6 +403,10 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { return GE_EXEC_NOT_INIT; } + // stop profiling + if (!ProfilingManager::Instance().ProfilingOpTraceOn() && ProfilingManager::Instance().ProfilingOn()) { + ProfilingManager::Instance().StopProfiling(); + } return GraphLoader::UnloadModel(model_id); } @@ -554,7 +565,7 @@ Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_da return GE_EXEC_NOT_INIT; } - string filePath = RealPath(path.c_str()); + string filePath = domi::RealPath(path.c_str()); if (filePath.empty()) { GELOGE(ge::FAILED, "filePath is invalid. please check your text file '%s'.", path.c_str()); return ge::FAILED; diff --git a/src/ge/ge_local_engine/CMakeLists.txt b/src/ge/ge_local_engine/CMakeLists.txt index 80a3c335..559c782d 100755 --- a/src/ge/ge_local_engine/CMakeLists.txt +++ b/src/ge/ge_local_engine/CMakeLists.txt @@ -35,6 +35,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.cc b/src/ge/ge_local_engine/engine/host_cpu_engine.cc index 9ee616ac..c4fe9ea1 100644 --- a/src/ge/ge_local_engine/engine/host_cpu_engine.cc +++ b/src/ge/ge_local_engine/engine/host_cpu_engine.cc @@ -237,7 +237,7 @@ Status HostCpuEngine::LoadLib(const std::string &lib_path) { } Status HostCpuEngine::GetRealPath(std::string &path) { - std::string real_path = RealPath(path.c_str()); + std::string real_path = domi::RealPath(path.c_str()); if (real_path.empty()) { GELOGW("File path %s is invalid.", path.c_str()); return INTERNAL_ERROR; diff --git a/src/ge/ge_local_engine/engine/host_cpu_engine.h b/src/ge/ge_local_engine/engine/host_cpu_engine.h index 1987138d..88985f87 100644 --- a/src/ge/ge_local_engine/engine/host_cpu_engine.h +++ b/src/ge/ge_local_engine/engine/host_cpu_engine.h @@ -21,7 +21,7 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" #include "graph/operator.h" -#include "inc/register/register.h" +#include "register/register.h" namespace ge { class HostCpuEngine { diff --git a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc index cde6640f..4eae65c5 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.cc @@ -17,6 +17,8 @@ #include "ge_local_engine/ops_kernel_store/ge_local_ops_kernel_info.h" #include #include "common/constant/constant.h" +#include "framework/common/debug/ge_log.h" +#include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" diff --git a/src/ge/ge_runtime/CMakeLists.txt b/src/ge/ge_runtime/CMakeLists.txt index aa4e3470..b914b21b 100755 --- a/src/ge/ge_runtime/CMakeLists.txt +++ b/src/ge/ge_runtime/CMakeLists.txt @@ -26,6 +26,7 @@ include_directories(${GE_SOURCE_DIR}/inc/framework/common) include_directories(${GE_SOURCE_DIR}/inc/framework/ge_runtime) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index 330ffc14..b60df61d 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -447,11 +447,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. - int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); - if (elem_num == 0 && constant->weight_tensors[0].size == 0) { - elem_num = 1; - } - + int64_t elem_num = + (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); if (constant->weight_data.size() < sizeof(uint64_t)) { GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); return false; diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index 8cae441c..3cc067c5 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -28,6 +28,11 @@ #include "graph/utils/graph_utils.h" #include "model/ge_model.h" +using domi::DATA; +using domi::ModelHelper; +using domi::NETOUTPUT; +using domi::NODE_NAME_NET_OUTPUT; +using domi::SaveParam; using ge::ModelBufferData; using std::map; using std::string; @@ -101,7 +106,7 @@ static void GetOpsProtoPath(string &opsproto_path) { const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { string path = path_env; - string file_path = RealPath(path.c_str()); + string file_path = domi::RealPath(path.c_str()); if (file_path.empty()) { GELOGE(FAILED, "File path %s is invalid.", path.c_str()); return; @@ -143,7 +148,7 @@ Status GeGenerator::Initialize(const map &options) { GELOGI("opsproto_path is %s", opsproto_path.c_str()); OpsProtoManager *manager = OpsProtoManager::Instance(); map option_tmp; - option_tmp.emplace(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); + option_tmp.insert(std::pair(string("ge.opsProtoLibPath"), opsproto_path)); (void)manager->Initialize(option_tmp); Status ret = impl_->graph_manager_.Initialize(options); @@ -258,7 +263,7 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector op_attrs = op_desc->GetAllAttrs(); // 1. Create ComputeGraph. - string name = ge::CurrentTimeInStr() + "_" + model_file_name; + string name = domi::CurrentTimeInStr() + "_" + model_file_name; ge::ComputeGraphPtr compute_graph = MakeShared(name); if (compute_graph == nullptr) { return INTERNAL_ERROR; diff --git a/src/ge/generator/generator_api.cc b/src/ge/generator/generator_api.cc index 3f92f1a2..094baab8 100644 --- a/src/ge/generator/generator_api.cc +++ b/src/ge/generator/generator_api.cc @@ -116,7 +116,7 @@ Status_t OpTaskGernerator(const char *op_type, const OpTensor_t *in_tensor, int CHECK_PARAM_NOT_NULL(om_file); const std::string om_file_name(om_file); - std::string op_name = std::string(op_type) + "_" + std::to_string(ge::GetCurrentTimestap()); + std::string op_name = std::string(op_type) + "_" + std::to_string(domi::GetCurrentTimestap()); ge::OpDescPtr op_desc = ge::MakeShared(op_name, op_type); if (op_desc == nullptr) { return ge::FAILED; diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index 957ddc2d..de222c8c 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -18,15 +18,18 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/opskernel/ops_kernel_info_types.h" -#include "graph/build/run_context.h" #include "graph/build/stream_graph_optimizer.h" +#include "graph/build/run_context.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "model/ge_model.h" +using domi::ATTR_MODEL_MEMORY_SIZE; +using domi::ATTR_MODEL_WEIGHT_SIZE; using domi::BuildMode; +using domi::DATA; namespace { const int32_t kInvalidPerfLevel = -1; @@ -98,10 +101,8 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, uint64_t session_id) { GE_CHECK_NOTNULL(model_ptr); GE_CHECK_NOTNULL(comp_graph); @@ -192,7 +193,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr } StreamGraphOptimizer stream_optimizer; - ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_map, run_context.GetRunContext()); + ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_ptr_list, run_context.GetRunContext()); if (ret != SUCCESS) { GELOGE(ret, "Optimize streamed subGraph fail."); return ret; diff --git a/src/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h index d0bf26e6..c1c4f7b6 100644 --- a/src/ge/graph/build/graph_builder.h +++ b/src/ge/graph/build/graph_builder.h @@ -53,7 +53,7 @@ class GraphBuilder { private: Status CalcOpParam(const ge::ComputeGraphPtr &graph); Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, - Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); + std::vector &subgraph_ptr_list, uint64_t session_id = INVALID_SESSION_ID); Status SetInputSize(const ge::NodePtr &node_ptr); Status UpdateDataInputSize(const ge::NodePtr &node_ptr); Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector &subgraph_ptr_list); diff --git a/src/ge/graph/build/logical_stream_allocator.cc b/src/ge/graph/build/logical_stream_allocator.cc index 16c4935e..d57d5ac5 100644 --- a/src/ge/graph/build/logical_stream_allocator.cc +++ b/src/ge/graph/build/logical_stream_allocator.cc @@ -16,17 +16,22 @@ #include "graph/build/logical_stream_allocator.h" #include "common/ge/ge_util.h" +#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" using std::map; using std::set; using std::string; using std::vector; +using domi::ATTR_NAME_STREAM_LABEL; +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::HCOMALLREDUCE; + namespace { const char *const kAICPUEngineName = "DNN_VM_AICPU"; const char *const kAttrNameParentOpType = "parentOpType"; @@ -70,7 +75,7 @@ bool LogicalStreamPass::HasNonConstInputNode(const Subgraph &subgraph) const { return false; } -Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { +Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; map label_streams; @@ -97,7 +102,7 @@ Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector & return changed ? SUCCESS : NOT_CHANGED; } -Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { +Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; @@ -129,7 +134,8 @@ Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { +Status AssignByDependencyPass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, + Context &context) { bool changed = false; if (IsHeadNodeExceeded(subgraphs)) { int64_t &next_stream = context.next_stream; @@ -297,7 +303,7 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { subgraph->stream_id = stream_id; engine_next_streams_[engine_name] = stream_id + 1; - assigned_subgraphs_.emplace_back(subgraph); + assigned_subgraphs_.emplace(subgraph); if ((stream_id + 1) > engine_stream_num_[engine_name]) { engine_stream_num_[engine_name] = stream_id + 1; @@ -310,15 +316,6 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { } void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { - // If the parent stream is valid, the first assigned stream will reuse the parent stream id - // and other streams use new id. To ensure that the id of the new stream is continuous, - // we first subtract one from next_stream. - int64_t to_be_updated_stream = kInvalidStream; - if (context.parent_stream != kInvalidStream) { - context.next_stream--; - to_be_updated_stream = context.next_stream; - } - // Update the starting stream id for each engine. int64_t &next_stream = context.next_stream; map engine_start_streams; @@ -328,16 +325,10 @@ void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { next_stream += stream_count; } - // Update the subgraph streams assigned by engine. + // Update the subgraphs assigned by the engine. for (auto &subgraph : assigned_subgraphs_) { subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; - if (subgraph->stream_id == to_be_updated_stream) { - subgraph->stream_id = context.parent_stream; - GELOGI("Subgraph %s of engine %s reuses parent stream %ld.", subgraph->name.c_str(), - subgraph->engine_conf.id.c_str(), context.parent_stream); - } else { - GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); - } + GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); } } @@ -351,7 +342,7 @@ void AssignByDependencyPass::UpdateReusedSubgraphs() { } } -Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { +Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector &subgraphs, Context &context) { // Check if all subgraphs have been assigned a stream. for (const SubgraphPtr &subgraph : subgraphs) { const string &engine_name = subgraph->engine_conf.id; @@ -367,7 +358,7 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vectorGetDirectNode()) { + for (NodePtr &node : whole_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); node->GetOpDesc()->SetStreamId(kInvalidStream); } @@ -389,11 +380,76 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { + if (!context.hcom_parallel) { + return NOT_CHANGED; + } + + GELOGI("AllReduceParallelPass is enabled."); + GraphUtils::DumpGEGraph(whole_graph, "BeforeAllReduceParallel"); + + // All successors of HcomAllReduce. + set all_reduce_succs; + + for (const NodePtr &node : whole_graph->GetDirectNode()) { + if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { + continue; + } + + string reduce_stream_label; + GE_CHECK_NOTNULL(node->GetOpDesc()); + // ATTR_NAME_STREAM_LABEL is optional. + (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); + + set cur_nodes = {node}; + while (!cur_nodes.empty()) { + set all_out_data_nodes; + for (auto &curr_node : cur_nodes) { + for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { + string out_stream_label; + GE_CHECK_NOTNULL(out_node->GetOpDesc()); + // ATTR_NAME_STREAM_LABEL is optional. + (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); + if (out_stream_label == reduce_stream_label) { + all_reduce_succs.emplace(out_node); + all_out_data_nodes.emplace(out_node); + } + } + } + cur_nodes = all_out_data_nodes; + } + } + + map old_stream_to_new; + for (const NodePtr &node : all_reduce_succs) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + auto old_stream = node->GetOpDesc()->GetStreamId(); + if (old_stream != kInvalidStream) { + int64_t new_stream = kInvalidStream; + auto iter = old_stream_to_new.find(old_stream); + if (iter != old_stream_to_new.end()) { + new_stream = iter->second; + } else { + new_stream = context.next_stream; + context.next_stream++; + old_stream_to_new.emplace(old_stream, new_stream); + } + + GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); + node->GetOpDesc()->SetStreamId(new_stream); + } + } + + return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; +} + int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { set stream_ids; @@ -421,11 +477,11 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { return kInvalidStream; } -Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, +Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const vector &subgraphs) { set nodes_to_be_updated; - // Check if subgraph is engine skipped and without stream label or not + // Check if sub graph is engine skipped and without stream label or not for (const SubgraphPtr &subgraph : subgraphs) { if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { auto graph = subgraph->subgraph_info.GetSubGraph(); @@ -441,7 +497,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph } // Try reassign the stream id - for (ge::NodePtr &node : graph->GetDirectNode()) { + for (ge::NodePtr &node : whole_graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); int64_t stream_id = op_desc->GetStreamId(); @@ -458,7 +514,6 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph } } } - return SUCCESS; } @@ -475,65 +530,40 @@ bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { return true; } -Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { - if (!context.hcom_parallel) { - return NOT_CHANGED; - } - - GELOGI("AllReduceParallelPass is enabled."); - GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); - - // All successors of HcomAllReduce. - set all_reduce_succs; - - for (const NodePtr &node : graph->GetDirectNode()) { - if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { - continue; - } - - string reduce_stream_label; - GE_CHECK_NOTNULL(node->GetOpDesc()); - (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); +void NodeStreamUpdatePass::RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const { + int64_t stream_num = context.next_stream; + vector stream_has_node(stream_num); - set cur_nodes = {node}; - while (!cur_nodes.empty()) { - set all_out_data_nodes; - for (auto &curr_node : cur_nodes) { - for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { - string out_stream_label; - GE_CHECK_NOTNULL(out_node->GetOpDesc()); - (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); - if (out_stream_label == reduce_stream_label) { - all_reduce_succs.emplace(out_node); - all_out_data_nodes.emplace(out_node); - } + for (const NodePtr &node : whole_graph->GetDirectNode()) { + if (node != nullptr) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream && stream_id < stream_num) { + stream_has_node[stream_id] = true; } } - cur_nodes = all_out_data_nodes; } } - map old_stream_to_new; - for (const NodePtr &node : all_reduce_succs) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - auto old_stream = node->GetOpDesc()->GetStreamId(); - if (old_stream != kInvalidStream) { - int64_t new_stream = kInvalidStream; - auto iter = old_stream_to_new.find(old_stream); - if (iter != old_stream_to_new.end()) { - new_stream = iter->second; - } else { - new_stream = context.next_stream; - context.next_stream++; - old_stream_to_new.emplace(old_stream, new_stream); - } - - GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); - node->GetOpDesc()->SetStreamId(new_stream); + context.next_stream = 0; + vector old_to_new_streams(stream_num, kInvalidStream); + for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { + if (stream_has_node[old_stream]) { + old_to_new_streams[old_stream] = context.next_stream; + ++context.next_stream; } } - return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; + for (const NodePtr &node : whole_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + if (op_desc != nullptr) { + int64_t stream_id = op_desc->GetStreamId(); + if (stream_id != kInvalidStream && stream_id < stream_num) { + op_desc->SetStreamId(old_to_new_streams[stream_id]); + } + } + } } LogicalStreamAllocator::LogicalStreamAllocator(const map &scheduler_confs, @@ -542,10 +572,9 @@ LogicalStreamAllocator::LogicalStreamAllocator(const map context_.hcom_parallel = hcom_parallel; } -Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, +Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const vector &subgraph_infos, int64_t &stream_num) { GE_CHECK_NOTNULL(whole_graph); - map engine_confs; GE_TIMESTAMP_START(InitEngineConfs); for (const auto &item : scheduler_confs_) { @@ -559,64 +588,16 @@ Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const } GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); - Status status = DoAssign(whole_graph, subgraph_map, engine_confs); - if (status != SUCCESS) { - GELOGE(status, "Assign streams failed."); - return status; - } - - vector subgraphs = whole_graph->GetAllSubgraphs(); - for (const ComputeGraphPtr &subgraph : subgraphs) { - Status status = DoAssign(subgraph, subgraph_map, engine_confs); - if (status != SUCCESS) { - GELOGE(status, "Assign streams failed."); - return status; - } - } - - RefreshContinuousStreams(whole_graph); - - stream_num = context_.next_stream; - GELOGI("Assigned logical stream num: %ld.", stream_num); - - return SUCCESS; -} - -Status LogicalStreamAllocator::DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, - const map &engine_confs) { - GE_CHECK_NOTNULL(graph); - - NodePtr parent_node = graph->GetParentNode(); - if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { - context_.parent_stream = kInvalidStream; - } else { - context_.parent_stream = parent_node->GetOpDesc()->GetStreamId(); - } - - auto iter = subgraph_map.find(graph); - if (iter == subgraph_map.end()) { - GELOGE(FAILED, "Graph %s not found.", graph->GetName().c_str()); - return FAILED; - } - - const vector &subgraph_info_list = iter->second; vector subgraphs; GE_TIMESTAMP_START(ConvertSubgraphs); - Status status = ConvertSubgraphs(subgraph_info_list, engine_confs, subgraphs); + Status status = ConvertSubgraphs(subgraph_infos, engine_confs, subgraphs); GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); if (status != SUCCESS) { GELOGE(status, "Create subgraphs failed."); return status; } - GELOGI("Subgraphs of graph %s:", graph->GetName().c_str()); - for (const auto &subgraph : subgraphs) { - if (subgraph != nullptr) { - GELOGI("subgraph: %s", subgraph->name.c_str()); - } - } - - return RunPasses(graph, subgraphs); + return RunPasses(whole_graph, subgraphs, stream_num); } Status LogicalStreamAllocator::ConvertSubgraphs(const vector &subgraph_infos, @@ -655,7 +636,8 @@ Status LogicalStreamAllocator::ConvertSubgraphs(const vector &s return SUCCESS; } -Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vector &subgraphs) { +Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, const vector &subgraphs, + int64_t &stream_num) { vector passes; passes.emplace_back(MakeShared()); passes.emplace_back(MakeShared()); @@ -666,7 +648,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec for (auto &pass : passes) { GE_CHECK_NOTNULL(pass); - Status status = pass->Run(graph, subgraphs, context_); + Status status = pass->Run(whole_graph, subgraphs, context_); if (status == SUCCESS) { GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); } else if (status == NOT_CHANGED) { @@ -677,42 +659,9 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vec } } - return SUCCESS; -} - -void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &graph) { - int64_t stream_num = context_.next_stream; - vector stream_has_node(stream_num); - - for (const NodePtr &node : graph->GetAllNodes()) { - if (node != nullptr) { - auto op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream && stream_id < stream_num) { - stream_has_node[stream_id] = true; - } - } - } - } - - context_.next_stream = 0; - vector old_to_new_streams(stream_num, kInvalidStream); - for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { - if (stream_has_node[old_stream]) { - old_to_new_streams[old_stream] = context_.next_stream; - ++context_.next_stream; - } - } + stream_num = context_.next_stream; + GELOGI("Assigned logical stream num: %ld.", stream_num); - for (const NodePtr &node : graph->GetAllNodes()) { - auto op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id != kInvalidStream && stream_id < stream_num) { - op_desc->SetStreamId(old_to_new_streams[stream_id]); - } - } - } + return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/build/logical_stream_allocator.h b/src/ge/graph/build/logical_stream_allocator.h index 404d22f9..2265a0f3 100644 --- a/src/ge/graph/build/logical_stream_allocator.h +++ b/src/ge/graph/build/logical_stream_allocator.h @@ -60,7 +60,7 @@ class LogicalStreamPass { }; struct Context { - int64_t parent_stream = kInvalidStream; + // Next stream id. int64_t next_stream = 0; bool hcom_parallel = false; }; @@ -71,7 +71,7 @@ class LogicalStreamPass { virtual ~LogicalStreamPass() = default; const std::string &GetName() const; - virtual Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) = 0; + virtual Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) = 0; protected: bool IsEngineSkip(const Subgraph &subgraph) const; @@ -93,21 +93,21 @@ using LogicalStreamPassPtr = std::shared_ptr; class AssignByLabelPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); - Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; }; // Engines such as hccl require independent Stream. class IndependentStreamPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); - Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; }; // Reuse streams or assign new streams based on dependencies. class AssignByDependencyPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); - Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; private: void InitEndSubgraphMap(const std::vector &subgraphs, std::map &end_subgraph_map); @@ -132,7 +132,7 @@ class AssignByDependencyPass : public LogicalStreamPass { std::map engine_stream_num_; // Subgraphs of assign stream by engine - std::vector assigned_subgraphs_; + std::set assigned_subgraphs_; // std::vector> reused_subgraphs_; @@ -142,7 +142,7 @@ class AssignByDependencyPass : public LogicalStreamPass { class NodeStreamUpdatePass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); - Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; private: /// Optimize for case like: @@ -150,18 +150,19 @@ class NodeStreamUpdatePass : public LogicalStreamPass { /// To case: /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) - Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector &subgraphs); + Status UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs); int64_t GetSingleInoutStream(const NodePtr &node) const; // Judge if all predecessors' streams of node are INVALID_STREAM bool AreAllPredStreamsInvalid(const NodePtr &node) const; + void RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const; }; // AllReduce and backward operators execute in parallel. class AllReduceParallelPass : public LogicalStreamPass { public: STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); - Status Run(ComputeGraphPtr graph, const std::vector &subgraphs, Context &context) override; + Status Run(ComputeGraphPtr whole_graph, const std::vector &subgraphs, Context &context) override; }; // Assign logical streams which is not limited by the number of tasks. @@ -177,16 +178,13 @@ class LogicalStreamAllocator { LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; ~LogicalStreamAllocator() = default; - Status Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num); + Status Assign(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs, int64_t &stream_num); private: - Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, - const map &engine_confs); Status ConvertSubgraphs(const std::vector &subgraph_infos, const std::map &engine_confs, std::vector &subgraphs); - Status RunPasses(const ComputeGraphPtr &graph, const std::vector &subgraphs); - void RefreshContinuousStreams(const ComputeGraphPtr &graph); + Status RunPasses(const ComputeGraphPtr &whole_graph, const std::vector &subgraphs, int64_t &stream_num); const std::map &scheduler_confs_; const std::map &max_parallel_num_; diff --git a/src/ge/graph/build/memory/CMakeLists.txt b/src/ge/graph/build/memory/CMakeLists.txt index ea87b906..aa474dd8 100644 --- a/src/ge/graph/build/memory/CMakeLists.txt +++ b/src/ge/graph/build/memory/CMakeLists.txt @@ -33,6 +33,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external) include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) diff --git a/src/ge/graph/build/memory/binary_block_mem_assigner.cc b/src/ge/graph/build/memory/binary_block_mem_assigner.cc index 8668e81e..67c04ef6 100644 --- a/src/ge/graph/build/memory/binary_block_mem_assigner.cc +++ b/src/ge/graph/build/memory/binary_block_mem_assigner.cc @@ -100,13 +100,13 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { GELOGD("Origin ranges:"); for (auto &v : ranges) { - GELOGD("__%s", ToString(v).c_str()); + GELOGD("__%s", domi::ToString(v).c_str()); } PlanRanges(range_number_limit, ranges); GELOGD("Origin ranges:"); for (auto &v : ranges) { - GELOGD("__%s", ToString(v).c_str()); + GELOGD("__%s", domi::ToString(v).c_str()); } for (auto &range : ranges) { @@ -115,7 +115,7 @@ Status BinaryBlockMemAssigner::GetMemoryRanges(vector &range_ceils) { range_ceils.push_back(range.back()); } } - GELOGI("Range ceils: %s", ToString(range_ceils).c_str()); + GELOGI("Range ceils: %s", domi::ToString(range_ceils).c_str()); return SUCCESS; } diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index 4f55a569..e0fd3d9b 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -29,6 +29,7 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" +#include "common/op/attr_define.h" #include "graph/debug/ge_attr_define.h" #include "graph/optimize/common/params.h" @@ -46,6 +47,29 @@ const int kReuseMaxCharNum = 2000; } // namespace namespace ge { +using domi::AIPP_DATA_TYPE; +using domi::AIPPDATA; +using domi::ANN_DATA_TYPE; +using domi::APPLYMOMENTUM; +using domi::ASSIGN; +using domi::ASSIGNADD; +using domi::ASSIGNSUB; +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::DATA; +using domi::DATA_TYPE; +using domi::ENTER; +using domi::FASTRCNNPREDICTIONS; +using domi::HCOMALLREDUCE; +using domi::HCOMBROADCAST; +using domi::MULTISHAPE; +using domi::NETOUTPUT; +using domi::NEXTITERATION; +using domi::PROPOSAL; +using domi::REFENTER; +using domi::REFNEXTITERATION; +using domi::VARIABLE; +using domi::ZEROSLIKE; using std::map; using std::pair; using std::string; @@ -134,7 +158,7 @@ string ToString(ge::NodeTypeIndex &x) { string MemoryBlock::String() { stringstream ss; ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << ""; - ss << "real_size_list: " << ToString(real_size_list_) << ""; + ss << "real_size_list: " << domi::ToString(real_size_list_) << ""; ss << "ref_count: " << ref_count_ << ""; ss << "members: "; for (auto x : NodeTypeIndexList()) { @@ -175,7 +199,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); } sort(all_memory_size.begin(), all_memory_size.end()); - GELOGI("All memory size: %s", ToString(all_memory_size).c_str()); + GELOGI("All memory size: %s", domi::ToString(all_memory_size).c_str()); for (auto iter = all_memory_size.begin(); iter != all_memory_size.end();) { if (*iter == 0) { diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index bcae79ea..7fc07f42 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -18,6 +18,7 @@ #include #include #include "common/math/math_util.h" +#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/build/memory/hybrid_mem_assigner.h" #include "graph/build/memory/var_mem_assign_util.h" @@ -28,6 +29,19 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +using domi::AIPP_DATA_TYPE; +using domi::ATOMICADDRCLEAN; +using domi::ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +using domi::ATTR_NAME_AUTOMIC_ADD_START; +using domi::CONCAT; +using domi::CONSTANTOP; +using domi::DATA_TYPE; +using domi::HCOMBROADCAST; +using domi::LABELSWITCHBYINDEX; +using domi::NODE_NAME_NET_OUTPUT; +using domi::STREAMMERGE; +using domi::VARIABLE; + namespace { const int kDataOutputIndex = 0; const int kAllInputAddrIsAtomic = -1; @@ -423,10 +437,8 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); } memory_offset_[0].mem_offset_ += extra_memory_size; - size_t after_mem_offset = memory_offset_[0].mem_offset_; - AlignMemOffset(MEM_ALIGN_SIZE); - GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); + GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu.", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); } } return SUCCESS; @@ -511,10 +523,8 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { } op_desc->SetOutputOffset(output_list); memory_offset_[0].mem_offset_ += extra_memory_size; - size_t after_mem_offset = memory_offset_[0].mem_offset_; - AlignMemOffset(MEM_ALIGN_SIZE); - GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); + GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu.", + op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); } } return SUCCESS; diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc index 0a2061f8..ff5f9798 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.cc +++ b/src/ge/graph/build/memory/var_mem_assign_util.cc @@ -16,6 +16,7 @@ #include "graph/build/memory/var_mem_assign_util.h" #include +#include "common/op/attr_define.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "graph/common/transop_util.h" @@ -50,10 +51,10 @@ Status VarMemAssignUtil::AssignMemory2VariableNode(ge::ComputeGraphPtr &compute_ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_graph) { GE_IF_BOOL_EXEC(compute_graph == nullptr, return FAILED); for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - GE_IF_BOOL_EXEC((n->GetType() != VARIABLE) && (n->GetType() != CONSTANTOP), continue); + GE_IF_BOOL_EXEC((n->GetType() != domi::VARIABLE) && (n->GetType() != domi::CONSTANTOP), continue); string ref_var_src_var_name; GE_CHECK_NOTNULL(n->GetOpDesc()); - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); + GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); string node_name = n->GetName(); GE_IF_BOOL_EXEC(n->GetOpDesc()->GetAllOutputsDesc().empty(), GELOGE(FAILED, "node:%s has no OutputDesc.", n->GetName().c_str()); @@ -63,7 +64,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { GE_CHK_STATUS_RET( VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM)); - GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, + GE_IF_BOOL_EXEC(n->GetType() == domi::VARIABLE, GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) ->SetAllocatedGraphId(node_name, compute_graph->GetGraphID())); @@ -84,7 +85,7 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t session_id) { string src_var_name; GE_CHECK_NOTNULL(node->GetOpDesc()); - if (ge::AttrUtils::GetStr(node->GetOpDesc(), VAR_ATTR_SRC_VAR_NAME, src_var_name)) { + if (ge::AttrUtils::GetStr(node->GetOpDesc(), domi::VAR_ATTR_SRC_VAR_NAME, src_var_name)) { ge::GeTensorDesc cur_tensor_desc; uint8_t *dev_ptr = nullptr; rtMemType_t memory_type = RT_MEMORY_HBM; @@ -99,10 +100,11 @@ Status VarMemAssignUtil::AssignData2Fp32Var(const ge::NodePtr &node, uint64_t se Status VarMemAssignUtil::AssignVarAttr2Nodes(ge::ComputeGraphPtr &compute_graph) { for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetType() != VARIABLE, continue); + GE_IF_BOOL_EXEC(node->GetType() != domi::VARIABLE, continue); string ref_var_src_var_name; GE_CHECK_NOTNULL(node->GetOpDesc()); - GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), continue); + GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name), + continue); GE_CHK_STATUS_RET(DealVariableNode(compute_graph->GetGraphID(), node, compute_graph->GetSessionID())); } return SUCCESS; @@ -140,7 +142,8 @@ Status VarMemAssignUtil::DealExportVariableNode(const ge::NodePtr &node, const g GE_IF_BOOL_EXEC(var_out_anchor == nullptr, return FAILED); for (const ge::InDataAnchorPtr &dst_in_var_anchor : var_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { + if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || + (dst_node->GetType() == domi::ASSIGNSUB)) { if (dst_in_var_anchor == dst_node->GetInDataAnchor(0)) { GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, var_node, session_id)); } @@ -208,19 +211,20 @@ Status VarMemAssignUtil::DealVariableNode(uint32_t graph_id, const ge::NodePtr & for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == domi::HCOMBROADCAST) { GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); continue; } - if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { + if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || + (dst_node->GetType() == domi::ASSIGNSUB)) { if (dst_in_data_anchor == dst_node->GetInDataAnchor(0)) { GE_CHK_STATUS_RET(DealExportVariableNode(dst_node, node, session_id)); } } auto dst_type = dst_node->GetType(); - bool is_trans_node = - (dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); + bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || + (dst_type == domi::PERMUTE); if (is_trans_node) { NodePtr final_trans_node = GetFinalTransNode(dst_node); GE_CHK_STATUS_RET(DealTransNode(final_trans_node)); @@ -237,8 +241,8 @@ ge::NodePtr VarMemAssignUtil::GetFinalTransNode(const ge::NodePtr &trans_node) { for (const auto &dst_in_anchor : trans_out_data_anchor->GetPeerInDataAnchors()) { NodePtr dst_node = dst_in_anchor->GetOwnerNode(); auto dst_type = dst_node->GetType(); - bool is_trans_node = - (dst_type == TRANSDATA) || (dst_type == CAST) || (dst_type == TRANSPOSE) || (dst_type == PERMUTE); + bool is_trans_node = (dst_type == domi::TRANSDATA) || (dst_type == domi::CAST) || (dst_type == domi::TRANSPOSE) || + (dst_type == domi::PERMUTE); if (is_trans_node && (dst_in_anchor->GetIdx() == 0)) { final_ref_node = GetFinalTransNode(dst_node); } @@ -252,7 +256,8 @@ Status VarMemAssignUtil::DealTransNode(const ge::NodePtr &final_trans_node) { GE_IF_BOOL_EXEC(final_trans_out_anchor == nullptr, return SUCCESS); for (const ge::InDataAnchorPtr &dst_in_var_anchor : final_trans_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { + if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || + (dst_node->GetType() == domi::ASSIGNSUB)) { GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); } } @@ -264,7 +269,8 @@ Status VarMemAssignUtil::DealExportTransNode(const ge::NodePtr &node, const ge:: GE_CHECK_NOTNULL(node_out_anchor); for (const ge::InDataAnchorPtr &dst_in_var_anchor : node_out_anchor->GetPeerInDataAnchors()) { ge::NodePtr dst_node = dst_in_var_anchor->GetOwnerNode(); - if ((dst_node->GetType() == ASSIGN) || (dst_node->GetType() == ASSIGNADD) || (dst_node->GetType() == ASSIGNSUB)) { + if ((dst_node->GetType() == domi::ASSIGN) || (dst_node->GetType() == domi::ASSIGNADD) || + (dst_node->GetType() == domi::ASSIGNSUB)) { GE_CHK_STATUS_RET(DealExportTransNode(dst_node, final_trans_node)); } } @@ -300,7 +306,7 @@ Status VarMemAssignUtil::AssignMemory2HasRefAttrNode(ge::ComputeGraphPtr &comput for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { string ref_var_src_var_name; GE_CHECK_NOTNULL(n->GetOpDesc()); - bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); GE_IF_BOOL_EXEC(is_ref, GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); } @@ -323,7 +329,7 @@ Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, GE_CHECK_SIZE(ref_attr_node_output_list.size()); int out_index = 0; - bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, out_index); + bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), domi::REF_VAR_PRE_PEER_OUT_INDEX, out_index); if (!is_get) { GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); } diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index ac61eeeb..750cc90b 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -17,7 +17,6 @@ #include "graph/build/model_builder.h" #include #include -#include #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" @@ -28,7 +27,6 @@ #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" -#include "graph/ge_context.h" #include "graph/ge_error_codes.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" @@ -40,11 +38,46 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/ge_context.h" #include "init/gelib.h" #include "memory/memory_assigner.h" #include "omg/version.h" #include "register/op_registry.h" +using domi::AIPP_CONV_FLAG; +using domi::AIPP_DATA_FLAG; +using domi::AIPP_DATA_TYPE; +using domi::AippOpParams; +using domi::ATOMICADDRCLEAN; +using domi::ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +using domi::ATTR_NAME_AUTOMIC_ADD_START; +using domi::CAST; +using domi::CHW_DIM_H; +using domi::CHW_DIM_W; +using domi::CONCAT; +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::CONVOLUTION; +using domi::DATA; +using domi::DATA_TYPE; +using domi::DEFAULT_FORMAT; +using domi::DIM_DEFAULT_SIZE; +using domi::DOMI_TENSOR_NC1HWC0; +using domi::HWC_DIM_H; +using domi::HWC_DIM_W; +using domi::LOOPCOND; +using domi::MODEL_ATTR_TASKS; +using domi::ModelTaskDef; +using domi::NCHW_DIM_H; +using domi::NCHW_DIM_N; +using domi::NCHW_DIM_W; +using domi::NETOUTPUT; +using domi::NHWC_DIM_H; +using domi::NHWC_DIM_W; +using domi::PlatformVersionManager; +using domi::STREAMMERGE; +using domi::VARIABLE; +using domi::XRGB_CHN_NUM; using ge::FAILED; using ge::PARAM_INVALID; using ge::SUCCESS; @@ -69,25 +102,25 @@ const char *const kVectorCore = "VectorCore"; const char *const kCoreType = "ge.engineType"; const std::string kEnableL1Fusion = "ge.l1Fusion"; -const set adjust_layer_type_ = {ge::CONVOLUTION}; +const set adjust_layer_type_ = {CONVOLUTION}; bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { auto type = op_desc->GetType(); - if (type == ge::CONSTANTOP) { + if (type == CONSTANTOP) { // constant op just has one output ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); return !(output_desc.GetDataType() == ge::DT_STRING); } - const set ge_local_set = {ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, - ge::VARIABLE, ge::NOOP, ge::CONSTANT, ge::ENTER, - ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, - ge::EXIT, ge::REFEXIT, ge::MEMCPYADDRASYNC}; + const set ge_local_set = {domi::STREAMMERGE, domi::MEMCPYASYNC, domi::STREAMACTIVE, domi::STREAMSWITCH, + domi::VARIABLE, domi::NOOP, domi::CONSTANT, domi::ENTER, + domi::REFENTER, domi::LOOPCOND, domi::NEXTITERATION, domi::REFNEXTITERATION, + domi::EXIT, domi::REFEXIT}; return (ge_local_set.find(type) != ge_local_set.end()); } } // namespace namespace ge { -ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, +ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const vector &subgraphs, const map &stream_max_parallel_num, bool hcom_parallel, int mode) : mem_offset_(0), weight_offset_(kWeightsStartOffset), @@ -189,7 +222,7 @@ void ModelBuilder::SetInputIsConst(const ge::NodePtr &n) { } } - std::string input_const_info = ToString(is_input_const); + std::string input_const_info = domi::ToString(is_input_const); GELOGD("update opdesc:%s InputConst:%s", node_op_desc->GetName().c_str(), input_const_info.c_str()); node_op_desc->SetIsInputConst(is_input_const); } @@ -226,25 +259,6 @@ Status ModelBuilder::SetInputOutputDesc() { if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { is_loop_graph_ = true; } - // if user set input node format ND, the expected node for data and netoutput format is ND in - // final graph. - if ((domi::GetContext().format == domi::DOMI_TENSOR_ND) && - ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { - GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); - auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); - auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); - ge::Format format = ge::FORMAT_ND; - for (auto &inputDescPtr : inputDescsPtr) { - GE_CHECK_NOTNULL(inputDescPtr); - inputDescPtr->SetFormat(format); - inputDescPtr->SetOriginFormat(format); - } - for (auto &outputDescPtr : outputDescsPtr) { - GE_CHECK_NOTNULL(outputDescPtr); - outputDescPtr->SetFormat(format); - outputDescPtr->SetOriginFormat(format); - } - } if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { GELOGD("Data node: %s.", n->GetName().c_str()); @@ -525,7 +539,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { return INTERNAL_ERROR; } int byte_size = static_cast(task_def_bytes.GetSize()); - std::shared_ptr task = ge::MakeShared(); + std::shared_ptr task = ge::MakeShared(); GE_CHECK_NOTNULL(task); GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_def_bytes.GetData(), byte_size, task.get()), return INTERNAL_ERROR, "ReadProtoFromArray failed."); @@ -571,6 +585,12 @@ Status ModelBuilder::PreBuildModel() { Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { GE_CHK_STATUS_RET(AdjustInputTensorFlag(), "AdjustInputTensorFlag failed!"); + // Assign functional op labels. + GE_TIMESTAMP_START(AssignFunctionalLabels); + LabelAllocator label_allocator(compute_graph_); + GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); + GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); + // Assign logical streams. StreamAllocator stream_allocator(compute_graph_, subgraphs_); GE_TIMESTAMP_START(AssignLogicalStreams); @@ -578,12 +598,6 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { "Assign logical streams failed."); GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); - // Assign functional op labels. - GE_TIMESTAMP_START(AssignFunctionalLabels); - LabelAllocator label_allocator(compute_graph_); - GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); - GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); - GE_TIMESTAMP_START(AssignMemory); MemoryAssigner mem_assigner(compute_graph_); GE_CHK_STATUS_RET(mem_assigner.AssignMemory(is_loop_graph_, mem_offset_), "Assign Memory Failed!"); diff --git a/src/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h index 072126e3..4bf03bdc 100644 --- a/src/ge/graph/build/model_builder.h +++ b/src/ge/graph/build/model_builder.h @@ -37,7 +37,7 @@ namespace ge { class ModelBuilder { public: - ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, + ModelBuilder(ge::ComputeGraphPtr whole_graph, const std::vector &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode = static_cast(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); @@ -85,7 +85,7 @@ class ModelBuilder { ge::ComputeGraphPtr compute_graph_; - const Graph2SubGraphInfoList &subgraphs_; + const std::vector &subgraphs_; int64_t stream_num_; diff --git a/src/ge/graph/build/run_context.cc b/src/ge/graph/build/run_context.cc index f2a41271..d0fab3bd 100644 --- a/src/ge/graph/build/run_context.cc +++ b/src/ge/graph/build/run_context.cc @@ -17,6 +17,7 @@ #include "graph/build/run_context.h" #include "common/util.h" +#include "framework/common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" @@ -169,6 +170,7 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; + return SUCCESS; } diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index 88c5e055..baa5e400 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -17,6 +17,7 @@ #include "graph/build/stream_allocator.h" #include #include "common/ge/ge_util.h" +#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "framework/common/fmk_error_codes.h" #include "framework/common/types.h" @@ -26,6 +27,15 @@ #include "graph/build/logical_stream_allocator.h" +using domi::ATTR_NAME_STREAM_LABEL; +using domi::HCOMALLGATHER; +using domi::HCOMALLREDUCE; +using domi::HCOMBROADCAST; +using domi::HCOMREDUCESCATTER; +using domi::RECV; +using domi::SEND; +using domi::STREAMACTIVE; +using domi::STREAMSWITCH; using std::map; using std::set; using std::string; @@ -40,7 +50,7 @@ const uint32_t kMaxSwitchStreamNum = 1; namespace ge { Status StreamAllocator::AssignLogicalStreams(const std::map &max_parallel_num, bool hcom_parallel) { - GELOGI("Assign logical streams start."); + GELOGI("AssignLogicalStreams start."); GE_CHECK_NOTNULL(whole_graph_); GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); @@ -52,6 +62,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m } const map &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); + LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); if (status != SUCCESS) { @@ -61,7 +72,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map &m GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); - GELOGI("Assign logical streams success."); + GELOGI("AssignLogicalStreams success."); return SUCCESS; } @@ -135,7 +146,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); stream_num_ = 1; } - GELOGI("stream num: %ld, event num: %u.", stream_num_, event_num_); + GELOGI("stream_num_: %ld, event_num_: %u.", stream_num_, event_num_); GELOGI("RefreshRealStream successfully."); stream_num = stream_num_; @@ -147,7 +158,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu // Split the stream according to the maximum number of nodes in the stream. Status StreamAllocator::SplitStreams() { if (stream_num_ == 0) { - GELOGI("The number of streams is 0 and no need to split."); + GELOGI("stream_num_ is 0"); return SUCCESS; } diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index a18e00d7..e3901205 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -30,7 +30,7 @@ namespace ge { class StreamAllocator { public: - StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs) + StreamAllocator(ComputeGraphPtr whole_graph, const std::vector &subgraphs) : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} StreamAllocator(const StreamAllocator &) = delete; StreamAllocator &operator=(const StreamAllocator &) = delete; @@ -75,7 +75,7 @@ class StreamAllocator { bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; ComputeGraphPtr whole_graph_; - const Graph2SubGraphInfoList &subgraphs_; + const std::vector &subgraphs_; int64_t stream_num_{0}; uint32_t event_num_{0}; diff --git a/src/ge/graph/build/stream_graph_optimizer.cc b/src/ge/graph/build/stream_graph_optimizer.cc index 42d1afc1..5af54783 100644 --- a/src/ge/graph/build/stream_graph_optimizer.cc +++ b/src/ge/graph/build/stream_graph_optimizer.cc @@ -17,6 +17,7 @@ #include "stream_graph_optimizer.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" + #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "init/gelib.h" @@ -29,21 +30,19 @@ static const int64_t kInvalidStream = -1; namespace ge { StreamGraphOptimizer::~StreamGraphOptimizer() {} -void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { +void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, vector &subgraph_infos) { size_t node_size = comp_graph->GetDirectNodesSize(); GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); - for (const auto &subgraph_pair : subgraph_map) { - for (const auto &subgraph_info : subgraph_pair.second) { - ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); - if (subgraph == nullptr) { - continue; - } - for (ge::NodePtr &node : subgraph->GetDirectNode()) { - GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); - if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { - node->GetOpDesc()->SetId(static_cast(node_size)); - node_size++; - } + for (const auto &sub_graph_info : subgraph_infos) { + ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); + if (sub_graph == nullptr) { + continue; + } + for (ge::NodePtr &node : sub_graph->GetDirectNode()) { + GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); + if ((node->GetType() == domi::END) || (node->GetType() == domi::PLACEHOLDER)) { + node->GetOpDesc()->SetId(static_cast(node_size)); + node_size++; } } } @@ -73,71 +72,67 @@ bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { } Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, - Graph2SubGraphInfoList &subgraph_map, + vector &subgraph_infos, struct RunContext &run_context) { - GELOGI("Optimize streamed subgraph start."); + Status ret = SUCCESS; + GELOGI("Begin to Get optimize streamed subgraph."); - RefreshNodeId(comp_graph, subgraph_map); + RefreshNodeId(comp_graph, subgraph_infos); std::shared_ptr instance = ge::GELib::GetInstance(); GE_CHECK_NOTNULL(instance); - for (const auto &subgraph_pair : subgraph_map) { - for (const auto &subgraph_info : subgraph_pair.second) { - ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); - GE_CHECK_NOTNULL(subgraph); - - GELOGI("Optimize subgraph %s", subgraph->GetName().c_str()); + for (auto &sub_graph_info : subgraph_infos) { + ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); + if (sub_graph == nullptr) { + continue; + } - std::string engine_name = subgraph_info->GetEngineName(); + std::string engine_name = sub_graph_info->GetEngineName(); - vector graph_optimizers; - if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { - instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); - GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, graph Optimizer num: %zu.", - subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); + vector graph_optimizers; + if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { + instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); + GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, subgraph num: %zu, graph Optimizer num: %zu.", + sub_graph->GetName().c_str(), engine_name.c_str(), subgraph_infos.size(), graph_optimizers.size()); - auto nodes = subgraph->GetDirectNode(); - if (nodes.empty()) { - continue; - } - if (!IsSameStreamId(subgraph)) { - GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); - continue; - } - OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - int64_t stream_id = op_desc->GetStreamId(); - if (static_cast(stream_id) >= run_context.graphStreamList.size()) { - GELOGE(FAILED, "stream_id %ld is bigger than run_context.graphStreamList.size() %zu", stream_id, - run_context.graphStreamList.size()); - return FAILED; - } - run_context.stream = run_context.graphStreamList[stream_id]; - GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", - subgraph->GetName().c_str(), engine_name.c_str(), stream_id, - static_cast(reinterpret_cast(run_context.stream))); - for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { - GE_CHECK_NOTNULL(*iter); - Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); - if (ret != SUCCESS) { - GELOGE( - ret, - "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " - "Optimizer num: %zu, ret: %u", - subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); - return ret; - } - GELOGI( - "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " - "Optimizer num: %zu!", - subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); + auto nodes = sub_graph->GetDirectNode(); + if (nodes.empty()) { + continue; + } + if (!IsSameStreamId(sub_graph)) { + GELOGI("There are more than one stream in subgraph %s", sub_graph->GetName().c_str()); + continue; + } + OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + int64_t stream_id = op_desc->GetStreamId(); + if (static_cast(stream_id) >= run_context.graphStreamList.size()) { + GELOGE(FAILED, "stream_id is bigger than run_context.graphStreamList.size()"); + return FAILED; + } + run_context.stream = run_context.graphStreamList[stream_id]; + GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", + sub_graph->GetName().c_str(), engine_name.c_str(), stream_id, + static_cast(reinterpret_cast(run_context.stream))); + for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { + GE_CHECK_NOTNULL(*iter); + ret = (*iter)->OptimizeStreamGraph(*sub_graph, run_context); + if (ret != SUCCESS) { + GELOGE(ret, + "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " + "Optimizer num: %zu, ret: %u", + sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); + return ret; } + GELOGI( + "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " + "Optimizer num: %zu!", + sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); } } } - GELOGI("Optimize streamed subgraph success."); - return SUCCESS; + return ret; } } // namespace ge diff --git a/src/ge/graph/build/stream_graph_optimizer.h b/src/ge/graph/build/stream_graph_optimizer.h index 3133d32d..a65f95f2 100644 --- a/src/ge/graph/build/stream_graph_optimizer.h +++ b/src/ge/graph/build/stream_graph_optimizer.h @@ -35,11 +35,11 @@ class StreamGraphOptimizer { virtual ~StreamGraphOptimizer(); - Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, + Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list, struct RunContext &run_context); private: - void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); + void RefreshNodeId(const ComputeGraphPtr &comp_graph, std::vector &subgraph_ptr_list); bool IsSameStreamId(const ComputeGraphPtr &comp_graph); }; diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index cc34e352..2266f137 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -21,6 +21,7 @@ #include "common/types.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/op/attr_define.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/manager/graph_var_manager.h" @@ -30,8 +31,14 @@ #include "graph/utils/type_utils.h" #include "init/gelib.h" +using domi::CONSTANT; +using domi::HCOMALLREDUCE; using domi::LogTimeStampDef; +using domi::MODEL_ATTR_TASK_GEN_BASE_ADDR; +using domi::MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; +using domi::MODEL_ATTR_TASKS; using domi::ModelTaskDef; +using domi::NETOUTPUT; using domi::TaskDef; using std::map; using std::string; @@ -221,8 +228,10 @@ Status TaskGenerator::SaveL1fusionNodes(map> &l1_f if (call_check) { auto input_group_id = *input_group_ids.begin(); if (group_id != input_group_id) { - GELOGW("L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", + GELOGE(INTERNAL_ERROR, + "L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", name.c_str(), type.c_str(), group_id, input_group_id); + return INTERNAL_ERROR; } } } @@ -237,7 +246,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); return GE_CLI_GE_NOT_INITIALIZED; } - GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkFirstAndLastNode failed."); + GE_CHK_STATUS_RET(MarkFirstAndLastNode(graph), "MarkFirstAndLastNode failed."); ProfilingPoint ppoint; vector ar_ppoint; GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, ppoint, ar_ppoint)); @@ -258,6 +267,7 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra for (auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + op_desc->SetId(node_index); node_index++; string name = node->GetName(); string type = node->GetType(); @@ -483,26 +493,23 @@ Status TaskGenerator::UpdateAnchorStatus(const NodePtr &node) { return SUCCESS; } -Status TaskGenerator::MarkNodeAndSetIndex(ComputeGraphPtr &graph) { +Status TaskGenerator::MarkFirstAndLastNode(ComputeGraphPtr &graph) { std::shared_ptr ge_lib = GELib::GetInstance(); if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized or is finalized"); return GE_CLI_GE_NOT_INITIALIZED; } - int64_t node_index = 0; map>> engine_stream_stat; for (auto &node : graph->GetAllNodes()) { - const OpDescPtr &op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - string op_kernel_lib_name = op_desc->GetOpKernelLibName(); - int64_t stream_id = op_desc->GetStreamId(); - op_desc->SetId(node_index++); + GE_CHECK_NOTNULL(node->GetOpDesc()); + string op_kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); + int64_t stream_id = node->GetOpDesc()->GetStreamId(); if (op_kernel_lib_name.empty()) { // Reset op kernel lib - (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(op_desc); - op_kernel_lib_name = op_desc->GetOpKernelLibName(); + (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(node->GetOpDesc()); + op_kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); if (op_kernel_lib_name.empty()) { GELOGE(INTERNAL_ERROR, "node:%s(%s) get op kernel lib failed.", node->GetName().c_str(), node->GetType().c_str()); diff --git a/src/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h index 7fa48ea1..1f4a1f0d 100644 --- a/src/ge/graph/build/task_generator.h +++ b/src/ge/graph/build/task_generator.h @@ -104,7 +104,7 @@ class TaskGenerator { RunContext &run_context); // Mark first and last node according to the same stream and engine - Status MarkNodeAndSetIndex(ComputeGraphPtr &graph); + Status MarkFirstAndLastNode(ComputeGraphPtr &graph); // profiling interface Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &ppoint, diff --git a/src/ge/graph/common/omg_util.cc b/src/ge/graph/common/omg_util.cc index 00091c10..0a6d98d2 100644 --- a/src/ge/graph/common/omg_util.cc +++ b/src/ge/graph/common/omg_util.cc @@ -18,10 +18,14 @@ #include +#include "common/op/attr_define.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +using domi::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +using domi::ATTR_NAME_STREAM_LABEL; +using domi::FRAMEWORKOP; using ge::AttrUtils; using ge::OpDescPtr; @@ -57,7 +61,7 @@ Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) { OpDescPtr tmp_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(tmp_desc); - if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) { + if (!AttrUtils::SetStr(tmp_desc, ATTR_NAME_STREAM_LABEL, label)) { GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str()); return FAILED; } @@ -74,7 +78,7 @@ Status SetCycleEvent(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); OpDescPtr tmp_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(tmp_desc); - if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) { + if (!AttrUtils::SetBool(tmp_desc, ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) { GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_CYCLE_EVENT_FLAG failed", node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/common/transop_util.cc b/src/ge/graph/common/transop_util.cc index 3250929d..8631529e 100644 --- a/src/ge/graph/common/transop_util.cc +++ b/src/ge/graph/common/transop_util.cc @@ -24,8 +24,8 @@ const int kInvalidTransopDataIndex = -1; namespace ge { TransOpUtil::TransOpUtil() { - transop_index_map_ = {{TRANSDATA, 0}, {TRANSPOSE, 0}, {TRANSPOSED, 0}, {RESHAPE, 0}, - {REFORMAT, 0}, {CAST, 0}, {SQUEEZE, 0}, {EXPANDDIMS, 0}}; + transop_index_map_ = {{domi::TRANSDATA, 0}, {domi::TRANSPOSE, 0}, {domi::TRANSPOSED, 0}, {domi::RESHAPE, 0}, + {domi::REFORMAT, 0}, {domi::CAST, 0}, {domi::SQUEEZE, 0}, {domi::EXPANDDIMS, 0}}; } TransOpUtil::~TransOpUtil() {} diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 0f83a494..56e31de3 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -16,7 +16,6 @@ #include "graph/execute/graph_execute.h" -#include #include #include "common/ge_inner_error_codes.h" diff --git a/src/ge/graph/label/case_label_maker.cc b/src/ge/graph/label/case_label_maker.cc index 4d477bb7..2d024499 100644 --- a/src/ge/graph/label/case_label_maker.cc +++ b/src/ge/graph/label/case_label_maker.cc @@ -23,6 +23,8 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +using domi::CASE; + namespace ge { constexpr uint32_t kCasePredIndex = 0; constexpr uint32_t kMinCaseBranch = 1; @@ -55,8 +57,6 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { return SUCCESS; } - NodePtr first_label = nullptr; - ComputeGraphPtr first_graph = nullptr; std::vector switch_labels; uint32_t last_label_index = label_index++; for (uint32_t index = 0; index < graph_num; ++index) { @@ -66,16 +66,11 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { // all branch, add label node to head. uint32_t curr_label_index = label_index++; std::string label_set_name = parent_node_->GetName() + "/LabelSet_" + std::to_string(index); // rtLabelSet - NodePtr label = AddLabelSetEnter(graph, label_set_name, curr_label_index); - if (label == nullptr) { + if (AddLabelSetEnter(graph, label_set_name, curr_label_index) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", graph->GetName().c_str()); return FAILED; } switch_labels.emplace_back(curr_label_index); - if (index == 0) { // save first subgraph node for switch. - first_label = label; - first_graph = graph; - } if (index + 1 < graph_num) { // middle node, add goto node to tail. @@ -95,7 +90,7 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { } // Add Switch node for first branch. - GE_CHECK_NOTNULL(first_label); + ComputeGraphPtr first_graph = parent_graph_->GetSubgraph(graph_names[0]); GE_CHECK_NOTNULL(first_graph); GeTensorDesc pred_desc = case_desc->GetInputDesc(kCasePredIndex); @@ -109,12 +104,6 @@ Status CaseOpLabelMaker::Run(uint32_t &label_index) { return FAILED; } - // Link control edge to then branch head. - if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), first_label->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", first_label->GetName().c_str()); - return FAILED; - } - uint32_t parent_index = 0; // Case cond input is first. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData"; if (AddLabelSwitchIndex(first_graph, data_name, cond_desc, switch_node, parent_index) == nullptr) { diff --git a/src/ge/graph/label/if_label_maker.cc b/src/ge/graph/label/if_label_maker.cc index 5a7c028b..142cf625 100644 --- a/src/ge/graph/label/if_label_maker.cc +++ b/src/ge/graph/label/if_label_maker.cc @@ -23,11 +23,24 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +using domi::_IF; +using domi::IF; +using domi::STATELESSIF; + namespace ge { constexpr uint8_t kIfPredIndex = 0; constexpr uint8_t kThenBranchIndex = 0; constexpr uint8_t kElseBranchIndex = 1; +// All ---> Node ---> If ---> Node ---> +// | +// V +// { Data ---> Node ---> Node ---> NetOutput } +// +// All ---> Node ---> If ---> Node ---> +// \ / +// { Node ---> Node } + /** * @ingroup ge * @brief Make label node to functional call. @@ -63,8 +76,7 @@ Status IfOpLabelMaker::Run(uint32_t &label_index) { const std::string else_enter_name = parent_node_->GetName() + "/ElseLabelSet"; // rtLabelSet(1) const std::string else_leave_name = parent_node_->GetName() + "/LeaveLabelSet"; // rtLabelSet - NodePtr then_enter_label = AddLabelSetEnter(then_sub_graph, then_label_name, then_enter_index); - if (then_enter_label == nullptr) { + if (AddLabelSetEnter(then_sub_graph, then_label_name, then_enter_index) == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label set failed.", then_sub_graph->GetName().c_str()); return FAILED; } @@ -95,12 +107,6 @@ Status IfOpLabelMaker::Run(uint32_t &label_index) { return FAILED; } - // Link control edge to then branch head. - if (GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), then_enter_label->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", then_enter_label->GetName().c_str()); - return FAILED; - } - uint32_t parent_index = 0; // If cond input is first. const std::string data_name = parent_node_->GetName() + "/SwitchIndexData"; if (AddLabelSwitchIndex(then_sub_graph, data_name, cond_desc, switch_node, parent_index) == nullptr) { diff --git a/src/ge/graph/label/label_maker.cc b/src/ge/graph/label/label_maker.cc index bf8949f0..d3701f07 100644 --- a/src/ge/graph/label/label_maker.cc +++ b/src/ge/graph/label/label_maker.cc @@ -23,92 +23,12 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" -namespace ge { -constexpr static int64_t kInvalidStreamId = -1; - -/** - * @ingroup ge - * @brief Set stream id for head node. - * @param [in] graph: graph for add node. - * @param [in] op_desc: OpDesc for set logical stream id. - * @return: void - */ -void LabelMaker::SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { - int64_t stream_id = kInvalidStreamId; - const auto &node_list = graph->GetDirectNode(); - for (size_t i = 0; i < node_list.size(); ++i) { - const auto &node = node_list.at(i); - GE_CHECK_NOTNULL_EXEC(node, continue); - - stream_id = node->GetOpDesc()->GetStreamId(); - if (stream_id != kInvalidStreamId) { - break; - } - } - - GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); - op_desc->SetStreamId(stream_id); -} - -/** - * @ingroup ge - * @brief Set stream id for tail node. - * @param [in] graph: graph for add node. - * @param [in] op_desc: OpDesc for set logical stream id. - * @return: void - */ -void LabelMaker::SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { - int64_t stream_id = kInvalidStreamId; - const auto &node_list = graph->GetDirectNode(); - for (size_t i = node_list.size(); i > 0; --i) { - const auto &node = node_list.at(i - 1); // i from list size, need shift 1. - GE_CHECK_NOTNULL_EXEC(node, continue); - - stream_id = node->GetOpDesc()->GetStreamId(); - if (stream_id != kInvalidStreamId) { - break; - } - } - - GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); - op_desc->SetStreamId(stream_id); -} - -/** - * @ingroup ge - * @brief Link Node to Graph head. - * @param [in] graph: graph for add node. - * @param [in] lb_node: Node for set link to head. - * @return: SUCCESS / FAILED - */ -Status LabelMaker::AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(node); - - std::set linked_nodes; - for (const NodePtr &n : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(n); - if (n->GetType() != DATA) { - continue; - } - - // Link control edge to graph head. - for (const NodePtr &out_node : n->GetOutAllNodes()) { - if (linked_nodes.count(out_node) > 0) { - continue; - } - - (void)linked_nodes.insert(out_node); - if (GraphUtils::AddEdge(node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str()); - return FAILED; - } - } - } - - return SUCCESS; -} +using domi::DATA; +using domi::LABELGOTO; +using domi::LABELSET; +using domi::LABELSWITCHBYINDEX; +namespace ge { /** * @ingroup ge * @brief Add LabelSet node at graph front. @@ -119,6 +39,8 @@ Status LabelMaker::AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr */ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.begin(); @@ -126,10 +48,11 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } + const NodePtr &node = *it; + GE_CHECK_NOTNULL_EXEC(node, return nullptr); OpDescPtr op_desc = MakeShared(name, LABELSET); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdEnter(graph, op_desc); GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); @@ -137,8 +60,8 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); // Link control edge to graph head. - if (AddCtrlLink2Data(graph, label_set) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", graph->GetName().c_str()); + if (GraphUtils::AddEdge(label_set->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str()); return nullptr; } @@ -155,6 +78,8 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st */ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -168,11 +93,10 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st OpDescPtr op_desc = MakeShared(name, LABELSET); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdLeave(graph, op_desc); GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); - NodePtr label_set = graph->AddNode(op_desc); + NodePtr label_set = graph->AddNodeFront(op_desc); GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); // Link control edge to graph tail. @@ -194,6 +118,8 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st */ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.begin(); @@ -201,16 +127,20 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } + const NodePtr &node = *it; + GE_CHECK_NOTNULL_EXEC(node, return nullptr); - OpDescPtr op_desc = MakeShared(name, LABELGOTOEX); + OpDescPtr op_desc = MakeShared(name, LABELGOTO); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdEnter(graph, op_desc); GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); NodePtr label_goto = graph->AddNodeFront(op_desc); - if (label_goto == nullptr) { - GELOGE(INTERNAL_ERROR, "LabelGoto: Add to graph %s failed.", graph->GetName().c_str()); + GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); + + // Link control edge to graph head. + if (GraphUtils::AddEdge(label_goto->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelGoto: Add ctrl edge to %s failed.", node->GetName().c_str()); return nullptr; } @@ -227,6 +157,8 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s */ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -238,9 +170,8 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s const NodePtr &node = *it; GE_CHECK_NOTNULL_EXEC(node, return nullptr); - OpDescPtr op_desc = MakeShared(name, LABELGOTOEX); + OpDescPtr op_desc = MakeShared(name, LABELGOTO); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdLeave(graph, op_desc); GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); @@ -268,6 +199,8 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const std::vector &labels) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.begin(); @@ -275,10 +208,11 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str()); return nullptr; } + const NodePtr &node = *it; + GE_CHECK_NOTNULL_EXEC(node, return nullptr); OpDescPtr op_desc = MakeShared(name, LABELSWITCHBYINDEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdEnter(graph, op_desc); GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { @@ -292,8 +226,11 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: } NodePtr label_switch = graph->AddNodeFront(op_desc); - if (label_switch == nullptr) { - GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add to graph %s failed.", graph->GetName().c_str()); + GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); + + // Link control edge to graph head. + if (GraphUtils::AddEdge(label_switch->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", node->GetName().c_str()); return nullptr; } @@ -312,6 +249,8 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const std::vector &labels) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); const auto &node_list = graph->GetDirectNode(); auto it = node_list.end(); @@ -325,7 +264,6 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: OpDescPtr op_desc = MakeShared(name, LABELSWITCHBYINDEX); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - SetStreamIdLeave(graph, op_desc); GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { @@ -363,10 +301,11 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std::string &name, const GeTensorDesc &desc, const NodePtr &sw_node, uint32_t parent_index) { GE_CHECK_NOTNULL_EXEC(graph, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_node_, return nullptr); + GE_CHECK_NOTNULL_EXEC(parent_graph_, return nullptr); OpDescPtr op_desc = MakeShared(name, DATA); GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); - op_desc->SetStreamId(kInvalidStreamId); GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { diff --git a/src/ge/graph/label/label_maker.h b/src/ge/graph/label/label_maker.h index 6b5ccbf7..d5878bc9 100644 --- a/src/ge/graph/label/label_maker.h +++ b/src/ge/graph/label/label_maker.h @@ -55,11 +55,6 @@ class LabelMaker { protected: NodePtr parent_node_; ComputeGraphPtr parent_graph_; - - private: - Status AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node); - void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); - void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); }; } // namespace ge #endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ diff --git a/src/ge/graph/label/partitioned_call_label_maker.cc b/src/ge/graph/label/partitioned_call_label_maker.cc index 39c88717..da36431e 100644 --- a/src/ge/graph/label/partitioned_call_label_maker.cc +++ b/src/ge/graph/label/partitioned_call_label_maker.cc @@ -22,6 +22,9 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +using domi::PARTITIONEDCALL; +using domi::STATEFULPARTITIONEDCALL; + namespace ge { constexpr int32_t kSubGraphIndex = 0; diff --git a/src/ge/graph/label/while_label_maker.cc b/src/ge/graph/label/while_label_maker.cc index c9efccd5..e2a6ddbd 100644 --- a/src/ge/graph/label/while_label_maker.cc +++ b/src/ge/graph/label/while_label_maker.cc @@ -23,6 +23,10 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +using domi::_WHILE; +using domi::STATELESSWHILE; +using domi::WHILE; + namespace ge { constexpr uint8_t kCondOutputNum = 1; constexpr uint8_t kCondOutputIndex = 0; @@ -84,7 +88,7 @@ Status WhileOpLabelMaker::Run(uint32_t &label_index) { return FAILED; } - NodePtr cond_out_node = cond_graph->FindNode(NODE_NAME_NET_OUTPUT); + NodePtr cond_out_node = cond_graph->FindNode(domi::NODE_NAME_NET_OUTPUT); GE_CHECK_NOTNULL(cond_out_node); OpDescPtr cond_out_desc = cond_out_node->GetOpDesc(); GE_CHECK_NOTNULL(cond_out_desc); diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index c58cdcb9..5f1704af 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -196,13 +196,13 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string ModelData &model_data) { Status ret; try { - if (!CheckInputPathValid(path)) { + if (!domi::CheckInputPathValid(path)) { GELOGE(PARAM_INVALID, "model path is invalid: %s", path.c_str()); return PARAM_INVALID; } GELOGI("Load model begin, model path is: %s", path.c_str()); - if (!key_path.empty() && !CheckInputPathValid(key_path)) { + if (!key_path.empty() && !domi::CheckInputPathValid(key_path)) { GELOGE(PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); return PARAM_INVALID; } diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc index 9b3c7a0f..c3de44c9 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc @@ -20,11 +20,11 @@ namespace { const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch const char *const kCpuTaskModelEnqueue = "modelEnqueue"; +const char *const kCpuTaskPrepareInput = "modelPrepareInput"; const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; -const char *const kCpuTaskPrepareOutput = "bufferPrepareOutput"; +const char *const kCpuTaskPrepareOutput = "modelPrepareOutput"; const char *const kCpuTaskModelDequeue = "modelDequeue"; const char *const kCpuTaskModelRepeat = "modelRepeat"; -const char *const kCpuTaskZeroCopy = "zeroCpy"; } // namespace namespace ge { @@ -93,19 +93,19 @@ Status CpuTaskModelDequeue::Distribute() { /// /// @ingroup ge -/// @brief definiteness queue schedule, zero copy. -/// @param [in] mbuf_list: input/output mbuf addr list for input/output data. -/// @param [in] outside_addrs: model input/output memory addr +/// @brief definiteness queue schedule, bind output queue to task. +/// @param [in] addr: NetOutput Op input tensor address. +/// @param [in] size: NetOutput Op input tensor size. +/// @param [in] in_mbuf: input mbuf addr for input data. /// @return: 0 for success / others for failed /// -Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, - std::map> &outside_addrs) { +Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf) { if ((args_ != nullptr) || (args_size_ > 0)) { GELOGE(FAILED, "Task already initialized, size: %u", args_size_); return FAILED; } - args_size_ = sizeof(AddrMapInfo); + args_size_ = sizeof(PrepareInputInfo); rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); @@ -113,99 +113,36 @@ Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) - AddrMapInfo addr_map_info; - for (const auto &addrs : outside_addrs) { - addr_map_info.addr_num += addrs.second.size(); - } - GELOGI("addr_map_info.addr_num is %zu", addr_map_info.addr_num); - - // init src_addrs/dst_addrs - size_t index = 0; - vector src_addrs; - vector dst_addrs; - for (const auto &addrs : outside_addrs) { - for (size_t i = 0; i < addrs.second.size(); ++i) { - src_addrs.push_back(mbuf_list.at(index)); - dst_addrs.push_back(reinterpret_cast(addrs.second.at(i))); - } - index++; - } - - // malloc mem for src_addrs/dst_addrs, and copy data of src_addrs/dst_addrs - status = rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; - } - status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), - src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; - } - - status = rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; - } - status = rtMemcpy(dst_addr_, dst_addrs.size() * sizeof(uint64_t), dst_addrs.data(), - dst_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + PrepareInputInfo prepare; + prepare.in_mbuf = in_mbuf; + prepare.mbuf_offset = 0; + prepare.data_size = size; + prepare.data_addr = addr; + status = rtMemcpy(args_, args_size_, &prepare, args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); return RT_FAILED; } - // src_addr_list is init to src_addr, which is the point to src_addrs - if (!src_addrs.empty() && !dst_addrs.empty()) { - addr_map_info.src_addr_list = reinterpret_cast(src_addr_); - addr_map_info.dst_addr_list = reinterpret_cast(dst_addr_); - GELOGI("src_addr_list is %lu, dst_addr_list is %lu", addr_map_info.src_addr_list, addr_map_info.dst_addr_list); - } - - status = rtMemcpy(args_, args_size_, &addr_map_info, sizeof(AddrMapInfo), RT_MEMCPY_HOST_TO_DEVICE); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; - } return SUCCESS; } -Status CpuTaskZeroCopy::Distribute() { +Status CpuTaskPrepareInput::Distribute() { if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); return FAILED; } - rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskZeroCopy, kCoreDim, args_, args_size_, nullptr, stream_); + rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskPrepareInput, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ZeroCopy failed, status: 0x%X", status); + GELOGE(RT_FAILED, "Call rt CpuKernelLaunch PrepareInput failed, status: 0x%X", status); return RT_FAILED; } - GELOGI("Cpu kernel launch zero copy task success."); + GELOGI("Cpu kernel launch prepare input task success."); return SUCCESS; } -CpuTaskZeroCopy::~CpuTaskZeroCopy() { - if (src_addr_ == nullptr && dst_addr_ == nullptr) { - return; - } - if (src_addr_ != nullptr) { - rtError_t status = rtFree(src_addr_); - if (status != RT_ERROR_NONE) { - GELOGW("Call rt free failed, status: 0x%x", status); - } - } - if (dst_addr_ != nullptr) { - rtError_t status = rtFree(dst_addr_); - if (status != RT_ERROR_NONE) { - GELOGW("Call rt free failed, status: 0x%x", status); - } - } - src_addr_ = nullptr; - dst_addr_ = nullptr; -} /// /// @ingroup ge /// @brief definiteness queue schedule, bind output queue to task. diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h index c4ae4df5..8a9af63f 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h @@ -47,13 +47,6 @@ struct PrepareOutputInfo { uintptr_t out_mbuf; // output mbuf addr }; -// For AICPU task "modelZeroCopy" -struct AddrMapInfo { - uint32_t addr_num = 0; - uint64_t src_addr_list; - uint64_t dst_addr_list; -}; - /// /// @ingroup ge /// @brief CpuTask base, inherit from TaskInfo used for manage. @@ -85,21 +78,17 @@ class CpuTaskModelDequeue : public CpuTaskInfo { /// /// @ingroup ge -/// @brief definiteness queue schedule, zero copy. +/// @brief definiteness queue schedule, bind output queue to task. /// -class CpuTaskZeroCopy : public CpuTaskInfo { +class CpuTaskPrepareInput : public CpuTaskInfo { public: - explicit CpuTaskZeroCopy(rtStream_t stream) : CpuTaskInfo(stream) {} - ~CpuTaskZeroCopy() override; + explicit CpuTaskPrepareInput(rtStream_t stream) : CpuTaskInfo(stream) {} + ~CpuTaskPrepareInput() override {} Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } - Status Init(std::vector &mbuf_list, std::map> &outside_addrs); + Status Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf); Status Distribute() override; - - private: - void *src_addr_ = nullptr; - void *dst_addr_ = nullptr; }; /// diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index 824f6b18..85bbd5bc 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -15,20 +15,19 @@ */ #include "graph/load/new_model_manager/data_dumper.h" -#include -#include -#include #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/anchor.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/attr_utils.h" -#include "graph/load/new_model_manager/model_utils.h" +#include "model_utils.h" #include "proto/ge_ir.pb.h" #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" +using domi::ENDGRAPH; + namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index 4400b127..823f7079 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -17,10 +17,8 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DATA_DUMPER_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DATA_DUMPER_H_ -#include -#include #include -#include +#include #include "framework/common/ge_inner_error_codes.h" #include "graph/node.h" diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index 19c0ab16..64a106ef 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -16,20 +16,19 @@ #include "graph/load/new_model_manager/davinci_model.h" -#include #include #include #include #include #include #include -#include #include #include "common/debug/log.h" #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" #include "common/math/math_util.h" +#include "common/op/attr_define.h" #include "common/op/ge_op_utils.h" #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" @@ -78,7 +77,6 @@ namespace { const uint32_t kDataIndex = 0; const uint32_t kTrueBranchStreamNum = 1; const uint32_t kThreadNum = 16; -const uint32_t kAddrLen = sizeof(void *); const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; @@ -204,7 +202,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats formats::TransResult result_last_time{}; bool use_init_data = true; for (const auto &trans_info : trans_road) { - if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { + if (trans_info.node_type == domi::RESHAPE || trans_info.node_type == domi::REFORMAT) { GELOGD("Skip to trans variable data on the reshape/reformat node"); continue; } @@ -217,7 +215,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats } formats::TransResult tmp_result{}; - if (trans_info.node_type == TRANSDATA) { + if (trans_info.node_type == domi::TRANSDATA) { auto src_format = trans_info.input.GetFormat(); auto src_shape = trans_info.input.GetShape().GetDims(); auto dst_format = trans_info.output.GetFormat(); @@ -237,9 +235,9 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); return ret; } - } else if (trans_info.node_type == CAST) { + } else if (trans_info.node_type == domi::CAST) { auto input_shape = trans_info.input.GetShape(); - auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); + auto src_data_size = input_shape.GetShapeSize(); auto src_data_type = trans_info.input.GetDataType(); auto dst_data_type = trans_info.output.GetDataType(); GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", @@ -303,7 +301,7 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t GE_CHECK_NOTNULL(var); bool need_trans = false; for (auto &road : trans_road) { - if (road.node_type != RESHAPE && road.node_type != REFORMAT) { + if (road.node_type != domi::RESHAPE && road.node_type != domi::REFORMAT) { need_trans = true; break; } @@ -371,7 +369,7 @@ bool CheckDynamicBatchZeroCopyAddr(const void *addr, const vector &dynam } inline bool IsDataOp(const std::string &node_type) { - return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; + return node_type == domi::DATA_TYPE || node_type == domi::AIPP_DATA_TYPE || node_type == domi::ANN_DATA_TYPE; } inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { bool skip_task_generate = false; @@ -381,7 +379,7 @@ inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { } // namespace -SysMode DavinciModel::mode_ = INFERENCE; +domi::SysMode DavinciModel::mode_ = domi::INFERENCE; std::mutex DavinciModel::mutex_mode_; std::mutex DavinciModel::tvm_bin_mutex_; @@ -410,18 +408,16 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptrGetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGW("Node OpDesc is nullptr"); continue); - GE_IF_BOOL_EXEC(((op_desc->GetType() == HCOMBROADCAST) || (op_desc->GetType() == HCOMALLGATHER) || - (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || - (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER)), + GE_IF_BOOL_EXEC(((op_desc->GetType() == domi::HCOMBROADCAST) || (op_desc->GetType() == domi::HCOMALLGATHER) || + (op_desc->GetType() == domi::HCOMALLREDUCE) || (op_desc->GetType() == domi::HCOMSEND) || + (op_desc->GetType() == domi::HCOMRECEIVE) || (op_desc->GetType() == domi::HCOMREDUCESCATTER)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); @@ -648,16 +644,19 @@ Status DavinciModel::DoTaskSink() { if (model_task_def_) { GELOGI("do task_sink."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(InitTaskInfo(*model_task_def_.get()) != SUCCESS, return FAILED, - "InitTaskInfo failed."); - GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); // will adjust stream indication, load fist. + GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); + for (size_t i = 0; i < stream_list_.size(); i++) { GE_IF_BOOL_EXEC(active_stream_indication_.count(i) > 0, GELOGI("rtModelBindStream[%zu]", i); GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_INVALID_FLAG)); continue;); // bind rt_model_handel to all streams that relates to op - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], RT_HEAD_STREAM)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, stream_list_[i], 0)); } + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(InitTaskInfo(*model_task_def_.get()) != SUCCESS, return FAILED, + "InitTaskInfo failed."); + GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); @@ -738,7 +737,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size event_list_.push_back(rt_event); } - label_list_.resize(LabelNum(), nullptr); + for (uint32_t i = 0; i < LabelNum(); i++) { + rtLabel_t rt_label; + GE_CHK_RT_RET(rtLabelCreate(&rt_label)); + GE_CHK_BOOL_RET_STATUS(rt_label != nullptr, FAILED, "rt_label is nullptr."); + label_list_.push_back(rt_label); + } // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); @@ -767,7 +771,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(IsBroadCastOpData(node), - (void)ge::AttrUtils::SetStr(node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); } // for profiling op_name_map_ = compute_graph->GetGraphOpName(); @@ -813,17 +817,11 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size /// Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { uint32_t data_op_index = 0; + std::map> input_data_info; + GE_TIMESTAMP_CALLNUM_START(LoadTBEKernelBinToOpDesc); GE_TIMESTAMP_CALLNUM_START(InitTbeHandle); - typedef Status (DavinciModel::*OpDescCall)(const OpDescPtr &); - static std::map op_desc_handle = { - {VARIABLE, &DavinciModel::InitVariable}, {CONSTANTOP, &DavinciModel::InitConstant}, - {NETOUTPUT, &DavinciModel::InitNetOutput}, {ENDGRAPH, &DavinciModel::InitEndGraph}, - {STREAMACTIVE, &DavinciModel::InitStreamActive}, {STREAMSWITCH, &DavinciModel::InitStreamSwitch}, - {STREAMSWITCHN, &DavinciModel::InitStreamSwitchN}, {LABELSET, &DavinciModel::InitLabelSet}, - }; - auto nodes = compute_graph->GetAllNodes(); const TBEKernelStore &tbekernel_store = ge_model_->GetTBEKernelStore(); for (size_t i = 0; i < nodes.size(); i++) { @@ -841,7 +839,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { GE_TIMESTAMP_ADD(LoadTBEKernelBinToOpDesc); if (IsDataOp(op_desc->GetType())) { - if (InitDataOp(node, data_op_index) != SUCCESS) { + if (InitDataOp(node, data_op_index, input_data_info) != SUCCESS) { GELOGE(PARAM_INVALID, "Data init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } @@ -855,15 +853,32 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } - auto it = op_desc_handle.find(op_desc->GetType()); - if (it != op_desc_handle.end()) { - if ((this->*it->second)(op_desc) != SUCCESS) { + if (op_desc->GetType() == VARIABLE) { + variable_op_list_.push_back(op_desc); + continue; + } + + if (op_desc->GetType() == NETOUTPUT) { + if (InitNetOutput(op_desc) != SUCCESS) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); return PARAM_INVALID; } continue; } + // Initialize constant op, only applies to training, ignoring inference constant op + if (op_desc->GetType() == CONSTANTOP) { + if (InitConstant(op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "Constant init failed. %s", op_desc->GetName().c_str()); + return PARAM_INVALID; + } + continue; + } + + if (op_desc->GetType() == ENDGRAPH) { + end_graph_op_ = op_desc; + } + GE_TIMESTAMP_RESTART(InitTbeHandle); uint32_t run_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode) && @@ -882,11 +897,17 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { } } GE_TIMESTAMP_ADD(InitTbeHandle); + + if (MarkActiveStream(op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "MarkActiveStream failed, node:%s, opIndex:%zu", op_desc->GetName().c_str(), i); + return PARAM_INVALID; + } } + Status ret = CombineDataInfo(input_data_info); GE_TIMESTAMP_CALLNUM_END(LoadTBEKernelBinToOpDesc, "GraphLoader::LoadTBEKernelBinToOpDesc."); GE_TIMESTAMP_CALLNUM_END(InitTbeHandle, "GraphLoader::InitTbeHandle."); - return SUCCESS; + return ret; } /// @ingroup ge @@ -895,7 +916,8 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { /// @param [in/out] data_op_index: NetOutput addr size info. /// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status -Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { +Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index, + std::map> &input_data_info) { // op_desc Checked by Init: Data, valid. auto op_desc = node->GetOpDesc(); uint32_t parent_index = 0; // Ignore subgraph Data Node. @@ -917,20 +939,20 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // Make information for copy input data. const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc, false); - if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { + const vector output_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); + if (output_size_list.empty() || output_addr_list.empty() || (output_size_list.size() != output_addr_list.size())) { GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), - output_size_list.size(), virtual_addr_list.size()); + output_size_list.size(), output_addr_list.size()); return PARAM_INVALID; } auto data_index = data_op_index; if (AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { - GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); + GELOGI("ge_train:get new index %u, old %u", data_index, data_op_index); } - input_data_info_[data_index] = {output_size_list[kDataIndex], virtual_addr_list[kDataIndex]}; - SetInputOutsideAddr(virtual_addr_list); + input_data_info[data_index] = {output_size_list[kDataIndex], output_addr_list[kDataIndex]}; + SetInputOutsideAddr(output_addr_list); data_op_index++; if (InitInputZeroCopyNodes(node) != SUCCESS) { GELOGE(PARAM_INVALID, "Input zero copy nodes init failed!"); @@ -993,78 +1015,43 @@ Status DavinciModel::InitNetOutput(const OpDescPtr &op_desc) { // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); - if (input_size_list.empty() && virtual_addr_list.empty()) { + const vector input_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc); + if (input_size_list.empty() && input_addr_list.empty()) { GELOGI("NetOutput[%s] is empty.", op_desc->GetName().c_str()); return SUCCESS; } - if (input_size_list.empty() || input_size_list.size() != virtual_addr_list.size() || + if (input_size_list.empty() || input_size_list.size() != input_addr_list.size() || input_size_list.size() != output_size_list.size()) { GELOGE(PARAM_INVALID, "NetOutput[%s] init failed: Input size is %zu, Input addr is %zu, Output size is %zu", - op_desc->GetName().c_str(), input_size_list.size(), virtual_addr_list.size(), output_size_list.size()); + op_desc->GetName().c_str(), input_size_list.size(), input_addr_list.size(), output_size_list.size()); return PARAM_INVALID; } - size_t num = output_data_info_.size(); - for (size_t idx = 0; idx < input_size_list.size(); ++idx) { - output_data_info_[num + idx] = {input_size_list[idx], virtual_addr_list[idx]}; - } - - SetOutputOutsideAddr(virtual_addr_list); + output_size_list_.insert(output_size_list_.end(), input_size_list.begin(), input_size_list.end()); + output_addr_list_.insert(output_addr_list_.end(), input_addr_list.begin(), input_addr_list.end()); + SetOutputOutsideAddr(input_addr_list); return SUCCESS; } /// @ingroup ge -/// @brief LabelSet Op Initialize. -/// @param [in] op_desc: LabelSet Op descriptor. +/// @brief Make Input and Output addr for feature use. +/// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status -Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) { - uint32_t label_index = 0; - if (!AttrUtils::GetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, label_index)) { - GELOGE(INTERNAL_ERROR, "InitLabelSet: %s attr [%s] not exist.", op_desc->GetName().c_str(), - ATTR_NAME_LABEL_SWITCH_INDEX.c_str()); - return INTERNAL_ERROR; - } - if (label_index >= LabelNum()) { - GELOGE(INTERNAL_ERROR, "InitLabelSet: label index: %u >= label size: %zu.", label_index, LabelNum()); - return INTERNAL_ERROR; - } - if (label_id_indication_.count(label_index) > 0) { - GELOGE(INTERNAL_ERROR, "InitLabelSet: %s label index: %u already used.", op_desc->GetName().c_str(), label_index); - return INTERNAL_ERROR; - } - - rtStream_t stream = nullptr; - uint32_t stream_id = static_cast(op_desc->GetStreamId()); - if (stream_list_.size() == 1) { - stream = stream_list_[0]; - } else if (stream_list_.size() > stream_id) { - stream = stream_list_[stream_id]; - } else { - GELOGE(INTERNAL_ERROR, "InitLabelSet: stream index: %u >= stream size: %zu.", stream_id, stream_list_.size()); - return INTERNAL_ERROR; - } - - rtLabel_t rt_label = nullptr; - rtError_t rt_error = rtLabelCreate(&rt_label); - if (rt_error != RT_ERROR_NONE || rt_label == nullptr) { - GELOGE(INTERNAL_ERROR, "InitLabelSet: %s create label failed, error=0x%x.", op_desc->GetName().c_str(), rt_error); - return INTERNAL_ERROR; +Status DavinciModel::CombineDataInfo(const std::map> &input_data_info) { + input_size_list_.resize(data_op_list_.size()); + input_addr_list_.resize(data_op_list_.size()); + for (size_t index = 0; index < data_op_list_.size(); ++index) { + auto it = input_data_info.find(index); + if (it == input_data_info.end()) { + GELOGE(PARAM_INVALID, "Data init failed: index %zu, Data Op size is %zu, Input addr is %zu", index, + data_op_list_.size(), input_data_info.size()); + return INTERNAL_ERROR; + } + input_size_list_[index] = it->second.first; + input_addr_list_[index] = it->second.second; } - GELOGI("InitLabelSet: label[%u]=%p stream[%u]=%p.", label_index, rt_label, stream_id, stream); - label_id_indication_.insert(label_index); - label_list_[label_index] = rt_label; - return SUCCESS; -} - -Status DavinciModel::InitVariable(const OpDescPtr &op_desc) { - variable_op_list_.push_back(op_desc); - return SUCCESS; -} - -Status DavinciModel::InitEndGraph(const OpDescPtr &op_desc) { - end_graph_op_ = op_desc; + GELOGI("Data init success, input size %zu, output size %zu", input_size_list_.size(), output_size_list_.size()); return SUCCESS; } @@ -1097,34 +1084,31 @@ Status DavinciModel::LoadWithQueue() { return SUCCESS; } - if (input_queue_ids_.size() != input_data_info_.size()) { + if (input_queue_ids_.size() != data_op_list_.size()) { GELOGE(PARAM_INVALID, "Input queue ids not match model: input_queue=%zu input_data=%zu", input_queue_ids_.size(), - input_data_info_.size()); + data_op_list_.size()); return PARAM_INVALID; } - if (output_queue_ids_.size() != output_data_info_.size()) { + if (output_queue_ids_.size() != output_size_list_.size()) { GELOGE(PARAM_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", - output_queue_ids_.size(), output_data_info_.size()); + output_queue_ids_.size(), output_size_list_.size()); return PARAM_INVALID; } // create stream instance which rt_model_handel is running on, this is S0. GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_model_stream_, priority_, RT_STREAM_AICPU)); is_inner_model_stream_ = true; - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, RT_HEAD_STREAM)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, 0)); // Binding input_queue and Data Op. GE_CHK_STATUS_RET(BindInputQueue(), "Launch bind input queue failed."); - GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, input_outside_addrs_), "Launch zero copy failed."); - - // Binding output_queue and NetOutput Op. - GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); - GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, output_outside_addrs_), "Launch zero copy failed."); GE_CHK_STATUS_RET(BindActiveStream(), "Launch active entry stream failed."); GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); - GE_CHK_STATUS_RET(BindEnqueue(), "Launch enqueue failed.") + + // Binding output_queue and NetOutput Op. + GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); GE_CHK_STATUS_RET(CpuModelRepeat(), "Launch model repeat failed."); return SUCCESS; @@ -1136,15 +1120,9 @@ Status DavinciModel::LoadWithQueue() { Status DavinciModel::BindInputQueue() { // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() for (size_t i = 0; i < input_queue_ids_.size(); ++i) { - auto it = input_data_info_.find(i); - if (it == input_data_info_.end()) { - GELOGE(FAILED, "Input not match: tensor num=%zu, Queue id index=%zu", input_data_info_.size(), i); - return FAILED; - } - uint32_t queue_id = input_queue_ids_[i]; - uint32_t data_size = static_cast(it->second.first); - uintptr_t data_addr = reinterpret_cast(it->second.second); + uint32_t data_size = input_size_list_[i]; + uintptr_t data_addr = reinterpret_cast(input_addr_list_[i]); GELOGI("BindInputToQueue: graph_%u index[%zu] queue id[%u] output addr[0x%lx] output size[%u]", runtime_param_.graph_id, i, queue_id, data_addr, data_size); @@ -1152,7 +1130,31 @@ Status DavinciModel::BindInputQueue() { return INTERNAL_ERROR; } - if (CpuModelDequeue(queue_id) != SUCCESS) { + if (CpuModelDequeue(queue_id, data_addr, data_size) != SUCCESS) { + return INTERNAL_ERROR; + } + } + + return SUCCESS; +} + +/// @ingroup ge +/// @brief queue schedule, bind output queue to NetOutput input address. +/// @return: 0 for success / others for failed +Status DavinciModel::BindOutputQueue() { + // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() + for (size_t i = 0; i < output_queue_ids_.size(); ++i) { + uint32_t queue_id = output_queue_ids_[i]; + uint32_t data_size = output_size_list_[i]; + uintptr_t data_addr = reinterpret_cast(output_addr_list_[i]); + GELOGI("BindOutputToQueue: graph_%u index[%zu] queue id[%u] input addr[0x%lx] input size[%u]", + runtime_param_.graph_id, i, queue_id, data_addr, data_size); + + if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE) != RT_ERROR_NONE) { + return INTERNAL_ERROR; + } + + if (CpuModelEnqueue(queue_id, data_addr, data_size) != SUCCESS) { return INTERNAL_ERROR; } } @@ -1160,13 +1162,34 @@ Status DavinciModel::BindInputQueue() { return SUCCESS; } +/// @ingroup ge +/// @brief queue schedule, active stream will schedule by S0. +/// @return: 0 for success / others for failed +Status DavinciModel::BindActiveStream() { + // Stream not in active_stream_indication_ is active stream. + std::vector active_stream_list; + for (size_t i = 0; i < stream_list_.size(); ++i) { + if (active_stream_indication_.count(i) == 0) { + active_stream_list.push_back(stream_list_[i]); + active_stream_indication_.insert(i); // deactive all model stream. + } + } + + // Active stream add to active entry, will active by S0. + if (CpuActiveStream(active_stream_list) != SUCCESS) { + return INTERNAL_ERROR; + } + + return SUCCESS; +} + /// @ingroup ge /// @brief definiteness queue schedule, bind input queue to task. /// @param [in] queue_id: input queue id from user. /// @param [in] addr: Data Op output tensor address. /// @param [in] size: Data Op output tensor size. /// @return: 0 for success / others for failed -Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { +Status DavinciModel::CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size) { GELOGI("Set CpuKernel model dequeue task enter."); std::shared_ptr dequeue_task = MakeShared(rt_model_stream_); if (dequeue_task == nullptr) { @@ -1180,55 +1203,20 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { return FAILED; } - cpu_task_list_.push_back(dequeue_task); - input_mbuf_list_.push_back(in_mbuf); - GELOGI("Set CpuKernel model dequeue task success."); - return SUCCESS; -} - -Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, - std::map> &outside_addrs) { - GELOGI("Set CpuKernel model zero_copy task enter."); - std::shared_ptr zero_copy = MakeShared(rt_model_stream_); - if (zero_copy == nullptr) { - GELOGE(FAILED, "Make CpuTaskZeroCopy task failed."); + std::shared_ptr prepare_input = MakeShared(rt_model_stream_); + if (dequeue_task == nullptr) { + GELOGE(FAILED, "Make CpuTaskPrepareInput task failed."); return FAILED; } - if (zero_copy->Init(mbuf_list, outside_addrs) != SUCCESS) { + if (prepare_input->Init(addr, size, in_mbuf) != SUCCESS) { return FAILED; } - cpu_task_list_.push_back(zero_copy); - GELOGI("Set CpuKernel model zero_copy task success."); - return SUCCESS; -} - -/// @ingroup ge -/// @brief queue schedule, bind output queue to NetOutput input address. -/// @return: 0 for success / others for failed -Status DavinciModel::BindOutputQueue() { - // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() - for (size_t i = 0; i < output_queue_ids_.size(); ++i) { - auto it = output_data_info_.find(i); - if (it == output_data_info_.end()) { - GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); - return FAILED; - } - - uint32_t queue_id = output_queue_ids_[i]; - uint32_t data_size = static_cast(it->second.first); - uintptr_t data_addr = reinterpret_cast(it->second.second); - GELOGI("BindOutputToQueue: graph_%u index[%zu] queue id[%u] input addr[0x%lx] input size[%u]", - runtime_param_.graph_id, i, queue_id, data_addr, data_size); - - if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE) != RT_ERROR_NONE) { - return INTERNAL_ERROR; - } - if (CpuModelPrepareOutput(data_addr, data_size) != SUCCESS) { - return INTERNAL_ERROR; - } - } + cpu_task_list_.push_back(dequeue_task); + cpu_task_list_.push_back(prepare_input); + input_mbuf_list_.push_back(in_mbuf); + GELOGI("Set CpuKernel model dequeue task success."); return SUCCESS; } @@ -1238,7 +1226,7 @@ Status DavinciModel::BindOutputQueue() { /// @param [in] addr: NetOutput Op input tensor address. /// @param [in] size: NetOutput Op input tensor size. /// @return: 0 for success / others for failed -Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { +Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t addr, uint32_t size) { GELOGI("Set CpuKernel model enqueue task enter."); if (input_mbuf_list_.empty()) { GELOGE(FAILED, "Need input mbuf for fill output mbuf head info."); @@ -1256,30 +1244,20 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { return FAILED; } - cpu_task_list_.push_back(prepare_output); - output_mbuf_list_.push_back(out_mbuf); - GELOGI("Set CpuKernel model enqueue task success."); - return SUCCESS; -} - -/// @ingroup ge -/// @brief queue schedule, active stream will schedule by S0. -/// @return: 0 for success / others for failed -Status DavinciModel::BindActiveStream() { - // Stream not in active_stream_indication_ is active stream. - std::vector active_stream_list; - for (size_t i = 0; i < stream_list_.size(); ++i) { - if (active_stream_indication_.count(i) == 0) { - active_stream_list.push_back(stream_list_[i]); - active_stream_indication_.insert(i); // deactive all model stream. - } + std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); + if (model_enqueue == nullptr) { + GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); + return FAILED; } - // Active stream add to active entry, will active by S0. - if (CpuActiveStream(active_stream_list) != SUCCESS) { - return INTERNAL_ERROR; + if (model_enqueue->Init(queue_id, out_mbuf) != SUCCESS) { + return FAILED; } + cpu_task_list_.push_back(prepare_output); + cpu_task_list_.push_back(model_enqueue); + output_mbuf_list_.push_back(out_mbuf); + GELOGI("Set CpuKernel model enqueue task success."); return SUCCESS; } @@ -1329,38 +1307,6 @@ Status DavinciModel::CpuWaitEndGraph() { return SUCCESS; } -Status DavinciModel::BindEnqueue() { - for (size_t i = 0; i < output_queue_ids_.size(); ++i) { - auto it = output_data_info_.find(i); - if (it == output_data_info_.end()) { - GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); - return FAILED; - } - - uint32_t queue_id = output_queue_ids_[i]; - if (CpuModelEnqueue(queue_id, output_mbuf_list_[i]) != SUCCESS) { - return INTERNAL_ERROR; - } - } - return SUCCESS; -} - -Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { - GELOGI("Set CpuKernel model enqueue task enter."); - std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); - if (model_enqueue == nullptr) { - GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); - return FAILED; - } - - if (model_enqueue->Init(queue_id, out_mbuf) != SUCCESS) { - return FAILED; - } - cpu_task_list_.push_back(model_enqueue); - GELOGI("Set CpuKernel model enqueue task enter."); - return SUCCESS; -} - /// @ingroup ge /// @brief definiteness queue schedule, repeat run model. /// @return: 0 for success / others for failed @@ -1385,7 +1331,7 @@ Status DavinciModel::CpuModelRepeat() { /// @brief get sys mode /// @return SysMode required system mode /// @author -SysMode DavinciModel::GetSysMode() { +domi::SysMode DavinciModel::GetSysMode() { std::unique_lock lock(mutex_mode_); return mode_; } @@ -1395,8 +1341,8 @@ SysMode DavinciModel::GetSysMode() { /// @param [in] mode to be set /// @return Status mode set result /// @author -Status DavinciModel::SetSysMode(SysMode mode) { - GE_CHK_BOOL_RET_STATUS(mode < RESERVED, PARAM_INVALID, "DavinciModel::SetSysMode Para Error"); +Status DavinciModel::SetSysMode(domi::SysMode mode) { + GE_CHK_BOOL_RET_STATUS(mode < domi::RESERVED, PARAM_INVALID, "DavinciModel::SetSysMode Para Error"); std::unique_lock lock(mutex_mode_); mode_ = mode; @@ -1475,7 +1421,7 @@ Status DavinciModel::GetDynamicBatchInfo(std::vector> &batc return FAILED; } - if (op_desc->GetType() != STREAMSWITCHN) { + if (op_desc->GetType() != domi::STREAMSWITCHN) { continue; } @@ -1536,12 +1482,13 @@ Status DavinciModel::GetInputDescInfo(vector &input_desc, s GE_CHECK_NOTNULL(data_op_list_[index]); GE_CHECK_NOTNULL(data_op_list_[index]->GetInputDescPtr(0)); Format format = data_op_list_[index]->GetInputDescPtr(0)->GetFormat(); - n = format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N; - c = format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C; - h = format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H; - w = format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W; + n = format == FORMAT_NHWC ? domi::NHWC_DIM_N : domi::NCHW_DIM_N; + c = format == FORMAT_NHWC ? domi::NHWC_DIM_C : domi::NCHW_DIM_C; + h = format == FORMAT_NHWC ? domi::NHWC_DIM_H : domi::NCHW_DIM_H; + w = format == FORMAT_NHWC ? domi::NHWC_DIM_W : domi::NCHW_DIM_W; - if (data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDimNum() == static_cast(NORMAL_TENSOR_SIZE)) { + if (data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDimNum() == + static_cast(domi::NORMAL_TENSOR_SIZE)) { input.shape_info.num = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(n); input.shape_info.height = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(h); input.shape_info.width = data_op_list_[index]->GetInputDescPtr(0)->GetShape().GetDim(w); @@ -1578,11 +1525,11 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD for (size_t i = 0; i < shape.GetDimNum() && i < (sizeof(dims) / sizeof(dims[0])); i++) { dims[i] = shape.GetDim(i); } - } else { // FOR FORMAT_NHWC or FORMAT_NCHW - dims[0] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N); // 0: first dim - dims[1] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C); // 1: second dim - dims[2] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H); // 2: third dim - dims[3] = shape.GetDim(format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W); // 3: forth dim + } else { // FOR FORMAT_NHWC or FORMAT_NCHW + dims[0] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_N : domi::NCHW_DIM_N); // 0: first dim + dims[1] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_C : domi::NCHW_DIM_C); // 1: second dim + dims[2] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_H : domi::NCHW_DIM_H); // 2: third dim + dims[3] = shape.GetDim(format == FORMAT_NHWC ? domi::NHWC_DIM_W : domi::NCHW_DIM_W); // 3: forth dim } output.shape_info.num = dims[0]; // 0: first dim output.shape_info.channel = dims[1]; // 1: second dim @@ -1657,42 +1604,24 @@ ge::Format DavinciModel::GetFormat() { } Status DavinciModel::CopyInputData(const InputData ¤t_data, bool device_data) { - rtMemcpyKind_t kind = device_data ? RT_MEMCPY_DEVICE_TO_DEVICE : RT_MEMCPY_HOST_TO_DEVICE; - const std::vector &blobs = current_data.blobs; - for (const auto &data : input_data_info_) { - if (data.first >= blobs.size()) { - GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), - input_data_info_.size(), data.first, data.second.first); - return FAILED; - } - - const DataBuffer &data_buf = blobs[data.first]; - // if data attr support zero copy, then update addrs info to flowtable - bool flag = data_buf.isDataSupportMemShare && support_mem_shared_flag_; - if (flag) { - GELOGI("No need to copy input data, user's input data buffer can be shared."); - continue; - } + Status ret = SUCCESS; + uint32_t data_op_index = 0; - void *mem_addr = data.second.second; - uint32_t mem_size = static_cast(data.second.first); - GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, - "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, - mem_size); + for (auto op_desc : data_op_list_) { + ret = CopyInputDataToModel(current_data.blobs, data_op_index, device_data); - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, data.first, mem_addr, mem_size, data_buf.length); - GE_CHK_RT_RET(rtMemcpy(mem_addr, mem_size, data_buf.data, data_buf.length, kind)); + GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy input data to model ret failed, index:%u, model id:%u", + current_data.index, current_data.model_id); + data_op_index++; } - - return SUCCESS; + return ret; } Status DavinciModel::SyncVarData() { GELOGI("Sync var data, model id:%u", model_id_); Status ret = SUCCESS; - OpDescPtr global_step = GetVariableOp(NODE_NAME_GLOBAL_STEP); + OpDescPtr global_step = GetVariableOp(domi::NODE_NAME_GLOBAL_STEP); if (global_step != nullptr) { auto v_output_size = ModelUtils::GetOutputSize(global_step); auto v_output_addr = ModelUtils::GetOutputDataAddrs(runtime_param_, global_step); @@ -2003,6 +1932,140 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { } return; } +/// +/// @ingroup domi_ome +/// @brief copy input data to Model's firat OP. Address already malloced when Load +/// @copy need datatype transfer: FLOAT to FP16, 4D to 5D; +/// @param [in] data data pointer to be copy +/// @return Status result +/// @author +/// +Status DavinciModel::CopyInputDataToModel(const std::vector &data, uint32_t data_op_index, + bool device_data) { + GE_CHK_BOOL_RET_STATUS(!data_op_list_.empty(), PARAM_INVALID, "data_op_list_ is empty!"); + + GE_CHK_BOOL_RET_STATUS(data_op_list_.size() == data.size(), PARAM_INVALID, + "The input data list size (%zu) does not match the model input list size (%zu)", data.size(), + data_op_list_.size()); + + GE_CHK_BOOL_RET_STATUS(data_op_index < data_op_list_.size(), PARAM_INVALID, + "input data op index(%zu) is invalid, exceeds input op size(%zu)", data_op_index, + data_op_list_.size()); + + /// input datatype conversion, converting FLOAT to FP16, 4D to 5D at the same time. + /// Choose respective mode in API parameters. + auto op_def = data_op_list_[data_op_index]; + GE_CHK_BOOL_EXEC(op_def != nullptr, return PARAM_INVALID, "op_def is null!"); + + auto data_index = data_op_index; + if (AttrUtils::GetInt(op_def, "index", data_index)) { + GELOGI("ge_train:get new index %u , old %u", data_index, data_op_index); + } + + GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); + GE_CHK_BOOL_RET_STATUS(op_def->GetInputsSize() == 1 && op_def->GetOutputsSize() == 1, PARAM_INVALID, + "Data Op has invalid input_desc_size(%zu) or output_desc_size(%zu)", op_def->GetInputsSize(), + op_def->GetOutputsSize()); + + // float to float16 + bool need_trans_flag = ModelUtils::IsInputTensorNeedTrans(data_op_list_[data_op_index], 0); + + int64_t output_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(*op_def->GetOutputDescPtr(0), output_size), "get output size failed."); + GE_CHK_BOOL_RET_STATUS(output_size >= data[data_index].length, PARAM_INVALID, + "input data size(%u) does not match model required size(%zu), ret failed.", + data[data_index].length, output_size); + + vector outputs = op_def->GetOutputOffset(); + if (device_data) { + return CopyPlainData(data, data_index, data_op_index, outputs, RT_MEMCPY_DEVICE_TO_DEVICE); + } else if (need_trans_flag) { + return CopyTransData(data, data_index, data_op_index, outputs); + } else { + return CopyPlainData(data, data_index, data_op_index, outputs, RT_MEMCPY_HOST_TO_DEVICE); + } +} + +Status DavinciModel::CopyTransData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, + const std::vector &outputs) { + GE_CHECK_VECTOR_NOT_EMPTY(outputs); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(outputs[0] == -1, return PARAM_INVALID, "output offset is -1"); + GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); + + auto input_tensor_desc = data_op_input_tensor_desc_map_[data_op_list_[data_op_index]->GetName()]; + auto output_tensor_desc = data_op_output_tensor_desc_map_[data_op_list_[data_op_index]->GetName()]; + + uint8_t *src_data = reinterpret_cast(data[data_index].data); + + formats::TransResult tmp_result{}; + auto input_shape = input_tensor_desc->GetShape(); + auto src_data_size = input_shape.GetShapeSize(); + auto src_data_type = input_tensor_desc->GetDataType(); + auto dst_data_type = output_tensor_desc->GetDataType(); + GELOGD("Trans data type from %s to %s, input shape %s, data size %zu", + TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), + src_data_size); + auto ret = + formats::TransDataType({src_data, static_cast(src_data_size), src_data_type, dst_data_type}, tmp_result); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %zu, error code %u", + TypeUtils::DataTypeToSerialString(src_data_type).c_str(), + TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), + src_data_size, ret); + return ret; + } + + void *mem_addr = mem_base_ + outputs[0]; + auto rt_ret = rtMemcpy(mem_addr, static_cast(runtime_param_.mem_size - outputs[0]), + reinterpret_cast(tmp_result.data.get()), static_cast(tmp_result.length), + RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", tmp_result.length); + return RT_FAILED; + } + GELOGI("[IMAS]CopyTransData memcpy graph_%u type[F] name[%s] output[%d] memaddr[%p] datasize[%zu]", + runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, mem_addr, tmp_result.length); + return SUCCESS; +} + +Status DavinciModel::CopyPlainData(const std::vector &data, uint32_t data_index, uint32_t data_op_index, + const std::vector &outputs, rtMemcpyKind_t kind) { + GE_CHK_BOOL_EXEC(data_index < data.size(), return PARAM_INVALID, "index:%u >= size:%zu", data_index, data.size()); + bool flag = data[data_index].isDataSupportMemShare && support_mem_shared_flag_; + // if data attr support zero cpy,then update addrs info to flowtable + if (flag) { + GELOGI("No need to copy input data, user's input data buffer can be shared."); + return SUCCESS; + } + + GE_CHECK_VECTOR_NOT_EMPTY(outputs); + // P2P memory space parameters + void *host_data_addr = data[data_index].data; + uint32_t copy_size = data[data_index].length; + GELOGD("data output tensor is aipp tensor,copy data only."); + + void *data_out_addr = nullptr; + if (VarManager::Instance(session_id_)->IsVarAddr(outputs[0])) { + data_out_addr = var_mem_base_ + outputs[0] - runtime_param_.logic_var_base; + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, data_out_addr, copy_size, + copy_size); + } else { + data_out_addr = mem_base_ + outputs[0]; + GELOGI("output[0]=%ld, copy_size=%u, total_size=%zu", outputs[0], copy_size, TotalMemSize()); + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(((uint64_t)outputs[0] + (uint64_t)copy_size) > TotalMemSize(), return INTERNAL_ERROR, + "input offset add size is large than total memory."); + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, data_op_list_[data_op_index]->GetName().c_str(), 0, data_out_addr, copy_size, + copy_size); + } + + GE_CHK_RT_RET(rtMemcpy(data_out_addr, copy_size, host_data_addr, copy_size, kind)); + + return SUCCESS; +} /// /// @ingroup domi_ome @@ -2019,9 +2082,9 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { } else { output_data.index = data_id; output_data.model_id = model_id_; - GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_data_info_.size(), INTERNAL_ERROR, + GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_size_list_.size(), INTERNAL_ERROR, "output buffer size[%zu] not equal output_size_list[%zu] size!", output_data.blobs.size(), - output_data_info_.size()); + output_size_list_.size()); // index of data in output_data uint32_t output_data_index = 0; @@ -2192,7 +2255,7 @@ Status DavinciModel::DumpOpInputOutput() { /// Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { GE_CHK_BOOL_EXEC(nullptr != op_def, return PARAM_INVALID, "op_def is null!"); - string op_name = ge::StringUtils::ReplaceAll(op_def->GetName(), "/", "-"); + string op_name = domi::StringUtils::ReplaceAll(op_def->GetName(), "/", "-"); GELOGI("dump op name:%s, type:%s, model_id: %u.", op_def->GetName().c_str(), op_def->GetType().c_str(), model_id_); string model_path = "./dump" + to_string(model_id_); if (mmAccess(model_path.c_str()) != EN_OK) { @@ -2378,6 +2441,7 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().WriteInternalErrorCode(); GELOGI("Model run end, model id:%u", model->model_id_); + GEEVENT("Model Run thread end, model_id:%u.", model->model_id_); return nullptr; } @@ -2411,8 +2475,8 @@ Status DavinciModel::DestroyThread() { /// @author /// Status DavinciModel::ModelRunStart() { - GE_CHK_BOOL_RET_STATUS((RESET != DavinciModel::GetSysMode()) && (STOP != DavinciModel::GetSysMode()), INTERNAL_ERROR, - "Model Start FAIL in wrong sys mode!"); + GE_CHK_BOOL_RET_STATUS((domi::RESET != DavinciModel::GetSysMode()) && (domi::STOP != DavinciModel::GetSysMode()), + INTERNAL_ERROR, "Model Start FAIL in wrong sys mode!"); GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); @@ -2445,8 +2509,8 @@ Status DavinciModel::ModelRunStart() { /// @author /// Status DavinciModel::ModelRunStop() { - GE_CHK_BOOL_RET_STATUS((DavinciModel::GetSysMode() != RESET) && (DavinciModel::GetSysMode() != STOP), INTERNAL_ERROR, - "Model stop FAIL in wrong sys mode!"); + GE_CHK_BOOL_RET_STATUS((DavinciModel::GetSysMode() != domi::RESET) && (DavinciModel::GetSysMode() != domi::STOP), + INTERNAL_ERROR, "Model stop FAIL in wrong sys mode!"); LockRunFlg(); GE_MAKE_GUARD(tmp_lock, [&] { UnlockRunFlg(); }); @@ -2592,24 +2656,16 @@ Status DavinciModel::DistributeTask() { } } } - AddEndGraphToTaskList(); - // launch dump kernel to aicpu - GE_CHK_STATUS_RET(data_dumper_.LoadDumpInfo(), "Load dump info failed."); - return SUCCESS; -} - -void DavinciModel::AddEndGraphToTaskList() { auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(name_) != all_dump_model.end()) { - if (end_graph_id_ != 0xFFFFFFFF && end_graph_op_ != nullptr) { - data_dumper_.SaveDumpTask(task_list_[end_graph_id_]->GetTaskID(), end_graph_op_, 0); - GELOGI("The type of op is %s and the task id is %u", end_graph_op_->GetType().c_str(), - task_list_[end_graph_id_]->GetTaskID()); - } else { - GELOGD("There are no end graph node in the graph"); - } + data_dumper_.SaveDumpTask(task_list_[end_graph_id_]->GetTaskID(), end_graph_op_, 0); + GELOGI("The type of op is %s and the task id is %u", end_graph_op_->GetType().c_str(), + task_list_[end_graph_id_]->GetTaskID()); } + // launch dump kernel to aicpu + GE_CHK_STATUS_RET(data_dumper_.LoadDumpInfo(), "Load dump info failed."); + return SUCCESS; } /// @@ -2710,7 +2766,7 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 } bool is_dynamic_aipp = false; for (const auto &op_desc : data_op_list_) { - if (op_desc->GetType() == AIPP_DATA_TYPE) { + if (op_desc->GetType() == domi::AIPP_DATA_TYPE) { GELOGI("This is dynamic aipp model."); is_dynamic_aipp = true; break; @@ -2738,20 +2794,20 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 /// /// @ingroup ge /// @brief Copy Inputs and Outputs addr to model for direct use. -/// @param [in] const InputData &input_data: model input data. -/// @param [in] OutputData &output_data: model output data. +/// @param [in] const domi::InputData &input_data: model input data. +/// @param [in] domi::OutputData &output_data: model output data. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input /// @return SUCCESS handle successfully / PARAM_INVALID for failed /// Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic_input) { - if (ZeroCopyBlobs(input_data_info_, input_data.blobs, is_dynamic_input, kInputZeroCopy, input_data.batch_label) != - SUCCESS) { + if (ZeroCopyBlobs(input_addr_list_, input_size_list_, input_data.blobs, is_dynamic_input, kInputZeroCopy, + input_data.batch_label) != SUCCESS) { GELOGE(PARAM_INVALID, "Copy input data to model failed."); return PARAM_INVALID; } - if (ZeroCopyBlobs(output_data_info_, output_data.blobs, is_dynamic_input, kOutputZeroCopy, input_data.batch_label) != - SUCCESS) { + if (ZeroCopyBlobs(output_addr_list_, output_size_list_, output_data.blobs, is_dynamic_input, kOutputZeroCopy, + input_data.batch_label) != SUCCESS) { GELOGE(PARAM_INVALID, "Copy output data to model failed."); return PARAM_INVALID; } @@ -2764,37 +2820,31 @@ Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &outp /// /// @ingroup ge /// @brief Copy Data addr to model for direct use. -/// @param [in] const vstd::map> &data_info: model memory addr/size list. +/// @param [in] const vector &addrs: model input memory addr list. +/// @param [in] const vector &sizes: model input memory size list. /// @param [in] const std::vector &blobs: user input data list. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy /// @param [in] string batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// -Status DavinciModel::ZeroCopyBlobs(const std::map> &data_info, +Status DavinciModel::ZeroCopyBlobs(const std::vector &addr_list, const std::vector &size_list, const std::vector &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, std::string batch_label) { - if (blobs.size() != data_info.size()) { - GELOGE(FAILED, "Blobs not match: blobs=%zu datas=%zu", blobs.size(), data_info.size()); + if ((blobs.size() != addr_list.size()) || (blobs.size() != size_list.size())) { + GELOGE(FAILED, "Blobs not match: blobs=%zu addr=%zu size=%zu", blobs.size(), addr_list.size(), size_list.size()); return FAILED; } - for (const auto &data : data_info) { - if (data.first >= blobs.size()) { - GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u", blobs.size(), data_info.size(), data.first); - return FAILED; - } - int64_t mem_size = data.second.first; - void *mem_addr = data.second.second; - - const DataBuffer &data_buf = blobs[data.first]; + for (size_t idx = 0; idx < size_list.size(); ++idx) { + const DataBuffer &data_buf = blobs[idx]; if (data_buf.data == nullptr) { - GELOGE(FAILED, "data_buf.data is nullptr, index=%u", data.first); + GELOGE(FAILED, "data_buf.data is nullptr, index=%zu", idx); return FAILED; } + GELOGI("Copy Blobs %zu: Input data length is %u, Op data size is %u.", idx, data_buf.length, size_list[idx]); - GELOGI("Copy Blobs %u: Input data length is %u, Op data size is %ld.", data.first, data_buf.length, mem_size); - if (!CheckInputAndModelSize(data_buf.length, mem_size, is_dynamic_input)) { + if (!CheckInputAndModelSize(data_buf.length, size_list[idx], is_dynamic_input)) { GELOGE(FAILED, "Check input size and model size failed"); return FAILED; } @@ -2804,14 +2854,14 @@ Status DavinciModel::ZeroCopyBlobs(const std::map(tensor->MutableData().data()); GE_CHK_BOOL_RET_STATUS(ge::CheckInt64Uint32MulOverflow(elem_num, kBytes) == SUCCESS, FAILED, "Shape size is invalid"); - uint64_t offset = static_cast(elem_num * kBytes); + int64_t offset = elem_num * kBytes; - uint64_t hbm_raw_data_base_addr = - reinterpret_cast(reinterpret_cast(v_output_addr[0])) + offset; + uint64_t hbm_raw_data_base_addr = reinterpret_cast(v_output_addr[0]) + offset; for (int64_t i = elem_num - 1; i >= 0; --i) { buff[i] = hbm_raw_data_base_addr + (buff[i] - buff[0]); } @@ -3126,48 +3162,45 @@ void DavinciModel::CleanTbeHandle() { /// @brief insert active_stream_indication_ /// @return Status /// -Status DavinciModel::InitStreamActive(const OpDescPtr &op_desc) { - if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { +Status DavinciModel::MarkActiveStream(const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(op_desc); + std::string type = op_desc->GetType(); + GE_IF_BOOL_EXEC( + type == domi::STREAMSWITCH, std::vector active_stream_list; + GE_LOGI_IF(!ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), + "GetInt ACTIVE_STREAM_LIST failed."); + if (active_stream_list.size() != kTrueBranchStreamNum) { + GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kTrueBranchStreamNum); + return INTERNAL_ERROR; + } uint32_t true_stream_id = active_stream_list.front(); + active_stream_indication_.insert(true_stream_id); + GELOGI("flowctrl_op_index_map node:%s, true_stream_id=%u.", op_desc->GetName().c_str(), true_stream_id);); + GE_IF_BOOL_EXEC( + type == domi::STREAMACTIVE, if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { + std::vector active_stream_list; + GE_CHK_BOOL_EXEC(AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), + return INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM failed."); + + for (size_t j = 0; j < active_stream_list.size(); ++j) { + active_stream_indication_.insert(active_stream_list[j]); + GELOGI("flowctrl_op_index_map node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), + active_stream_list[j]); + } + }); + + if (type == domi::STREAMSWITCHN) { std::vector active_stream_list; - GE_CHK_BOOL_EXEC(AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), - return INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM failed."); + if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { + GELOGE(INTERNAL_ERROR, "StreamSwitchNOp get attr ACTIVE_STREAM failed."); + return INTERNAL_ERROR; + } for (size_t j = 0; j < active_stream_list.size(); ++j) { active_stream_indication_.insert(active_stream_list[j]); - GELOGI("flowctrl_op_index_map node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); - } - } - - return SUCCESS; -} - -Status DavinciModel::InitStreamSwitch(const OpDescPtr &op_desc) { - std::vector active_stream_list; - GE_LOGI_IF(!ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list), - "GetInt ACTIVE_STREAM_LIST failed."); - if (active_stream_list.size() != kTrueBranchStreamNum) { - GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kTrueBranchStreamNum); - return INTERNAL_ERROR; - } - - uint32_t true_stream_id = active_stream_list.front(); - active_stream_indication_.insert(true_stream_id); - GELOGI("flowctrl_op_index_map node:%s, true_stream_id=%u.", op_desc->GetName().c_str(), true_stream_id); - - return SUCCESS; -} - -Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { - std::vector active_stream_list; - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { - GELOGE(INTERNAL_ERROR, "StreamSwitchNOp get attr ACTIVE_STREAM failed."); - return INTERNAL_ERROR; - } - - for (size_t j = 0; j < active_stream_list.size(); ++j) { - active_stream_indication_.insert(active_stream_list[j]); - GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); + GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); + }; } + GELOGI("Flow control: active_stream_indication_ size = %zu.", active_stream_indication_.size()); return SUCCESS; } @@ -3179,7 +3212,7 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); ge::NodePtr dst_node = in_anchor->GetOwnerNode(); GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == HCOMBROADCAST) { + if (dst_node->GetType() == domi::HCOMBROADCAST) { return true; } } @@ -3191,11 +3224,12 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { /// @ingroup domi_ome /// @brief Init model stream for NN model. /// @param [in] stream user input model stream. +/// @param [in] async_mode is asynchronize mode. /// @return Status /// -Status DavinciModel::InitModelStream(rtStream_t stream) { +Status DavinciModel::InitModelStream(rtStream_t stream, bool async_mode) { // asynchronize mode, use user input stream. - if (is_async_mode_) { + if (async_mode) { rt_model_stream_ = stream; is_inner_model_stream_ = false; return SUCCESS; @@ -3230,12 +3264,16 @@ Status DavinciModel::InitModelStream(rtStream_t stream) { /// Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data) { - is_async_mode_ = async_mode; - GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); - GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); + GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, async_mode); + GE_CHK_STATUS(InitModelStream(stream, async_mode), "Init model stream failed."); + + GELOGI("do rtModelExecute task sink, model id:%u", input_data.model_id); + auto enable_dump = false; auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); - auto enable_dump = !dump_path.empty(); + if (!dump_path.empty()) { + enable_dump = true; + } auto dump_op_env = std::getenv("DUMP_OP"); if (dump_op_env != nullptr) { @@ -3256,9 +3294,9 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa output_use_zero_copy_ = false; } - // Empty task, Just copy input to output, need direct copy. - if (task_list_.empty() && (input_use_zero_copy_ || output_use_zero_copy_)) { - GELOGE(FAILED, "Empty task, Just copy input to output, need direct copy."); + // Asynchronous mode depends on zero copy. + if (async_mode && !input_use_zero_copy_ && !output_use_zero_copy_ && !task_list_.empty()) { + GELOGE(FAILED, "Asynchronous mode but zero copy disabled."); return FAILED; } @@ -3279,16 +3317,15 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("rtModelExecute end"); } - if (!is_async_mode_) { - GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); - ret = output_use_zero_copy_ ? SyncDataAndDump() : CopyOutputData(input_data.index, output_data); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); - GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); - } + GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); + ret = output_use_zero_copy_ ? SyncDataAndDump() : CopyOutputData(input_data.index, output_data); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); + GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); // report model time data GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), (void)SinkTimeProfile(input_data)); GELOGI("Model run end, model id:%u", model_id_); + GEEVENT("Model Run thread end, model_id:%u", model_id_); return SUCCESS; } @@ -3471,9 +3508,9 @@ void DavinciModel::SetDataDumperArgs() { return nullptr; }; - data_dumper_.SetLoopAddr(get_var_addr(GetVariableOp(NODE_NAME_GLOBAL_STEP), runtime_param_), - get_var_addr(GetVariableOp(NODE_NAME_FLOWCTRL_LOOP_PER_ITER), runtime_param_), - get_var_addr(GetVariableOp(NODE_NAME_FLOWCTRL_LOOP_COND), runtime_param_)); + data_dumper_.SetLoopAddr(get_var_addr(GetVariableOp(domi::NODE_NAME_GLOBAL_STEP), runtime_param_), + get_var_addr(GetVariableOp(domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER), runtime_param_), + get_var_addr(GetVariableOp(domi::NODE_NAME_FLOWCTRL_LOOP_COND), runtime_param_)); GELOGI("SetDataDumperArgs end."); } @@ -3560,7 +3597,7 @@ Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { string cp_from_node; bool copy_value = false; for (auto &node : compute_graph->GetAllNodes()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != domi::VARIABLE, continue); GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), GELOGI("Get original type of cp_from_node")); if (cp_from_node.length() != 0) { diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index 76edd4a4..d5a7baf4 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -27,7 +27,7 @@ #include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" -#include "graph/debug/ge_attr_define.h" +#include "common/op/attr_define.h" #include "common/opskernel/ge_task_info.h" #include "common/types.h" #include "framework/common/util.h" @@ -47,6 +47,10 @@ #define WEIGHTS_ADDR_TO_CCE(var) namespace ge { +using domi::CONSTANTOP; +using domi::ENDGRAPH; +using domi::NETOUTPUT; +using domi::VARIABLE; using std::vector; enum ZeroCopyMode { kInputZeroCopy, @@ -146,14 +150,14 @@ class DavinciModel { /// @brief get sys mode /// @return SysMode /// - static SysMode GetSysMode(); + static domi::SysMode GetSysMode(); /// /// @ingroup domi_ome /// @brief set sys mode /// @return Status /// - static Status SetSysMode(SysMode mode); + static Status SetSysMode(domi::SysMode mode); /// /// @ingroup domi_ome @@ -193,7 +197,7 @@ class DavinciModel { vector GetOpDesc() { vector opDescVector; - GE_IF_BOOL_EXEC(ge::AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), + GE_IF_BOOL_EXEC(ge::AttrUtils::GetListOpDesc(GetGeModel(), domi::MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), GELOGI("get opDesc of opDescVector")); return opDescVector; } @@ -340,6 +344,13 @@ class DavinciModel { vector &output_desc, std::vector &inputFormats, std::vector &output_formats); + /// + /// @ingroup domi_ome + /// @brief copy input data to model + /// @return Status + /// + Status CopyInputDataToModel(const std::vector &data, uint32_t data_op_index, bool device_data); + Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); Status ReturnNoOutput(uint32_t data_id); @@ -406,6 +417,20 @@ class DavinciModel { /// uint32_t GetDeviceId() const { return device_id_; } + /// + /// @ingroup domi_ome + /// @brief Set Train Mode + /// @return void + /// + void SetTrainMode(bool mode) { is_train_mode_ = mode; } + + /// + /// @ingroup domi_ome + /// @brief Get Train Mode + /// @return bool true + /// + bool GetTrainMode() { return is_train_mode_; } + GeModelPtr GetGeModel() { return ge_model_; } const RuntimeParam &GetRuntimeParam() { return runtime_param_; } @@ -498,14 +523,15 @@ class DavinciModel { /// /// @ingroup ge /// @brief Copy Data addr to model for direct use. - /// @param [in] const std::map> &data_info: model memory addr/size list. + /// @param [in] const vector &addrs: model input memory addr list. + /// @param [in] const vector &sizes: model input memory size list. /// @param [in] const std::vector &blobs: user input data list. /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy /// @param [in] string batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// - Status ZeroCopyBlobs(const std::map> &data_info, + Status ZeroCopyBlobs(const std::vector &addr_list, const std::vector &size_list, const std::vector &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, string batch_label); @@ -573,8 +599,6 @@ class DavinciModel { void UnbindTaskSinkStream(); - void AddEndGraphToTaskList(); - /// /// @ingroup ge /// @brief Travel all nodes and do some init. @@ -588,9 +612,11 @@ class DavinciModel { /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. + /// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status /// - Status InitDataOp(const NodePtr &node, uint32_t &data_op_index); + Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, + std::map> &input_data_info); /// /// @ingroup ge @@ -609,27 +635,19 @@ class DavinciModel { Status InitNetOutput(const OpDescPtr &op_desc); /// - /// @ingroup domi_ome - /// @brief Constant Op Init. + /// @ingroup ge + /// @brief Make Input and Output addr for feature use. + /// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status /// - Status InitConstant(const OpDescPtr &op_desc); + Status CombineDataInfo(const std::map> &input_data_info); - Status InitVariable(const OpDescPtr &op_desc); - - Status InitEndGraph(const OpDescPtr &op_desc); - - /// @ingroup ge - /// @brief LabelSet Op Initialize. - /// @param [in] op_desc: LabelSet Op descriptor. + /// + /// @ingroup domi_ome + /// @brief Constant Op Init. /// @return Status - Status InitLabelSet(const OpDescPtr &op_desc); - - Status InitStreamSwitch(const OpDescPtr &op_desc); - - Status InitStreamActive(const OpDescPtr &op_desc); - - Status InitStreamSwitchN(const OpDescPtr &op_desc); + /// + Status InitConstant(const ConstOpDescPtr &op_desc) const; /// /// @ingroup domi_ome @@ -646,7 +664,7 @@ class DavinciModel { /// @brief Init model stream for NN model. /// @return Status /// - Status InitModelStream(rtStream_t stream); + Status InitModelStream(rtStream_t stream, bool async_mode); /// /// @ingroup ge @@ -662,16 +680,12 @@ class DavinciModel { /// Status BindInputQueue(); - Status CpuTaskModelZeroCopy(std::vector &mbuf_list, - std::map> &outside_addrs); - /// /// @ingroup ge /// @brief ACL, Bind NetOutput Op addr to output queue. /// @return: 0 for success / others for fail /// Status BindOutputQueue(); - Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size); /// /// @ingroup ge @@ -680,6 +694,13 @@ class DavinciModel { /// Status BindActiveStream(); + /// + /// @ingroup domi_ome + /// @brief insert active_stream_indication_ + /// @return Status + /// + Status MarkActiveStream(const OpDescPtr &op_desc); + /// /// @ingroup ge /// @brief definiteness queue schedule, bind input queue to task. @@ -688,7 +709,7 @@ class DavinciModel { /// @param [in] size: Data Op output tensor size. /// @return: 0 for success / others for fail /// - Status CpuModelDequeue(uint32_t queue_id); + Status CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size); /// /// @ingroup ge @@ -715,8 +736,6 @@ class DavinciModel { /// Status CpuWaitEndGraph(); - Status BindEnqueue(); - Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf); /// /// @ingroup ge /// @brief definiteness queue schedule, repeat run model. @@ -766,8 +785,10 @@ class DavinciModel { vector variable_op_list_; - std::map> input_data_info_; // Init by Data Output Tensor - std::map> output_data_info_; // Init by NetOutput Input Tensor + vector output_size_list_; // Init by NetOutput Input Tensor + vector output_addr_list_; // Init by NetOutput Input Tensor + vector input_size_list_; // Init by Data Output Tensor + vector input_addr_list_; // Init by Data Output Tensor // output op: save cce op actual needed memory size vector output_memory_size_list_; @@ -780,7 +801,7 @@ class DavinciModel { std::mutex mux_run_flg_; - static SysMode mode_; + static domi::SysMode mode_; static std::mutex mutex_mode_; @@ -794,7 +815,6 @@ class DavinciModel { vector event_list_; vector label_list_; - set label_id_indication_; std::mutex outside_addrs_mutex_; std::map> input_outside_addrs_; @@ -812,8 +832,6 @@ class DavinciModel { bool is_inner_model_stream_; - bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. - // ACL queue schedule, save queue ids for Init. std::vector cpu_task_list_; std::vector input_queue_ids_; // input queue ids created by caller. @@ -831,6 +849,8 @@ class DavinciModel { uint32_t device_id_; + bool is_train_mode_; + std::mutex flowctrl_op_index_internal_map_mutex_; std::map flowctrl_op_index_internal_map_; std::set active_stream_indication_; diff --git a/src/ge/graph/load/new_model_manager/davinci_model_parser.cc b/src/ge/graph/load/new_model_manager/davinci_model_parser.cc index b744f907..0c5d0073 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model_parser.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model_parser.cc @@ -35,14 +35,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelInfoParser(const Mo GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, GE_CHK_RT(rtDeviceReset(0)); return ret, "Parse model failed"); - auto *file_header = reinterpret_cast(model.model_data); + domi::ModelFileHeader *file_header = (domi::ModelFileHeader *)model.model_data; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(file_header == nullptr, GE_CHK_RT(rtDeviceReset(0)); return PARAM_INVALID, "file_header is null."); model_info.version = file_header->version; model_info.is_encrypt = false; - GE_IF_BOOL_EXEC(ENCRYPTED == file_header->is_encrypt, model_info.is_encrypt = true); + GE_IF_BOOL_EXEC(domi::ENCRYPTED == file_header->is_encrypt, model_info.is_encrypt = true); std::shared_ptr davinci_model = std::shared_ptr(new (std::nothrow) DavinciModel(model.priority, nullptr)); diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index 1b6b30c2..8cf866d0 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -302,8 +302,8 @@ Status ModelManager::UnloadModeldef(uint32_t model_id) { Status ModelManager::DataInput(const InputData &input_data, OutputData &output_data) { GELOGI("calling the DataInput"); - SysMode mode = DavinciModel::GetSysMode(); - if ((mode == RESET) || (mode == STOP)) { + domi::SysMode mode = DavinciModel::GetSysMode(); + if ((mode == domi::RESET) || (mode == domi::STOP)) { GELOGE(domi::MODEL_NOT_READY, "System mode is reset or stop"); return domi::MODEL_NOT_READY; } @@ -344,8 +344,8 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d /// Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector &inputs, std::vector &outputs) { - SysMode mode = DavinciModel::GetSysMode(); - if ((mode == RESET) || (mode == STOP)) { + domi::SysMode mode = DavinciModel::GetSysMode(); + if ((mode == domi::RESET) || (mode == domi::STOP)) { GELOGE(domi::MODEL_NOT_READY, "System mode is reset or stop"); return domi::MODEL_NOT_READY; } @@ -358,17 +358,26 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vectorGetDataList()) { + GE_CHECK_NOTNULL(op); + GE_CHECK_GE(inputs.size(), 1); + GE_CHECK_GE(inputs.size() - 1, index); + DataBuffer data; - data.data = inputs[i].data.data; - data.length = inputs[i].data.length; + data.data = inputs[index].data.data; + data.length = inputs[index].data.length; input_data.blobs.push_back(data); + index++; } + CHECK_FALSE_EXEC(input_data.blobs.size() >= inputs.size(), + GELOGW("cur_inputs size = %zu, inputs size = %zu.", input_data.blobs.size(), inputs.size());); + OutputData output_data; output_data.model_id = model_id; output_data.index = 0; - for (size_t i = 0; i < outputs.size(); ++i) { + for (size_t i = 0; i < outputs.size(); i++) { DataBuffer data; data.data = outputs[i].data.data; data.length = outputs[i].data.length; @@ -463,7 +472,7 @@ Status ModelManager::HandleAclProfilingCommand(const Command &command) { std::string map_key = command.cmd_params[0]; std::string value = command.cmd_params[1]; - if (map_key == PROFILE_CONFIG) { + if (map_key == domi::PROFILE_CONFIG) { ProfilingManager::Instance().SetProfilingConfig(value); } @@ -481,17 +490,18 @@ Status ModelManager::HandleProfileCommand(const Command &command) { GELOGI("Profiling mode, Command key:%s , value:%s ", map_key.c_str(), value.c_str()); - auto iter = PROFILE_COMPONENT_MAP.find(map_key); - if (iter != PROFILE_COMPONENT_MAP.end()) { + auto iter = domi::PROFILE_COMPONENT_MAP.find(map_key); + if (iter != domi::PROFILE_COMPONENT_MAP.end()) { std::string property_value = (value == "on") ? "1" : "0"; PropertiesManager::Instance().SetPropertyValue(iter->second, property_value); } - if ((map_key == PROFILER_JOBCTX || map_key == PROFILER_TARGET_PATH || map_key == RTS_PROFILE_PATH)) { + if ((map_key == domi::PROFILER_JOBCTX || map_key == domi::PROFILER_TARGET_PATH || + map_key == domi::RTS_PROFILE_PATH)) { PropertiesManager::Instance().SetPropertyValue(map_key, value); } - if ((map_key == PROFILE_STOP_KEY) && (value == PROFILE_STOP_VALUE)) { + if ((map_key == domi::PROFILE_STOP_KEY) && (value == domi::PROFILE_STOP_VALUE)) { rtError_t rt_ret = rtProfilerStop(); if (rt_ret != RT_ERROR_NONE) { GELOGE(PARAM_INVALID, "Call rtProfilerStop ret:%d", rt_ret); @@ -512,6 +522,7 @@ Status ModelManager::HandleDumpCommand(const Command &command) { std::string dump_model(DUMP_ALL_MODEL); std::string dump_path("/"); std::set dump_layers; + std::string dump_layer_count; auto iter_dump_status = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_STATUS); if (iter_dump_status != command.cmd_params.end()) { @@ -666,15 +677,6 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model break; } davinci_model->SetId(model_id); - - int32_t device_id = 0; - rtError_t rt_ret = rtGetDevice(&device_id); - if (rt_ret != RT_ERROR_NONE || device_id < 0) { - GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); - return FAILED; - } - davinci_model->SetDeviceId(device_id); - ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); @@ -714,7 +716,7 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d GE_CHK_BOOL_RET_STATUS(model_data.key.empty() || access(model_data.key.c_str(), F_OK) == 0, PARAM_INVALID, "input key file path is not valid, %s", strerror(errno)); - ModelHelper model_helper; + domi::ModelHelper model_helper; Status ret = model_helper.LoadModel(model_data); if (ret != SUCCESS) { GELOGE(ret, "load model failed."); @@ -805,17 +807,17 @@ Status ModelManager::GetModelMemAndWeightSize(const ModelData &model, size_t &me Status ret = DavinciModelParser::ParseModelContent(model, model_data, model_len); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "parse model content failed!"); - OmFileLoadHelper om_file_helper; + domi::OmFileLoadHelper om_file_helper; ret = om_file_helper.Init(model_data, model_len); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "om file helperInit failed!"); - auto partition_table = reinterpret_cast(model_data); + auto partition_table = reinterpret_cast(model_data); if (partition_table->num == 1) { GELOGE(FAILED, "om model is error,please use executable om model"); return FAILED; } - ModelPartition task_partition; - if (om_file_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { + domi::ModelPartition task_partition; + if (om_file_helper.GetModelPartition(domi::ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { GELOGE(FAILED, "get task model partition failed."); return FAILED; } @@ -825,14 +827,14 @@ Status ModelManager::GetModelMemAndWeightSize(const ModelData &model, size_t &me return FAILED; } if (task_partition.size != 0) { - if (!ReadProtoFromArray(task_partition.data, static_cast(task_partition.size), model_task_def.get())) { + if (!domi::ReadProtoFromArray(task_partition.data, static_cast(task_partition.size), model_task_def.get())) { GELOGE(FAILED, "ReadProtoFromArray failed."); return FAILED; } } - ModelPartition partition_weight; - ret = om_file_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition_weight); + domi::ModelPartition partition_weight; + ret = om_file_helper.GetModelPartition(domi::ModelPartitionType::WEIGHTS_DATA, partition_weight); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Get weight partition failed. ret = %u", ret); mem_size = model_task_def->memory_size(); @@ -853,4 +855,5 @@ void ModelManager::GenModelId(uint32_t *id) { free_model_id_.pop_back(); } } + } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index fe511c24..7ac4d822 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -23,7 +23,6 @@ #include #include #include -#include #include #include "cce/aicpu_engine_struct.h" #include "common/types.h" diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index dd2d20f6..df11c874 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -51,6 +51,27 @@ bool ModelUtils::IsOutput(ConstOpDescPtr op_desc) { return false; } +/// +/// @ingroup domi_ome +/// @brief Check is the Input need trans code. +/// @return bool +/// +bool ModelUtils::IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index) { + GE_CHECK_NOTNULL_EXEC(op_desc, return false); + const auto &input_desc = op_desc->MutableInputDesc(static_cast(tensor_index)); + const auto &output_desc = op_desc->MutableOutputDesc(static_cast(tensor_index)); + GE_CHECK_NOTNULL_EXEC(input_desc, return false); + GE_CHECK_NOTNULL_EXEC(output_desc, return false); + + if ((output_desc->GetFormat() == FORMAT_NC1HWC0) && (output_desc->GetDataType() == DT_INT8)) { + // AIPP input, add attribute in data op to tag aipp + return false; + } + + return (input_desc->GetFormat() != output_desc->GetFormat()) || + (input_desc->GetDataType() != output_desc->GetDataType()); +} + /// /// @ingroup domi_ome /// @brief Get input size. @@ -64,7 +85,7 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { - if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { + if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != domi::NETOUTPUT)) { // TBE: add weights size to input GE_IF_BOOL_EXEC( true, GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); int64_t tensor_size = 0; @@ -368,7 +389,7 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co return v_input_data_addr; } for (size_t i = 0; i < inputs_size; ++i) { - if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { + if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != domi::NETOUTPUT)) { // TBE: add weights address to input GE_IF_BOOL_EXEC( true, GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); int64_t tensor_size = 0; @@ -405,7 +426,7 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co uint8_t *mem_addr = nullptr; // l1 fusion if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { - mem_addr = reinterpret_cast(reinterpret_cast(input_offset)); + mem_addr = reinterpret_cast(input_offset); v_input_data_addr.push_back(mem_addr); } else { mem_addr = static_cast(mem_base + input_offset - logic_mem_base); @@ -473,7 +494,7 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C uint8_t *mem_addr = nullptr; // l1 fusion if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { - mem_addr = reinterpret_cast(reinterpret_cast(v_output_offset[i])); + mem_addr = reinterpret_cast(v_output_offset[i]); v_output_data_addr.push_back(mem_addr); } else { mem_addr = static_cast(mem_base + v_output_offset[i] - logic_mem_base); @@ -518,7 +539,7 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { v_workspace_data_addr.push_back(reinterpret_cast(v_workspace_offset[i])); GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, - reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); + reinterpret_cast(v_workspace_offset[i])); } else { int64_t workspace_offset = v_workspace_offset[i]; int64_t workspace_bytes = v_workspace_bytes[i]; diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h index 479cc431..1a15c930 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.h +++ b/src/ge/graph/load/new_model_manager/model_utils.h @@ -40,6 +40,13 @@ class ModelUtils { /// static bool IsOutput(ConstOpDescPtr op_desc); + /// + /// @ingroup domi_ome + /// @brief Check is the Input need trans code. + /// @return bool + /// + static bool IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index); + /// /// @ingroup domi_ome /// @brief Get input size. diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index 3fa5eee2..f65d05dd 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -67,13 +67,13 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m GE_CHECK_NOTNULL(op_desc); Status dmrt = HcomOmeUtil::GetHcomDataType(op_desc, data_type); - if (dmrt != SUCCESS) { + if (dmrt != domi::SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); return FAILED; } - dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == HCOMALLGATHER), count); - if (dmrt != SUCCESS) { + dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == domi::HCOMALLGATHER), count); + if (dmrt != domi::SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); return FAILED; } @@ -170,28 +170,28 @@ Status HcclTaskInfo::SetAddrs(const std::string &hccl_type, const std::shared_pt output_data_addr = output_data_addr_list[0]; } - if (hccl_type == HCOMBROADCAST) { + if (hccl_type == domi::HCOMBROADCAST) { int64_t root_id; dmrt = HcomOmeUtil::GetHcomRootId(op_desc, root_id); - if (dmrt != SUCCESS) { + if (dmrt != domi::SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); return FAILED; } root_id_ = root_id; - } else if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE) { + } else if (hccl_type == domi::HCOMALLGATHER || hccl_type == domi::HCOMRECEIVE) { output_data_addr_ = output_data_addr; - } else if (hccl_type == HCOMALLREDUCE) { + } else if (hccl_type == domi::HCOMALLREDUCE) { dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { + if (dmrt != domi::SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); return FAILED; } output_data_addr_ = output_data_addr; op_type_ = op_type; - } else if (hccl_type == HCOMREDUCESCATTER) { + } else if (hccl_type == domi::HCOMREDUCESCATTER) { dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { + if (dmrt != domi::SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); return FAILED; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index faaa3f82..11b32be1 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -28,6 +28,7 @@ #include "graph/load/new_model_manager/model_manager.h" namespace ge { + Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("KernelExTaskInfo Init Start."); if (davinci_model == nullptr) { @@ -41,7 +42,6 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin } auto kernel_ex_def = task_def.kernel_ex(); - const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); // 1. Copy context from kernelExDef.private to workspace uint32_t op_index = kernel_ex_def.op_index(); @@ -51,12 +51,12 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin return INTERNAL_ERROR; } - if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { + if (CopyTaskInfo(kernel_ex_def, davinci_model->GetRuntimeParam(), op_desc) != SUCCESS) { GELOGE(FAILED, "copy task info to workspace failed."); return FAILED; } - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); if (workspace_data_addrs.empty()) { GELOGE(FAILED, "workspace_data_addrs is empty."); return FAILED; @@ -78,18 +78,18 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin // 2.1 get loop cond variable for tensor array write uint64_t step_id_addr = 0; - OpDescPtr step_id_node = davinci_model->GetVariableOp(NODE_NAME_GLOBAL_STEP); + OpDescPtr step_id_node = davinci_model->GetVariableOp(domi::NODE_NAME_GLOBAL_STEP); if (step_id_node != nullptr) { - vector v_step_id_addr = ModelUtils::GetOutputDataAddrs(rts_param, step_id_node); + vector v_step_id_addr = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), step_id_node); if (!v_step_id_addr.empty()) { step_id_addr = static_cast(reinterpret_cast(v_step_id_addr[0])); } } // 3. Set workspaceaddr, inputOutputDataAddr - uint64_t workspace_base_addr = reinterpret_cast(reinterpret_cast(workspace_data_addrs[0])); - const vector input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); - const vector output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + uint64_t workspace_base_addr = reinterpret_cast(workspace_data_addrs[0]); + vector input_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + vector output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); vector io_addrs; io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); @@ -133,13 +133,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast(&fwk_op_kernel), sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) - - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model->SetZeroCopyAddr(op_desc, virtual_io_addrs, input_output_addr_); + davinci_model->SetZeroCopyAddr(op_desc, io_addrs, input_output_addr_); kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); davinci_model_ = davinci_model; diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index a6419f9f..9aab55e7 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -25,7 +25,6 @@ class KernelExTaskInfo : public TaskInfo { public: KernelExTaskInfo() : task_id_(0), - stream_id_(0), dump_flag_(RT_KERNEL_DEFAULT), kernel_buf_size_(0), davinci_model_(nullptr), diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 47956cf2..84710e41 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -221,13 +221,13 @@ Status KernelTaskInfo::SuperKernelLaunch() { return RT_FAILED; } // Call the fuse API - skt::SuperKernel *superKernel = nullptr; + skt::SuperKernel *superKernel; if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); return RT_FAILED; } // Launch a super kernel - if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { + if (superKernel->Launch(skt_info_.last_stream, true) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); return RT_FAILED; } @@ -341,7 +341,6 @@ Status KernelTaskInfo::Distribute() { rtError_t rt_ret = RT_ERROR_NONE; char *skt_enable_env = getenv("SKT_ENABLE"); int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; - bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); if (kernel_type_ == cce::ccKernelType::AI_CPU) { // blockDim is reserved parameter, set to 1 rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), @@ -349,10 +348,11 @@ Status KernelTaskInfo::Distribute() { nullptr, stream_, dump_flag_); } else { /* default: not skt launch */ + bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); GELOGI( - "KernelTaskInfo Distribute Start, sktenable:%d taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " + "KernelTaskInfo Distribute Start, sktenable:%ld taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " "stubfunc:%p blockdim:%u stream:%p", - call_skt, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); + env_flag, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); // l1 fusion enable and env flag open (kCloseSkt for skt debug) if (call_skt && (env_flag != kCloseSkt)) { GE_RETURN_IF_ERROR(SuperKernelDistribute()); @@ -371,7 +371,7 @@ Status KernelTaskInfo::Distribute() { GELOGI( "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " "blockdim:%d stream:%p", - call_skt, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); + env_flag, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); return SUCCESS; } @@ -423,12 +423,12 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, stub_func_ = const_cast(bin_file_key); } - const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); - const vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); - const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); - + const vector input_data_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + const vector workspace_data_addrs = + ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); vector tensor_device_addrs; + tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); @@ -468,13 +468,7 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, reinterpret_cast(reinterpret_cast(args_) + offset + sizeof(void *) * input_data_addrs.size()); } - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, static_cast(args_) + offset); - + davinci_model_->SetZeroCopyAddr(op_desc, tensor_device_addrs, static_cast(args_) + offset); // update origin l2 data string sm_desc = kernel_def.sm_desc(); char *sm_contrl = nullptr; @@ -522,7 +516,6 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::mapsecond; - const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); const domi::KernelContext &context = kernel_def.context(); const uint32_t kCustomAicpuArgsLen = 5; @@ -541,8 +534,11 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map(const_cast(context.args_offset().data())))[i]; } - const std::vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); - const std::vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + const std::vector input_data_addrs = + ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + const std::vector output_data_addrs = + ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), ModelUtils::GetOutputDescs(op_desc)); @@ -553,7 +549,7 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map(args + ctx_.argsOffset[0])) = - reinterpret_cast(reinterpret_cast(custom_info_.input_descs)); // arg 0 + reinterpret_cast(custom_info_.input_descs); // arg 0 *(reinterpret_cast(args + ctx_.argsOffset[1])) = - reinterpret_cast(reinterpret_cast(custom_info_.input_addrs)); // arg 1 + reinterpret_cast(custom_info_.input_addrs); // arg 1 *(reinterpret_cast(args + ctx_.argsOffset[2])) = - reinterpret_cast(reinterpret_cast(custom_info_.output_descs)); // arg 2 + reinterpret_cast(custom_info_.output_descs); // arg 2 *(reinterpret_cast(args + ctx_.argsOffset[3])) = - reinterpret_cast(reinterpret_cast(custom_info_.output_addrs)); // arg 3 + reinterpret_cast(custom_info_.output_addrs); // arg 3 *(reinterpret_cast(args + ctx_.argsOffset[4])) = - reinterpret_cast(reinterpret_cast(custom_info_.attr_handle)); // arg 4 + reinterpret_cast(custom_info_.attr_handle); // arg 4 rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { @@ -610,10 +606,8 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, custom_info_.input_addrs); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, custom_info_.output_addrs); + davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, custom_info_.input_addrs); + davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, custom_info_.output_addrs); return SUCCESS; } @@ -720,10 +714,8 @@ Status KernelTaskInfo::InitAicpuTask(const std::map &op_lis } OpDescPtr op_desc = iter->second; - const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); - - vector input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); - vector output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + vector input_addrs = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + vector output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); vector io_addrs; io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); @@ -760,13 +752,7 @@ Status KernelTaskInfo::InitAicpuTask(const std::map &op_lis sizeof(void *) * input_addrs.size()); } - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, - static_cast(args_) + sizeof(aicpu::AicpuParamHead)); + davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, static_cast(args_) + sizeof(aicpu::AicpuParamHead)); return SUCCESS; } @@ -918,7 +904,7 @@ Status KernelTaskInfo::CceUpdateKernelArgs(const domi::KernelContext &context, u std::string file_name = "libcce.so"; std::string path = PluginManager::GetPath(); path.append(file_name); - string canonicalPath = RealPath(path.c_str()); + string canonicalPath = domi::RealPath(path.c_str()); if (canonicalPath.empty()) { GELOGW("failed to get realpath of %s", path.c_str()); return FAILED; @@ -991,7 +977,7 @@ Status KernelTaskInfo::SetFlowtable(std::string &flowtable, const domi::KernelDe *(reinterpret_cast( args + (reinterpret_cast(const_cast(context.args_offset().data())))[0])) = - reinterpret_cast(reinterpret_cast(flowtable_)); + reinterpret_cast(flowtable_); } return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc deleted file mode 100644 index 9c5e4c29..00000000 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h" - -#include "framework/common/debug/ge_log.h" -#include "graph/load/new_model_manager/davinci_model.h" - -namespace ge { -Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); - if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); - return PARAM_INVALID; - } - - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); - if (ret != SUCCESS) { - return ret; - } - - auto memcpy_async_def = task_def.memcpy_async(); - - uint64_t logic_dst = memcpy_async_def.dst(); - uint64_t logic_src = memcpy_async_def.src(); - - dst_max_ = memcpy_async_def.dst_max(); - - uint64_t update_base_addr = 0; - ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); - if (ret != SUCCESS) { - return ret; - } - src_ = reinterpret_cast(update_base_addr + logic_src); - if (src_ == nullptr) { - GELOGE(PARAM_INVALID, "src_ is null!"); - return PARAM_INVALID; - } - - uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); - uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); - dst_ = reinterpret_cast(mem_base + (logic_dst - logic_mem_base)); - if (dst_ == nullptr) { - GELOGE(PARAM_INVALID, "dst_ is null!"); - return PARAM_INVALID; - } - - count_ = memcpy_async_def.count(); - kind_ = memcpy_async_def.kind(); - - // malloc args memory - size_t args_size = sizeof(void *); - rtError_t rt_ret = rtMalloc(&args_, args_size * 2, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - // copy orign src - GELOGI("src_args:%p, destMax:%zu, src_:%p, count=%zu, kind=%u", args_, args_size, src_, args_size, - RT_MEMCPY_HOST_TO_DEVICE); - rt_ret = rtMemcpy(args_, args_size, &src_, args_size, RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api for src failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - // copy orign dst - GELOGI("dst_args:%p, destMax:%zu, dst_:%p, count=%zu, kind=%u", - reinterpret_cast(reinterpret_cast(args_) + args_size), args_size, dst_, args_size, - RT_MEMCPY_HOST_TO_DEVICE); - rt_ret = rtMemcpy(reinterpret_cast(reinterpret_cast(args_) + args_size), args_size, &dst_, - args_size, RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api for dst failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_, args_, - reinterpret_cast(reinterpret_cast(args_) + args_size)); - - return SUCCESS; -} - -Status MemcpyAddrAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); - GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); - - rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_) + sizeof(void *)), - dst_max_, args_, count_, static_cast(kind_), stream_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - return SUCCESS; -} - -Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, - uint64_t &base_addr) { - GE_CHECK_NOTNULL(davinci_model); - uint64_t data_base_addr = - reinterpret_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->VarMemBase())) - - davinci_model->GetRtVarAddr(); - - uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); - uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); - uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); - uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); - uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); - uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); - - if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { - base_addr = data_base_addr; - GELOGI("The update_addr is data address."); - } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { - base_addr = weight_base_addr; - GELOGI("The update_addr is weight address."); - } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { - base_addr = var_base_addr; - GELOGI("The update_addr is variable address."); - } else if (update_addr != 0) { - base_addr = 0; - GELOGE(PARAM_INVALID, "The update_addr is abnormal."); - return PARAM_INVALID; - } - return SUCCESS; -} - -REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); -} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h deleted file mode 100644 index 9252e43a..00000000 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h +++ /dev/null @@ -1,55 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ -#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ -#include "graph/load/new_model_manager/task_info/task_info.h" - -namespace ge { -class MemcpyAddrAsyncTaskInfo : public TaskInfo { - public: - MemcpyAddrAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), count_(0), kind_(0) {} - - ~MemcpyAddrAsyncTaskInfo() override { - src_ = nullptr; - dst_ = nullptr; - - if (args_ != nullptr) { - rtError_t ret = rtFree(args_); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); - } - } - - args_ = nullptr; - } - - Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - - Status Distribute() override; - - private: - Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); - - void *dst_; - uint64_t dst_max_; - void *src_; - void *args_; - uint64_t count_; - uint32_t kind_; -}; -} // namespace ge -#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index c783c718..f2621c52 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -52,8 +52,7 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da count_ = memcpy_async_def.count(); kind_ = memcpy_async_def.kind(); GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_); + reinterpret_cast(logic_src), reinterpret_cast(logic_dst), src_, dst_); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc index a54bf012..4e37ab64 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc @@ -47,7 +47,7 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d auto op_desc = davinci_model->GetOpList()[op_index]; GE_CHECK_NOTNULL(op_desc); auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (!input_data_addr.empty() && input_data_addr.size() >= STREAM_SWITCH_INPUT_NUM) { + if (!input_data_addr.empty() && input_data_addr.size() >= domi::STREAM_SWITCH_INPUT_NUM) { input_ptr_ = input_data_addr[0]; value_ptr_ = input_data_addr[1]; } @@ -60,9 +60,9 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d cond_ = static_cast(cond); size_t input_size = op_desc->GetInputsSize(); - if (input_data_addr.size() != STREAM_SWITCH_INPUT_NUM || input_size != STREAM_SWITCH_INPUT_NUM) { - GELOGE(INTERNAL_ERROR, "Input num should be %u. inputAddr size:%zu, inputDesc size:%zu.", STREAM_SWITCH_INPUT_NUM, - input_data_addr.size(), input_size); + if (input_data_addr.size() != domi::STREAM_SWITCH_INPUT_NUM || input_size != domi::STREAM_SWITCH_INPUT_NUM) { + GELOGE(INTERNAL_ERROR, "Input num should be %u inputAddr size:%zu, inputDesc size:%zu.", + domi::STREAM_SWITCH_INPUT_NUM, input_data_addr.size(), input_size); return INTERNAL_ERROR; } diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc index b8fc77ac..38dbd8b3 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc @@ -19,17 +19,17 @@ namespace ge { namespace skt { -Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { +Status SuperKernel::Launch(rtStream_t stream, bool dump_flag) { const void *func_stub_ = this->GetFuncStub(); - const void *args[] = {this->GetNavTablePtr(), - reinterpret_cast(reinterpret_cast(this->GetNavTableSize()))}; + const void *args[] = {this->GetNavTablePtr(), (const void *)this->GetNavTableSize()}; - rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); + void *device_args_addr = nullptr; + rtError_t rt_ret = rtMalloc((void **)&(device_args_addr), sizeof(args), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) - rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); + rt_ret = rtMemcpy((void *)device_args_addr, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) - rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, + rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr, sizeof(args), NULL, stream, dump_flag); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); return FAILED;) diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h index 1c31acd1..b662d97b 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h @@ -25,7 +25,6 @@ namespace ge { namespace skt { class SuperKernel { private: - void *device_args_addr_ = nullptr; const void *func_stub_; void *dev_nav_table_; uint64_t nav_table_size_; @@ -34,18 +33,8 @@ class SuperKernel { public: SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} - ~SuperKernel() { - // free memory when all releasing - if (device_args_addr_ != nullptr) { - GE_CHK_RT(rtFree(device_args_addr_)); - GELOGI("SKT: super_kernel args addr free."); - } - if (dev_nav_table_ != nullptr) { - GE_CHK_RT(rtFree(dev_nav_table_)); - GELOGI("SKT: super_kernel args addr free."); - } - } - Status Launch(rtStream_t stream, uint32_t dump_flag); + ~SuperKernel() {} + Status Launch(rtStream_t stream, bool dump_flag); const void *GetFuncStub() const { return func_stub_; } const void *GetNavTablePtr() const { return dev_nav_table_; } uint64_t GetNavTableSize() const { return nav_table_size_; } diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index 63107f5e..ab3f68f1 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -30,26 +30,26 @@ Status SuperKernelFactory::Init() { rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetFunctionByName " - "failed. stub_func: %s", + "failied. stub_func: %s", this->sk_stub_name_.c_str()); return FAILED;) rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); return FAILED;) if (this->use_physical_address_ != nullptr) { void *skt_func = nullptr; rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); return FAILED;) GELOGD( "SKT: fuseKernels super_kernel_template subFunc %p, device func " "address %p, device physic PC %p", - this->func_stub_, this->func_ptr_, skt_func); + (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_, (uint64_t)skt_func); } else { GELOGD( "SKT: fuseKernels super_kernel_template subFunc %p, device func " "address %p", - this->func_stub_, this->func_ptr_); + (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_); } } is_init_ = true; @@ -94,66 +94,63 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); rtError_t rt_ret; - void *hbm_nav_table_addr = nullptr; if (this->use_physical_address_ != nullptr) { for (unsigned i = 0; i < stub_func_list.size(); i++) { void *sub_device_func = nullptr; rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); return FAILED;) void *sub_device_func_pys = nullptr; void *args_addr_pys = nullptr; rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); return FAILED;) GELOGD( "SKT: fuseKernels subFunc %p, device func address %p, device " "physic func address %p", - stub_func_list[i], sub_device_func, sub_device_func_pys); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func_pys)) / 4; - GELOGD("SKT: CALL offset %p", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_pys)); - + stub_func_list[i], (uint64_t)sub_device_func, (uint64_t)sub_device_func_pys); + nav_table[i * 2] = (uint64_t)sub_device_func_pys / 4; + GELOGD("SKT: CALL offet %p", nav_table[i * 2]); + nav_table[i * 2 + 1] = (uint64_t)args_addr_pys; GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); } + void *hbm_nav_table_addr = nullptr; void *hbm_nav_table_addr_pys = nullptr; rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); return FAILED;) - GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); + GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", (uint64_t)hbm_nav_table_addr, + (uint64_t)hbm_nav_table_addr_pys); // Create the necessary metadata for the super kernel h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); } else { for (unsigned i = 0; i < stub_func_list.size(); i++) { void *sub_device_func = nullptr; rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); return FAILED;) - GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], (uint64_t)sub_device_func); + nav_table[i * 2] = (uint64_t)sub_device_func / 4; GELOGD("SKT: CALL offet %p", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); + nav_table[i * 2 + 1] = (uint64_t)args_addr_list[i]; GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); } + void *hbm_nav_table_addr = nullptr; rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) rt_ret = rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) // Create the necessary metadata for the super kernel h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); } diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h index 7ceb5cfa..7b59d4bf 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h @@ -31,12 +31,12 @@ class SuperKernelFactory { const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); bool is_init_ = false; SuperKernelFactory(){}; - ~SuperKernelFactory(){}; public: SuperKernelFactory(SuperKernelFactory const &) = delete; void operator=(SuperKernelFactory const &) = delete; static SuperKernelFactory &GetInstance(); + SuperKernelFactory(const std::string &sk_stub_name_, const std::string &bin_file); Status Init(); Status Uninitialize(); Status FuseKernels(const std::vector &stub_func_list, const std::vector &args_addr_list, diff --git a/src/ge/graph/load/output/output.h b/src/ge/graph/load/output/output.h index d93b8de9..4a3b0db2 100644 --- a/src/ge/graph/load/output/output.h +++ b/src/ge/graph/load/output/output.h @@ -21,14 +21,15 @@ #include #include "common/debug/log.h" +#include "common/op/attr_define.h" #include "common/op/attr_value_util.h" #include "common/op/ge_op_utils.h" +#include "common/op/op_parser_util.h" #include "common/types.h" #include "common/util.h" #include "common/ge_types.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/op_desc.h" -#include "graph/debug/ge_attr_define.h" namespace ge { using std::string; diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index d4680d94..765b2302 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -33,7 +33,6 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "framework/common/ge_types.h" -#include "graph/manager/util/rt_context_util.h" #include "graph/common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" @@ -60,11 +59,19 @@ #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" -#include "graph/passes/replace_with_empty_const_pass.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "init/gelib.h" +using domi::ASSIGN; +using domi::ATTR_MODEL_MEMORY_SIZE; +using domi::ATTR_MODEL_WEIGHT_SIZE; +using domi::ATTR_NAME_SESSION_GRAPH_ID; +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::HCOMBROADCAST; +using domi::VARIABLE; + namespace { const char *const kSummary = "Summary"; const char *const kSave = "Save"; @@ -118,7 +125,6 @@ Status GraphManager::Initialize(const std::map &options) { } graph_map_.clear(); - cache_helper_map_.clear(); init_flag_ = true; thread_run_flag_ = true; @@ -182,7 +188,6 @@ Status GraphManager::Finalize() { } } graph_map_.clear(); - cache_helper_map_.clear(); // graph context if (graph_context_ != nullptr) { @@ -429,13 +434,6 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorSetSubGraph(merged_compute_graph); // set subgraphlist to graphnode graph_node->SetSubGraph(sub_graph_list); - // when set incre build, save om model and var manager - auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), merged_compute_graph, ge_model); - if (save_ret != SUCCESS) { - GELOGW("Fail to save cache."); - } - // release rts generate context - RtContextUtil::GetInstance().DestroyrtContexts(); GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); GEEVENT("[GEPERFTRACE] GE PreRun End"); return ret; @@ -454,14 +452,10 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: return PARAM_INVALID; } GeModelPtr ge_model = nullptr; - // check need incre build. - ret = IncreBuild(graph_node, ge_model); + ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); if (ret != SUCCESS) { - ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); - if (ret != SUCCESS) { - GELOGE(ret, "PreRun Failed."); - return ret; - } + GELOGE(ret, "PreRun Failed."); + return ret; } ret = LoadGraph(ge_model, graph_node); if (ret != SUCCESS) { @@ -506,90 +500,6 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g return SUCCESS; } -Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, - GeModelPtr &ge_model) { - auto graph_id = graph_node->GetGraphId(); - auto ret = cache_helper->LoadOmModelFromCache(ge_model); - if (ret != SUCCESS) { - GELOGW("Fail to load om model from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ret = cache_helper->RecoverVarManagerFromCache(); - if (ret != SUCCESS) { - GELOGW("Fail to recover VarManager from cache."); - if (cache_helper->ClearCache(graph_id) != SUCCESS) { - GELOGW("Fail to clear cache of graph %u.", graph_id); - } - return FAILED; - } - ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); - if (compute_graph_in_model == nullptr) { - GELOGW("Error occurred when get compute graph from om, abandon."); - return FAILED; - } else { - graph_node->SetComputeGraph(compute_graph_in_model); - graph_node->SetGeModel(ge_model); - GELOGI("Load model and graph form cache om file."); - } - return SUCCESS; -} - -Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { - auto ret = cache_helper->SaveCacheInfoToCache(); - if (ret != SUCCESS) { - GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(true); - if (ret != SUCCESS) { - GELOGW("Fail to save var manager to cache."); - cache_helper->ClearCache(graph_id); - return FAILED; - } - GELOGI("Cache files have been saved."); - return SUCCESS; -} - -Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { - GELOGW("GELib not initialized."); - return FAILED; - } - - if (instance_ptr->IsIncreBuild()) { - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } else { - ModelCacheHelperPtr cache_helper = iter->second; - auto ret = cache_helper->RefreshComputeGraph(graph); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to refresh cache helper's compute graph"); - return FAILED; - } - ret = cache_helper->SaveVarManagerToCache(false); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save VarManager to cache"); - return FAILED; - } - ret = cache_helper->SaveOmModelToCache(ge_model); - if (ret != SUCCESS) { - cache_helper->ClearCache(graph_id); - GELOGW("Fail to save om model to cache"); - return FAILED; - } - } - } - return SUCCESS; -} - Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector &inputs, std::vector &outputs) { Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); @@ -649,9 +559,6 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector ge_models; if (options_.local_fmk_op_flag) { @@ -684,7 +591,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorGetSubGraph(); if (IsCheckpointGraph(checkPointGraph)) { - ret = CheckpointHandle(graph_id, checkPointGraph, outputs); + ret = CheckpointHandle(graph_id, outputs); if (ret != SUCCESS) { GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); } @@ -768,15 +675,6 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con return SUCCESS; } -void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { - auto iter = cache_helper_map_.find(graph_id); - if (iter != cache_helper_map_.end()) { - cache_helper_map_.erase(iter); - } else { - GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); - } -} - Status GraphManager::RemoveGraph(const GraphId &graph_id) { auto it = graph_map_.find(graph_id); if (it == graph_map_.end()) { @@ -826,9 +724,6 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { } var_acc_ctrl_.RemoveGraph(graph_id); graph_map_.erase(it); - - RemoveModelCacheHelper(graph_id); - auto ge_model = graph_node->GetGeModel(); if (ge_model != nullptr) { GELOGI("Unload model %u.", ge_model->GetModelId()); @@ -1219,15 +1114,21 @@ Status GraphManager::SummaryHandle(const GraphId &graph_id, std::vector &outputs) { +Status GraphManager::CheckpointHandle(const GraphId &graph_id, const std::vector &outputs) { GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); std::vector outputs_desc = graph_executor_.GetOutputsDesc(); GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); - + // find graph + GraphNodePtr graph_node = nullptr; + Status ret = GetGraphNode(graph_id, graph_node); + if (ret != SUCCESS) { + GELOGE(ret, "[CheckpointHandle] graph not exist, graph_id = %u.", graph_id); + return ret; + } + ComputeGraphPtr compute_graph_ptr = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); std::map save_results; NodePtr netoutput = nullptr; - for (const auto &node : compute_graph->GetDirectNode()) { + for (const auto &node : compute_graph_ptr->GetDirectNode()) { if (node->GetType() == kNetOutput) { netoutput = node; break; @@ -1355,8 +1256,6 @@ bool GraphManager::CheckTransOpForCheckpointGraph(NodePtr &node) { return true; } -static inline bool CheckConstanOpForCheckpointGraph(NodePtr &node) { return node->GetOutDataNodes().empty(); } - bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); @@ -1377,10 +1276,6 @@ bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { if (!CheckTransOpForCheckpointGraph(node)) { return false; } - } else if (op->GetType() == CONSTANTOP) { - if (!CheckConstanOpForCheckpointGraph(node)) { - return false; - } } else if (op->GetType() != kSend && op->GetType() != kRecv) { GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), op->GetName().c_str()); @@ -1407,7 +1302,7 @@ bool GraphManager::IsBroadCastOpData(const ge::NodePtr &var_node) { } void GraphManager::AdjustBroadCastOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore")) { GELOGW("set var_is_restore failed"); } } @@ -1425,7 +1320,7 @@ bool GraphManager::IsAssignOpData(const ge::NodePtr &var_node) { } void GraphManager::AdjustAssignOpData(const ge::NodePtr &var_node) { - if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { + if (!ge::AttrUtils::SetStr(var_node->GetOpDesc(), domi::VAR_ATTR_VAR_IS_RESTORE, "var_is_restore")) { GELOGW("SetStr var_is_restore failed"); } } @@ -1743,51 +1638,6 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - ModelCacheHelperPtr cache_helper = MakeShared(session_id, graph_id, compute_graph); - if (cache_helper != nullptr) { - cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); - } else { - GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); - } - } - } -} - -Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { - std::shared_ptr instance_ptr = ge::GELib::GetInstance(); - if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { - return FAILED; - } - const uint32_t graph_id = graph_node->GetGraphId(); - auto iter = cache_helper_map_.find(graph_id); - if (iter == cache_helper_map_.end()) { - GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); - return FAILED; - } - ModelCacheHelperPtr cache_helper = iter->second; - if (cache_helper->IsModelCacheHit()) { - GEEVENT("Model cache hit."); - Status ret = LoadFromCache(graph_node, cache_helper, ge_model); - if (ret == SUCCESS) { - return SUCCESS; - } else { - GELOGW("Error occurred when load from cache, abandon."); - } - } else { - GEEVENT("Model cache miss."); - } - if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { - GELOGW("Error occurred when save cache."); - } - return FAILED; -} - void GraphManager::PreRunThread(GraphManager *graph_manager) { if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { GELOGW("Set thread name failed."); @@ -1841,8 +1691,6 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { return; } } - // when set incre build, save cache helper. - graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); std::vector ge_models; @@ -1865,15 +1713,12 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { return; } - // check need incre build. - if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { - ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); - if (ret != SUCCESS) { - graph_node->SetRunFlag(false); - ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); - graph_node->Unlock(); - return; - } + ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); + if (ret != SUCCESS) { + graph_node->SetRunFlag(false); + ReturnError(graph_manager, args.callback, ret, "PreRun failed, thread exit."); + graph_node->Unlock(); + return; } graph_node->SetBuildFlag(true); graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index 92ea48c5..5a296b91 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -27,7 +27,6 @@ #include "common/blocking_queue.h" #include "common/ge_inner_error_codes.h" -#include "common/helper/model_cache_helper.h" #include "external/graph/types.h" #include "ge/ge_api_types.h" #include "graph/build/graph_builder.h" @@ -212,8 +211,7 @@ class GraphManager { Status SummaryHandle(const GraphId &graph_id, std::vector &outputs); - Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, - const std::vector &outputs); + Status CheckpointHandle(const GraphId &graph_id, const std::vector &outputs); // call the callback function of ME to push summary result data to ME Status PushSummaryData2ME(const GraphId &graph_id, const std::map &summary_data); @@ -262,13 +260,6 @@ class GraphManager { bool IsGraphNeedBuild(const GraphNodePtr &graph_node); - Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); - Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); - Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); - void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); - Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); - void RemoveModelCacheHelper(const GraphId &graph_id); - static void PreRunThread(GraphManager *graph_manager); static void RunThread(GraphManager *graph_manager); static void StopQueue(GraphManager *graph_manager); @@ -283,8 +274,6 @@ class GraphManager { std::map graph_map_; - std::map cache_helper_map_; - // for run graph synchronous return std::mutex sync_run_mutex_; std::condition_variable condition_; diff --git a/src/ge/graph/manager/graph_manager_utils.cc b/src/ge/graph/manager/graph_manager_utils.cc index 021c0c47..a340ce35 100644 --- a/src/ge/graph/manager/graph_manager_utils.cc +++ b/src/ge/graph/manager/graph_manager_utils.cc @@ -21,8 +21,8 @@ #include "framework/common/debug/ge_log.h" #include "common/ge/ge_util.h" +#include "common/op/attr_define.h" #include "common/string_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/compute_graph.h" #include "graph/op_desc.h" #include "graph/optimize/common/params.h" @@ -148,7 +148,7 @@ bool HasCalcOp(const ComputeGraphPtr &graph) { return false; } - static const std::set calc_op_type = {CONVOLUTION, DECONVOLUTION, FULL_CONNECTION}; + static const std::set calc_op_type = {domi::CONVOLUTION, domi::DECONVOLUTION, domi::FULL_CONNECTION}; for (const auto &node : graph->GetAllNodes()) { OpDescPtr op_desc = node->GetOpDesc(); @@ -167,15 +167,15 @@ Status ParseOutNodes(const string &out_nodes) { domi::GetContext().out_nodes_map.clear(); domi::GetContext().user_out_nodes.clear(); - vector nodes_v = StringUtils::Split(out_nodes, ';'); + vector nodes_v = domi::StringUtils::Split(out_nodes, ';'); for (const string &node : nodes_v) { - vector key_value_v = StringUtils::Split(node, ':'); + vector key_value_v = domi::StringUtils::Split(node, ':'); if (key_value_v.size() != 2) { // must contain 2 items GELOGE(GE_GRAPH_PARAM_NULLPTR, "Invalid outNodes: %s", node.c_str()); return GE_GRAPH_PARAM_NULLPTR; } auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); - int32_t index = std::stoi(StringUtils::Trim(key_value_v[1])); + int32_t index = std::stoi(domi::StringUtils::Trim(key_value_v[1])); if (iter != domi::GetContext().out_nodes_map.end()) { iter->second.emplace_back(index); } else { diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index f40ca7ce..5b76a597 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -19,11 +19,11 @@ #include #include "common/l2_cache_optimize.h" +#include "common/op/attr_define.h" #include "common/types.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "ge/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/attr_utils.h" @@ -64,10 +64,6 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens return SUCCESS; } -void VarResource::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { - var_addr_mgr_map = var_addr_mgr_map_; -} - void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t memory_type) { std::string var_key = VarKey(var_name, tensor_desc); @@ -174,14 +170,6 @@ void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &b var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; } -ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { - if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) { - return FAILED; - } - broad_cast_info = var_broad_cast_info_[graph_id][var_name]; - return SUCCESS; -} - ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { if (var_op_desc == nullptr) { @@ -204,7 +192,7 @@ ge::Status VarResource::SyncBroadCastData2Var(uint32_t graph_id, const std::stri GELOGI("SyncBroadCastData2Var var_name: %s", var_name.c_str()); GE_CHECK_NOTNULL(var_op_desc); string var_is_broadcast; - bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); + bool is_broadcast = AttrUtils::GetStr(var_op_desc, domi::VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); if (!is_broadcast) { return SUCCESS; } @@ -222,7 +210,7 @@ ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_na const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { GE_CHECK_NOTNULL(var_op_desc); string var_is_broadcast; - bool is_broadcast = AttrUtils::GetStr(var_op_desc, VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); + bool is_broadcast = AttrUtils::GetStr(var_op_desc, domi::VAR_ATTR_VAR_IS_BROADCAST, var_is_broadcast); if (!is_broadcast) { return SUCCESS; } @@ -303,8 +291,6 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } -void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; - VarManager::VarManager(uint64_t session_id) : version_(SessionVersion::OTHER_VERSION), session_id_(session_id), @@ -320,9 +306,9 @@ VarManager *VarManager::Instance(uint64_t session_id) { return VarManagerPool::Instance().GetVarManager(session_id); } -void VarManager::Destroy() { +void VarManager::Destory() { std::lock_guard lock(mutex_); - GELOGI("VarManager::Destroy, session id = %lu.", session_id_); + GELOGI("VarManager::Destory, session id = %lu.", session_id_); version_ = SessionVersion::OTHER_VERSION; device_id_ = 0; session_id_ = 0; @@ -381,21 +367,6 @@ ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTenso return ge::SUCCESS; } -ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, - rtMemType_t memory_type) { - GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), - ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), - ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); - - std::lock_guard lock(mutex_); - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type); - return ge::SUCCESS; -} - ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type) { std::lock_guard lock(mutex_); @@ -421,10 +392,6 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); } -void VarManager::GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map) { - var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); -} - int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { std::lock_guard lock(mutex_); MemResource *mem_resource = nullptr; @@ -442,30 +409,6 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { return mem_resource->GetVarMemSize(); } -Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { - std::lock_guard lock(mutex_); - MemResource *mem_resource = nullptr; - auto iter = mem_resource_map_.find(memory_type); - if (iter == mem_resource_map_.end()) { - mem_resource = new (std::nothrow) MemResource(); - if (mem_resource == nullptr) { - GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); - return ge::INTERNAL_ERROR; - } else { - mem_resource_map_[memory_type] = mem_resource; - } - } else { - mem_resource = iter->second; - } - - if (mem_resource == nullptr) { - GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); - return FAILED; - } - mem_resource->UpdateVarMemSize(mem_size); - return SUCCESS; -} - ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, rtMemType_t memory_type) { std::lock_guard lock(mutex_); @@ -608,16 +551,6 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn return SUCCESS; } -ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { - std::lock_guard lock(mutex_); - - if (var_resource_ == nullptr) { - GELOGW("VarManager has not been init."); - return ge::INTERNAL_ERROR; - } - return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); -} - ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { std::lock_guard lock(mutex_); GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); @@ -740,7 +673,6 @@ Status VarManager::SetMemoryMallocSize(const map &options) { GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); return ge::GE_GRAPH_OPTIONS_INVALID; } - GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); } it = options.find(VARIABLE_MEMORY_MAX_SIZE); @@ -853,19 +785,19 @@ void VarManager::RemoveAllocatedGraphId(const std::string &var_name) { var_resource_->RemoveAllocatedGraphId(var_name); } -VarManagerPool::~VarManagerPool() { Destroy(); } +VarManagerPool::~VarManagerPool() { Destory(); } VarManagerPool &VarManagerPool::Instance() { static VarManagerPool var_manager_pool; return var_manager_pool; } -void VarManagerPool::Destroy() noexcept { +void VarManagerPool::Destory() noexcept { std::lock_guard lock(var_manager_mutex_); for (auto &it : var_manager_map_) { VarManager *var_manager = it.second; if (var_manager != nullptr) { - var_manager->Destroy(); + var_manager->Destory(); delete var_manager; var_manager = nullptr; } diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index 8b551e06..a23c45b6 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -101,8 +101,6 @@ class VarResource { ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); - void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); - void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t rtMemType_t); @@ -115,8 +113,6 @@ class VarResource { void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); - ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); @@ -179,8 +175,6 @@ class MemResource { int64_t GetVarMemSize() const; - void UpdateVarMemSize(int64_t mem_size); - private: uint64_t total_size_; uint64_t var_mem_size_; @@ -195,21 +189,16 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status Init(const uint32_t &version, const uint64_t &session_id, const uint32_t &device_id, const uint64_t &job_id); - void Destroy(); + void Destory(); ge::Status AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, rtMemType_t memory_type); ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, rtMemType_t memory_type); - ge::Status SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, - rtMemType_t memory_type); - ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, rtMemType_t &memory_type); - void GetAllVarAddrMgr(std::unordered_map &var_addr_mgr_map); - ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, @@ -217,8 +206,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); - ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); - ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, uint8_t *base_ptr); @@ -264,8 +251,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { int64_t GetVarMemSize(rtMemType_t memory_type); - Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); - bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); bool IsVarExist(const std::string &var_name); @@ -300,7 +285,7 @@ class VarManagerPool { VarManager *GetVarManager(uint64_t session_id); - void Destroy() noexcept; + void Destory() noexcept; ge::Status Init() const; diff --git a/src/ge/graph/manager/util/debug.cc b/src/ge/graph/manager/util/debug.cc index b2ef1c92..3ca4642d 100644 --- a/src/ge/graph/manager/util/debug.cc +++ b/src/ge/graph/manager/util/debug.cc @@ -31,7 +31,7 @@ Debug::Debug() = default; Debug::~Debug() = default; void Debug::DumpProto(const Message &proto, const char *file) { - std::string file_path = RealPath(file); + std::string file_path = domi::RealPath(file); int fd = open(file_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); if (fd == -1) { GELOGW("Write %s failed", file_path.c_str()); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index a1c4d769..6319f985 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -23,6 +23,14 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +using domi::HCOM_ATTR_DATA_TYPE; +using domi::HCOM_ATTR_RANK_SIZE; +using domi::HCOM_ATTR_REDUCE_TYPE; +using domi::HCOM_ATTR_ROOT_RANK; +using domi::HCOM_ATTR_SHAPE; +using domi::HCOMRECEIVE; +using domi::HCOMREDUCESCATTER; + namespace ge { Status HcomOmeUtil::GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type) { GE_CHECK_NOTNULL(op_desc); diff --git a/src/ge/graph/manager/util/variable_accelerate_ctrl.cc b/src/ge/graph/manager/util/variable_accelerate_ctrl.cc index 522b5ee3..3bc61e84 100644 --- a/src/ge/graph/manager/util/variable_accelerate_ctrl.cc +++ b/src/ge/graph/manager/util/variable_accelerate_ctrl.cc @@ -23,9 +23,9 @@ namespace ge { namespace { inline bool IsVariable(const std::string &node_type) { - return node_type == VARIABLE || node_type == VARIABLEV2 || node_type == VARHANDLEOP; -} + return node_type == domi::VARIABLE || node_type == domi::VARIABLEV2 || node_type == domi::VARHANDLEOP; } +} // namespace bool VarAccelerateCtrl::IsVarPermitToChangeFormats(const std::string &var_name) { auto iter = var_names_to_change_times_.find(var_name); @@ -39,9 +39,10 @@ void VarAccelerateCtrl::SetVarChanged(const std::string &var_name) { auto times = ++var_names_to_change_times_[var_name]; for (auto &graph_id_to_var_names : graph_ids_to_var_names_) { if (graph_id_to_var_names.second.count(var_name) > 0) { - GELOGI("The format of var %s has been changed, total changed times %d, " - "the graph %u contains which should be re-build before next run", - var_name.c_str(), times, graph_id_to_var_names.first); + GELOGI( + "The format of var %s has been changed, total changed times %d, " + "the graph %u contains which should be re-build before next run", + var_name.c_str(), times, graph_id_to_var_names.first); /// The graph being compiled right now is also added to the rebuild-list /// and can be deleted by calling `SetGraphBuildEnd` at the end of compilation. graph_ids_need_rebuild_.insert(graph_id_to_var_names.first); diff --git a/src/ge/graph/optimize/common/params.h b/src/ge/graph/optimize/common/params.h index ee2a735b..403e1aa8 100644 --- a/src/ge/graph/optimize/common/params.h +++ b/src/ge/graph/optimize/common/params.h @@ -22,6 +22,10 @@ #include "common/singleton.h" #include "common/types.h" +using domi::TARGET_TYPE_LTTE_8BIT; +using domi::TARGET_TYPE_MINI_8BIT; +using domi::TARGET_TYPE_TINY_8BIT; + namespace ge { class Params : public Singleton { public: diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index 0be0aeee..f1fe27b9 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -26,6 +26,8 @@ #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" +using domi::ATTR_NAME_FRAMEWORK_FWK_TYPE; +using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; using ge::ComputeGraph; using ge::OpDesc; @@ -77,7 +79,7 @@ void AddNodeInputProperty(ComputeGraphPtr &compute_graph) { src_index_list.emplace_back(peer_out_anchor->GetIdx()); node_op_desc->SetSrcName(src_name_list); node_op_desc->SetSrcIndex(src_index_list); - GE_IF_BOOL_EXEC(!(node_op_desc->GetType() == NETOUTPUT && domi::GetContext().type == domi::FMK_TYPE_T), + GE_IF_BOOL_EXEC(!(node_op_desc->GetType() == domi::NETOUTPUT && domi::GetContext().type == domi::FMK_TYPE_T), ge::NodePtr peer_owner_node = peer_out_anchor->GetOwnerNode(); input_name_list = node_op_desc->GetInputName(); input_name_list.emplace_back( peer_owner_node->GetName() + diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index bc8c9b9b..f459a7c2 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -22,8 +22,8 @@ #include #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" +#include "framework/common/op/attr_define.h" #include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" @@ -31,6 +31,21 @@ #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" +using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; +using domi::ATTR_NAME_SESSION_GRAPH_ID; +using domi::ATTR_NAME_STREAM_LABEL; +using domi::END; +using domi::NCHW_DIM_C; +using domi::NCHW_DIM_H; +using domi::NCHW_DIM_N; +using domi::NCHW_DIM_W; +using domi::NHWC_DIM_C; +using domi::NHWC_DIM_H; +using domi::NHWC_DIM_N; +using domi::NHWC_DIM_W; +using domi::PERMUTE_ATTR_ORDER; +using domi::PLACEHOLDER; + namespace { const char *const kEngineDefaultData = "ENGINE_DEFAULT_DATA"; const char *const kEndType = "End"; @@ -50,6 +65,12 @@ Status ge::GraphPartitioner::CheckIfEnd2PldEmpty(ge::ComputeGraphPtr &output_mer return FAILED; } output_merged_compute_graph = partition.first; + // flush all nodes' engine of merged graph + graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); + if (graph_info_.engine_placer_.Run() != SUCCESS) { + GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); + return FAILED; + } } else { // if placeholder to end map is empty, it should be an exception condition GELOGE(GE_GRAPH_EMPTY_PARTITION, "[GraphPartitioner]: placeholder to end map is empty, partitions size is not 1."); return FAILED; @@ -510,8 +531,9 @@ void ge::GraphPartitioner::AddNewGraphToPartition(ge::ComputeGraphPtr &input_gra } bool ge::GraphPartitioner::IsDataLike(ge::NodePtr node) { - return (node->GetType() == CONSTANT) || (node->GetType() == DATA) || (node->GetType() == AIPPDATA) || - (node->GetType() == CONSTANTOP) || (node->GetType() == VARIABLE); + return (node->GetType() == domi::CONSTANT) || (node->GetType() == domi::DATA) || + (node->GetType() == domi::AIPPDATA) || (node->GetType() == domi::CONSTANTOP) || + (node->GetType() == domi::VARIABLE); } bool ge::GraphPartitioner::HasNoInput(ge::NodePtr node) { diff --git a/src/ge/graph/passes/addn_pass.cc b/src/ge/graph/passes/addn_pass.cc index c0592965..7bd32a38 100644 --- a/src/ge/graph/passes/addn_pass.cc +++ b/src/ge/graph/passes/addn_pass.cc @@ -30,7 +30,7 @@ Status AddNPass::Run(NodePtr &node) { return PARAM_INVALID; } - if (node->GetType() == ADDN) { + if (node->GetType() == domi::ADDN) { if (node->GetOpDesc() == nullptr) { GELOGE(PARAM_INVALID, "Param [node] op desc is null."); return PARAM_INVALID; diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.cc b/src/ge/graph/passes/aicpu_constant_folding_pass.cc index e1e6842f..667c22a2 100644 --- a/src/ge/graph/passes/aicpu_constant_folding_pass.cc +++ b/src/ge/graph/passes/aicpu_constant_folding_pass.cc @@ -571,8 +571,8 @@ void AicpuConstantFoldingPass::ReleaseMemory(const vector &input_ad bool AicpuConstantFoldingPass::IsSkipFold(const ge::NodePtr &node) { GE_CHECK_NOTNULL(node); string type = node->GetType(); - if (type == ge::FRAMEWORKOP) { - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { + if (type == domi::FRAMEWORKOP) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { GELOGW("Skip aicpu constant folding on frameworkop node [%s]", node->GetName().c_str()); return true; } diff --git a/src/ge/graph/passes/aicpu_constant_folding_pass.h b/src/ge/graph/passes/aicpu_constant_folding_pass.h index 02babd8e..1ff4722e 100644 --- a/src/ge/graph/passes/aicpu_constant_folding_pass.h +++ b/src/ge/graph/passes/aicpu_constant_folding_pass.h @@ -17,7 +17,6 @@ #ifndef GE_GRAPH_PASSES_AICPU_CONSTANT_FOLDING_PASS_H_ #define GE_GRAPH_PASSES_AICPU_CONSTANT_FOLDING_PASS_H_ -#include #include #include "common/opskernel/ops_kernel_info_store.h" diff --git a/src/ge/graph/passes/assert_pass.cc b/src/ge/graph/passes/assert_pass.cc index 725016a9..3207d7be 100644 --- a/src/ge/graph/passes/assert_pass.cc +++ b/src/ge/graph/passes/assert_pass.cc @@ -38,7 +38,7 @@ Status AssertPass::Run(NodePtr &node) { return PARAM_INVALID; } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == ASSERT) { + if (op_type == domi::ASSERT) { GELOGD("op type is assert."); std::vector nodes_unused; @@ -71,8 +71,8 @@ void AssertPass::CollectUnusedNode(const NodePtr &assert_node, vector & if (src_node != nullptr && src_node->GetOpDesc() != nullptr) { auto size = ++invalid_outdata_info[src_node.get()]; // src_node need to be deleted - if (src_node->GetOutDataNodesSize() == size && src_node->GetOpDesc()->GetType() != DATA && - src_node->GetOpDesc()->GetType() != AIPPDATA) { + if (src_node->GetOutDataNodesSize() == size && src_node->GetOpDesc()->GetType() != domi::DATA && + src_node->GetOpDesc()->GetType() != domi::AIPPDATA) { node_queue.push(src_node); } } diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index e95f0680..87b40170 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -28,6 +28,11 @@ #include "graph/debug/ge_attr_define.h" #include "init/gelib.h" +using domi::ATOMICADDRCLEAN; +using domi::ATTR_NAME_STREAM_LABEL; +using domi::LOOPCOND; +using domi::NODE_NAME_ATOMIC_ADDR_CLEAN; + namespace { bool is_loop_graph = false; } @@ -200,18 +205,7 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { vector op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); for (const auto &op_info : op_info_vec) { if (op_info.isAtomic) { - GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); - // check peer input is DATA - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor->GetPeerOutAnchor() != nullptr && - in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { - auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); - if (peer_in_node->GetType() == DATA) { - GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); - return false; - } - } - } + GELOGI("Recognized atomic op %s from HCCL engine.", op_desc->GetName().c_str()); hcom_node_vec_.push_back(node); return true; } diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index a0742a03..00a9581e 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -22,6 +22,8 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" +using domi::CAST; + namespace ge { Status CastRemovePass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/cast_translate_pass.cc b/src/ge/graph/passes/cast_translate_pass.cc index 2d67b0a8..dfda5d10 100644 --- a/src/ge/graph/passes/cast_translate_pass.cc +++ b/src/ge/graph/passes/cast_translate_pass.cc @@ -23,13 +23,17 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" +using domi::ATTR_NAME_INPUT_DATATYPE; +using domi::ATTR_NAME_OUTPUT_DATATYPE; +using domi::CAST; +using domi::TRANSLATE; + namespace ge { bool CastTranslatePass::CheckInAndOutDataAnchor(NodePtr &node) const { if (node == nullptr) { diff --git a/src/ge/graph/passes/compile_nodes_pass.h b/src/ge/graph/passes/compile_nodes_pass.h index 70f8cbf5..56df7b87 100644 --- a/src/ge/graph/passes/compile_nodes_pass.h +++ b/src/ge/graph/passes/compile_nodes_pass.h @@ -19,9 +19,6 @@ #include #include -#include -#include - #include "inc/graph_pass.h" #include "init/gelib.h" diff --git a/src/ge/graph/passes/constant_fuse_same_pass.cc b/src/ge/graph/passes/constant_fuse_same_pass.cc index 69726e5d..f3ef6352 100644 --- a/src/ge/graph/passes/constant_fuse_same_pass.cc +++ b/src/ge/graph/passes/constant_fuse_same_pass.cc @@ -29,6 +29,9 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" +using domi::CONSTANT; +using domi::CONSTANTOP; + namespace ge { namespace { const size_t kCorrectNum = 1; diff --git a/src/ge/graph/passes/control_op_attr_pass.cc b/src/ge/graph/passes/control_op_attr_pass.cc index 7afa34a4..983f22f1 100644 --- a/src/ge/graph/passes/control_op_attr_pass.cc +++ b/src/ge/graph/passes/control_op_attr_pass.cc @@ -30,6 +30,12 @@ #include "graph/utils/graph_utils.h" #include "init/gelib.h" +using domi::ATTR_NAME_STREAM_LABEL; + +using domi::STREAMACTIVE; +using domi::STREAMSWITCH; +using domi::STREAMSWITCHN; + namespace { const uint32_t kMaxNodeNum = 350; } // namespace diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index ee2198af..a13a84b9 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -27,6 +27,19 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" +using domi::ATTR_NAME_WEIGHTS; + +using domi::CONSTANT; +using domi::CONTROLTRIGGER; +using domi::ENTER; +using domi::IDENTITY; +using domi::LOOPCOND; +using domi::MERGE; +using domi::REFENTER; +using domi::REFMERGE; +using domi::REFSWITCH; +using domi::SWITCH; + namespace ge { Status ControlTriggerPass::Run(ComputeGraphPtr graph) { GELOGD("ControlTriggerPass Enter"); diff --git a/src/ge/graph/passes/control_trigger_pass.h b/src/ge/graph/passes/control_trigger_pass.h index b9fff9b4..39ee515d 100644 --- a/src/ge/graph/passes/control_trigger_pass.h +++ b/src/ge/graph/passes/control_trigger_pass.h @@ -49,4 +49,4 @@ class ControlTriggerPass : public GraphPass { std::unordered_map>> control_trigger_map_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_CONTROL_TRIGGER_PASS_H_ +#endif // GE_GRAPH_PASSES_CONTROL_TRIGGER_PASS_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/dropout_pass.cc b/src/ge/graph/passes/dropout_pass.cc index ab88aa23..f1be5ba0 100644 --- a/src/ge/graph/passes/dropout_pass.cc +++ b/src/ge/graph/passes/dropout_pass.cc @@ -39,7 +39,7 @@ Status DropOutPass::Run(NodePtr &node) { return PARAM_INVALID; } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == DROPOUT) { + if (op_type == domi::DROPOUT) { GELOGD("op type is dropout."); return IsolateAndDeleteNode(node, {0}); } diff --git a/src/ge/graph/passes/end_graph_pass.cc b/src/ge/graph/passes/end_graph_pass.cc index 0a2790a8..8cd5c176 100644 --- a/src/ge/graph/passes/end_graph_pass.cc +++ b/src/ge/graph/passes/end_graph_pass.cc @@ -29,6 +29,10 @@ #include "common/ge/ge_util.h" #include "graph/debug/ge_attr_define.h" +using domi::ENDGRAPH; +using domi::NODE_NAME_END_GRAPH; +using domi::NODE_NAME_NET_OUTPUT; + namespace ge { Status EndGraphPass::Run(ge::ComputeGraphPtr graph) { GELOGI("EndGraphPass Run."); @@ -53,7 +57,7 @@ Status EndGraphPass::Run(ge::ComputeGraphPtr graph) { OpDescPtr op_desc = MakeShared(NODE_NAME_END_GRAPH, ENDGRAPH); GE_CHECK_NOTNULL(op_desc); GELOGI("Create EndGraph op:%s.", op_desc->GetName().c_str()); - (void) AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); + (void)AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); NodePtr end_graph_node = graph->AddNode(op_desc); if (end_graph_node == nullptr) { GELOGI("Add EndGraph:%s node to Graph fail.", op_desc->GetName().c_str()); @@ -69,4 +73,3 @@ Status EndGraphPass::Run(ge::ComputeGraphPtr graph) { return SUCCESS; } } // namespace ge - diff --git a/src/ge/graph/passes/enter_pass.cc b/src/ge/graph/passes/enter_pass.cc index bead855a..af3e4739 100644 --- a/src/ge/graph/passes/enter_pass.cc +++ b/src/ge/graph/passes/enter_pass.cc @@ -23,6 +23,11 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::ENTER; +using domi::REFENTER; + namespace ge { Status EnterPass::Run(NodePtr &node) { GELOGD("EnterPass running"); @@ -50,17 +55,16 @@ Status EnterPass::Run(NodePtr &node) { return SUCCESS; } - bool need_remove_flag = in_node->GetInControlNodes().empty() && - node->GetInControlNodes().empty() && - node->GetOutDataNodes().empty(); + bool need_remove_flag = + in_node->GetInControlNodes().empty() && node->GetInControlNodes().empty() && node->GetOutDataNodes().empty(); if (need_remove_flag) { for (auto &out_ctrl_node : node->GetOutControlNodes()) { if (out_ctrl_node == nullptr) { continue; } if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", - node->GetName().c_str(), out_ctrl_node->GetName().c_str()); + GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(), + out_ctrl_node->GetName().c_str()); return FAILED; } } diff --git a/src/ge/graph/passes/flow_ctrl_pass.cc b/src/ge/graph/passes/flow_ctrl_pass.cc index d144351d..6e933708 100644 --- a/src/ge/graph/passes/flow_ctrl_pass.cc +++ b/src/ge/graph/passes/flow_ctrl_pass.cc @@ -29,6 +29,23 @@ namespace ge { // when namespace change to ge, please delete the using code. +using domi::NODE_NAME_FLOWCTRL_LOOP_ASSIGN; +using domi::NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD; +using domi::NODE_NAME_FLOWCTRL_LOOP_COND; +using domi::NODE_NAME_FLOWCTRL_LOOP_INCREMENT; +using domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER; +using domi::NODE_NAME_FLOWCTRL_LOOP_RESETVALUE; +using domi::NODE_NAME_STREAM_SWITCH; + +using domi::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +using domi::TRUE_STREAM_ID; + +using domi::ASSIGN; +using domi::ASSIGNADD; +using domi::STREAMACTIVE; +using domi::STREAMSWITCH; +using domi::VARIABLE; + Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { GE_CHECK_NOTNULL(compute_graph); @@ -188,9 +205,9 @@ NodePtr FlowCtrlPass::AddVariableNode(ComputeGraphPtr &compute_graph, const stri } Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) { - NodePtr output_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT); + NodePtr output_node = compute_graph->FindNode(domi::NODE_NAME_NET_OUTPUT); if (output_node == nullptr) { - GELOGD("Node %s can't be found in graph %u", NODE_NAME_NET_OUTPUT.c_str(), compute_graph->GetGraphID()); + GELOGD("Node %s can't be found in graph %u", domi::NODE_NAME_NET_OUTPUT.c_str(), compute_graph->GetGraphID()); return SUCCESS; } @@ -203,16 +220,17 @@ Status FlowCtrlPass::AddGlobalStepVariableNode(ComputeGraphPtr &compute_graph) { return SUCCESS; } - NodePtr exist_node = compute_graph->FindNode(NODE_NAME_GLOBAL_STEP); + NodePtr exist_node = compute_graph->FindNode(domi::NODE_NAME_GLOBAL_STEP); if (exist_node != nullptr) { - GELOGD("Node %s already exist, no need add.", NODE_NAME_GLOBAL_STEP.c_str()); + GELOGD("Node %s already exist, no need add.", domi::NODE_NAME_GLOBAL_STEP.c_str()); return SUCCESS; } // set global step tensor desc GeTensorDesc tensor_desc(GeShape({1}), FORMAT_ND, DT_UINT64); std::vector input_desc_list = {}; std::vector output_desc_list = {tensor_desc}; - NodePtr global_step = InsertOp(compute_graph, VARIABLE, NODE_NAME_GLOBAL_STEP, input_desc_list, output_desc_list); + NodePtr global_step = + InsertOp(compute_graph, VARIABLE, domi::NODE_NAME_GLOBAL_STEP, input_desc_list, output_desc_list); if (global_step == nullptr) { GELOGE(FAILED, "Add global_step node failed, global_step is null."); return FAILED; diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.cc b/src/ge/graph/passes/folding_kernel/add_kernel.cc index 89f99938..2d786a4a 100644 --- a/src/ge/graph/passes/folding_kernel/add_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/add_kernel.cc @@ -23,6 +23,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::ADD; namespace ge { namespace { const size_t kAddFirstInput = 0; diff --git a/src/ge/graph/passes/folding_kernel/add_kernel.h b/src/ge/graph/passes/folding_kernel/add_kernel.h index f8fd272e..218bc12a 100644 --- a/src/ge/graph/passes/folding_kernel/add_kernel.h +++ b/src/ge/graph/passes/folding_kernel/add_kernel.h @@ -38,4 +38,4 @@ class AddKernel : public Kernel { std::vector &v_output); }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_ADD_KERNEL_H_ +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_ADD_KERNEL_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc b/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc index 364fb415..212fd419 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/broadcast_args_kernel.cc @@ -26,6 +26,8 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" +using domi::BROADCASTARGS; + namespace ge { namespace { const size_t kBCastArgsInputsSize = 2; diff --git a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc b/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc index 5fd5d576..826d3471 100644 --- a/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc @@ -27,6 +27,8 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" +using domi::BROADCASTGRADIENTARGS; + namespace ge { namespace { const size_t kBCastGradArgsInputsSize = 2; @@ -42,12 +44,13 @@ Status BroadcastGradientArgsKernel::Compute(const OpDescPtr op_desc_ptr, const s } // check input size bool size_check_fail = - (op_desc_ptr->GetAllInputsDesc().size() != kBCastGradArgsInputsSize || input.size() != kBCastGradArgsInputsSize || - op_desc_ptr->GetAllOutputsDesc().size() != kBCastGradArgsOutputsSize); + (op_desc_ptr->GetAllInputsDesc().size() != kBCastGradArgsInputsSize || input.size() != kBCastGradArgsInputsSize || + op_desc_ptr->GetAllOutputsDesc().size() != kBCastGradArgsOutputsSize); if (size_check_fail) { - GELOGW("input/output size error. InDesc size:%zu," - "OutDesc size:%zu, in size:%zu ", - op_desc_ptr->GetAllInputsDesc().size(), op_desc_ptr->GetAllOutputsDesc().size(), input.size()); + GELOGW( + "input/output size error. InDesc size:%zu," + "OutDesc size:%zu, in size:%zu ", + op_desc_ptr->GetAllInputsDesc().size(), op_desc_ptr->GetAllOutputsDesc().size(), input.size()); return NOT_CHANGED; } diff --git a/src/ge/graph/passes/folding_kernel/cast_kernel.cc b/src/ge/graph/passes/folding_kernel/cast_kernel.cc index 99944c20..54634737 100644 --- a/src/ge/graph/passes/folding_kernel/cast_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/cast_kernel.cc @@ -33,6 +33,11 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::CAST; +using domi::PARAM_INVALID; +using domi::Status; +using domi::SUCCESS; + namespace ge { namespace { const size_t kCastInputSize = 1; @@ -49,11 +54,9 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); - // src_data == nullptr is supported - if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); + if (op_desc_ptr == nullptr || src_data == nullptr) { + GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr or src_data is nullptr."); return PARAM_INVALID; } GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); @@ -75,7 +78,7 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().GetSize() == 0 is supported + GE_CHECK_SIZE(const_weight_ptr->GetData().GetSize()); auto src_data_size = src_shape.GetShapeSize(); if (src_data_size == 0 && static_cast(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { @@ -115,6 +118,7 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorSetData(trans_result.data.get(), trans_result.length) != SUCCESS) { GELOGW("Compute: SetData failed"); + return FAILED; } v_output.push_back(output_ptr); return SUCCESS; diff --git a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc b/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc index da552081..2217c58e 100644 --- a/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/concat_offset_kernel.cc @@ -25,6 +25,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::CONCATOFFSET; + namespace ge { namespace { const size_t kConcatOffsetInputIndexZero = 0; @@ -100,4 +102,4 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector &input, GeTen auto output_size = merged_shape.GetShapeSize(); int64_t data_size = GetSizeByDataType(data_type); auto step = merged_shape.GetDim(kMergedShapeSecondDim); - if (!CheckInt64MulOverflow(output_size, data_size) || !CheckInt64MulOverflow(step, data_size)) { + if (!domi::CheckInt64MulOverflow(output_size, data_size) || !domi::CheckInt64MulOverflow(step, data_size)) { GELOGW("Check int64 mul overflow failed. Output_size is %ld, data_size is %ld, step is %ld.", output_size, data_size, step); return NOT_CHANGED; @@ -193,7 +196,7 @@ Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vec allowance += data_unit; } indices_set.insert(input_indices[j]); - if (!CheckInt64MulOverflow(input_indices[j], data_unit)) { + if (!domi::CheckInt64MulOverflow(input_indices[j], data_unit)) { GELOGW("Check int64 mul overflow failed. Indices is %ld, data_unit is %ld.", input_indices[j], data_unit); return NOT_CHANGED; } diff --git a/src/ge/graph/passes/folding_kernel/empty_kernel.cc b/src/ge/graph/passes/folding_kernel/empty_kernel.cc index 1b135b9c..6d882ef9 100644 --- a/src/ge/graph/passes/folding_kernel/empty_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/empty_kernel.cc @@ -28,6 +28,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::EMPTY; + namespace ge { namespace { const size_t kEmptyFirstInput = 0; diff --git a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc b/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc index 3d999a02..6abf6cfb 100644 --- a/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/expanddims_kernel.cc @@ -25,6 +25,8 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" +using domi::EXPANDDIMS; + namespace ge { namespace { const int kExpandDimsIndexZero = 0; @@ -50,8 +52,7 @@ Status ExpanddimsKernel::Compute(const NodePtr &node_ptr) { GELOGI("Expanddims dimension kernel success."); return SUCCESS; } -Status ExpanddimsKernel::Compute(const ge::OpDescPtr op_desc_ptr, - const std::vector &input, +Status ExpanddimsKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, std::vector &v_output) { GELOGI("Expanddims folding kernel in."); if (op_desc_ptr == nullptr) { diff --git a/src/ge/graph/passes/folding_kernel/fill_kernel.cc b/src/ge/graph/passes/folding_kernel/fill_kernel.cc index 514a84c9..8c453e74 100644 --- a/src/ge/graph/passes/folding_kernel/fill_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/fill_kernel.cc @@ -27,6 +27,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::FILL; using ge::fp16_t; using ge::Status; diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc b/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc index 81595822..dc9602bb 100644 --- a/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/floordiv_kernel.cc @@ -28,6 +28,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::FLOORDIV; + namespace ge { namespace { const size_t kFloorDivInputX = 0; diff --git a/src/ge/graph/passes/folding_kernel/floordiv_kernel.h b/src/ge/graph/passes/folding_kernel/floordiv_kernel.h index c8505731..a692ff67 100644 --- a/src/ge/graph/passes/folding_kernel/floordiv_kernel.h +++ b/src/ge/graph/passes/folding_kernel/floordiv_kernel.h @@ -47,4 +47,4 @@ class FloorDivKernel : public Kernel { }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_FLOORDIV_KERNEL_H_ +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_FLOORDIV_KERNEL_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc b/src/ge/graph/passes/folding_kernel/floormod_kernel.cc index d7fb3b1c..a7fbf1e3 100644 --- a/src/ge/graph/passes/folding_kernel/floormod_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/floormod_kernel.cc @@ -27,6 +27,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::FLOORMOD; + namespace ge { namespace { const size_t kFloorModInputX = 0; diff --git a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc b/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc index 92c9e035..916708f1 100644 --- a/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/gather_v2_kernel.cc @@ -29,6 +29,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::GATHERV2; using ge::fp16_t; namespace ge { @@ -176,7 +177,7 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x if (data_num <= 0) { return PARAM_INVALID; } - if (!CheckInt64MulOverflow(data_num, sizeof(T))) { + if (!domi::CheckInt64MulOverflow(data_num, sizeof(T))) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T)); return PARAM_INVALID; } @@ -220,7 +221,7 @@ Status GatherV2Kernel::CalcStride(std::vector &stride, std::vector= 0) { size_t index = static_cast(i) + kGatherV2DimOne; - if (!CheckInt64MulOverflow(stride[index], dims[index])) { + if (!domi::CheckInt64MulOverflow(stride[index], dims[index])) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/greater_kernel.cc b/src/ge/graph/passes/folding_kernel/greater_kernel.cc index 4b4caa3a..944cd1b2 100644 --- a/src/ge/graph/passes/folding_kernel/greater_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/greater_kernel.cc @@ -29,6 +29,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::GREATER; using domi::Status; using domi::SUCCESS; using ge::fp16_t; diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.cc b/src/ge/graph/passes/folding_kernel/kernel_utils.cc index 2002643a..9448b232 100644 --- a/src/ge/graph/passes/folding_kernel/kernel_utils.cc +++ b/src/ge/graph/passes/folding_kernel/kernel_utils.cc @@ -113,26 +113,12 @@ bool KernelUtils::CheckSizeForTransOp(const ge::ConstGeTensorPtr &const_weight_p GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); - if (shape_size != 0) { - // Standard tensor - if (data_size != static_cast(cal_size) || data_size == 0) { - GELOGW("Const input data size is not equal with tensor desc shape"); - return false; - } - } else if (data_shape.GetDimNum() != 0) { - // Empty tensor, has zero in shape vector - if (data_size != 0) { - GELOGW("Const input data size is not equal with tensor desc shape"); - return false; - } - } else { - // Scalar tensor, has only one element in tensor - if (length != 0 && (data_size / static_cast(length) != 1)) { + if ((shape_size != 0) || (length != 0 && (data_size / static_cast(length) != 1))) { + if (!(data_size == static_cast(cal_size) && data_size != 0)) { GELOGW("Const input data size is not equal with tensor desc shape"); return false; } } - return true; } diff --git a/src/ge/graph/passes/folding_kernel/kernel_utils.h b/src/ge/graph/passes/folding_kernel/kernel_utils.h index 17b645aa..05f201e9 100644 --- a/src/ge/graph/passes/folding_kernel/kernel_utils.h +++ b/src/ge/graph/passes/folding_kernel/kernel_utils.h @@ -29,7 +29,6 @@ namespace ge { class KernelUtils { public: KernelUtils() = delete; - ~KernelUtils() = delete; static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); static bool CheckFormatSupported(const NodePtr &node_ptr); static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); @@ -45,7 +44,7 @@ class KernelUtils { template static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { if (data_num > 0) { - if (!CheckInt64MulOverflow(data_num, static_cast(sizeof(T)))) { + if (!domi::CheckInt64MulOverflow(data_num, static_cast(sizeof(T)))) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%zu)", data_num, sizeof(T)); return PARAM_INVALID; } @@ -93,7 +92,7 @@ class KernelUtils { vec_dim.clear(); break; } - if (!CheckInt64MulOverflow(data_num, dim)) { + if (!domi::CheckInt64MulOverflow(data_num, dim)) { GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) dim(%ld)", data_num, static_cast(dim)); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc b/src/ge/graph/passes/folding_kernel/maximum_kernel.cc index 9dd84f0a..89b3b159 100644 --- a/src/ge/graph/passes/folding_kernel/maximum_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/maximum_kernel.cc @@ -29,6 +29,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::MAXIMUM; using ge::fp16_t; namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/mul_kernel.cc b/src/ge/graph/passes/folding_kernel/mul_kernel.cc index 4ca740d1..4b1984e2 100644 --- a/src/ge/graph/passes/folding_kernel/mul_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/mul_kernel.cc @@ -29,6 +29,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::MUL; + namespace ge { namespace { const std::set kMulSupportedType = {DT_INT8, DT_INT16, DT_INT32, DT_INT64, DT_UINT8, DT_UINT16, diff --git a/src/ge/graph/passes/folding_kernel/pack_kernel.cc b/src/ge/graph/passes/folding_kernel/pack_kernel.cc index 5db3b394..f9587771 100644 --- a/src/ge/graph/passes/folding_kernel/pack_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/pack_kernel.cc @@ -29,6 +29,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::PACK; namespace { const int64_t kShapeItemNumMAX = 2000000000; } // namespace @@ -67,8 +68,8 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, return PARAM_INVALID; } if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { - n_ = 0; - GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); + GELOGE(PARAM_INVALID, "Attr %s is not exist.", PACK_ATTR_NAME_NUM.c_str()); + return PARAM_INVALID; } if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); @@ -105,7 +106,11 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } - + // check if tensor contains data + if (input[i]->GetData().size() == 0) { + GELOGW("Inputs %ld do not have value.", i); + return NOT_CHANGED; + } if (i == 0) { // get first input shape shape = input[0]->GetTensorDesc().GetShape(); @@ -123,8 +128,8 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v auto dst_shape = tensor_desc.GetShape(); int64_t num = 1; for (auto dim : dst_shape.GetDims()) { - if (dim < 0) { - GELOGW("Invalid dim ld% in the shape %s", dim, formats::ShapeToString(shape).c_str()); + if (dim < 1) { + GELOGW("Invalid zero dim in the shape %s", formats::ShapeToString(shape).c_str()); return NOT_CHANGED; } num *= dim; @@ -137,12 +142,6 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v GELOGW("Shape of input %ld is not equal wiht input 0.", i); return NOT_CHANGED; } - - // check tensor data size is zero ot not - if (input[i]->GetData().size() == 0 && num != 0) { - GELOGW("Inputs %ld do not have value.", i); - return NOT_CHANGED; - } } return SUCCESS; } @@ -169,13 +168,6 @@ void PackKernel::ExpandDims(const int64_t axis, const std::vector &input, ge::GeTensorPtr &output_ptr) { - output_ptr->MutableTensorDesc().SetShape(final_shape); - output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); - if (final_shape.GetShapeSize() == 0 && final_shape.GetDims().size() != 0) { - // means has zero in shape list, output tnesor data is []. - return SUCCESS; - } - int64_t times = 1; int64_t unit = 1; // calculate data unit @@ -219,6 +211,8 @@ Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector< if (output_ptr->SetData(buf.get(), static_cast(output_size * data_size)) != GRAPH_SUCCESS) { GELOGW("CopyOutputData: SetData failed"); } + output_ptr->MutableTensorDesc().SetShape(final_shape); + output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); return SUCCESS; } diff --git a/src/ge/graph/passes/folding_kernel/permute_kernel.cc b/src/ge/graph/passes/folding_kernel/permute_kernel.cc index 551ef59e..ce0737f0 100644 --- a/src/ge/graph/passes/folding_kernel/permute_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/permute_kernel.cc @@ -33,6 +33,13 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "framework/common/ge_inner_error_codes.h" +using domi::PARAM_INVALID; +using domi::PERMUTE; +using domi::Status; +using domi::SUCCESS; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace ge { namespace { const char *const kAttrOrder = "order"; diff --git a/src/ge/graph/passes/folding_kernel/range_kernel.cc b/src/ge/graph/passes/folding_kernel/range_kernel.cc index 8bcfa254..c284deab 100644 --- a/src/ge/graph/passes/folding_kernel/range_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/range_kernel.cc @@ -28,6 +28,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::RANGE; + namespace ge { namespace { constexpr size_t kRangeInputNum = 3; diff --git a/src/ge/graph/passes/folding_kernel/rank_kernel.cc b/src/ge/graph/passes/folding_kernel/rank_kernel.cc index ae14354b..03531654 100644 --- a/src/ge/graph/passes/folding_kernel/rank_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/rank_kernel.cc @@ -25,6 +25,7 @@ #include "inc/kernel_factory.h" #include "omg/omg_inner_types.h" +using domi::RANK; using ge::Status; namespace { diff --git a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc index b7fd11b1..8ffbb9b8 100644 --- a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.cc @@ -28,6 +28,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::REDUCEPROD; + namespace ge { namespace { const size_t kReduceProdDataIndex = 0; @@ -63,7 +65,10 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } - + if (data_tensor->GetData().size() == 0 || axis_tensor->GetData().size() == 0) { + GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); + return PARAM_INVALID; + } DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", @@ -148,6 +153,7 @@ Status ReduceProdKernel::DataCal(const std::vector &input, static_cast(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, GELOGW("set data failed"); return INTERNAL_ERROR); + output_ptr->MutableTensorDesc().SetDataType(data_dtype); } return SUCCESS; } @@ -256,32 +262,19 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec if (ret != SUCCESS) { return NOT_CHANGED; } - } else if (input.at(kReduceProdAxisIndex)->GetData().size() == 0) { - // axis tensor value is [], means no process for input - output_ptr->MutableTensorDesc().SetShape(input.at(kReduceProdDataIndex)->GetTensorDesc().GetShape()); - output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); - if (output_ptr->SetData(input.at(kReduceProdDataIndex)->GetData()) != GRAPH_SUCCESS) { - GELOGW("Compute: SetData failed"); - } } else { // calculate axis to reduce ret = AxisCal(input); if (ret != SUCCESS) { return NOT_CHANGED; } - // calculate and set shape - ShapeCal(op_desc_ptr, input, output_ptr); - // set data type - output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); - - // data size == 0 means input tensor has zero in shape, and tensor value is []. - if (input.at(kReduceProdDataIndex)->GetData().size() != 0) { - // calculate data and data type - ret = DataCal(input, output_ptr); - if (ret != SUCCESS) { - return NOT_CHANGED; - } + // calculate data and data type + ret = DataCal(input, output_ptr); + if (ret != SUCCESS) { + return NOT_CHANGED; } + // calculate shape + ShapeCal(op_desc_ptr, input, output_ptr); } // print output tensor information, and will be deleted diff --git a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h index 326dd2f5..4b858b4a 100644 --- a/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h +++ b/src/ge/graph/passes/folding_kernel/reduce_prod_kernel.h @@ -42,4 +42,4 @@ class ReduceProdKernel : public Kernel { }; } // namespace ge -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_REDUCE_PROD_KERNEL_H_ +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_REDUCE_PROD_KERNEL_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc b/src/ge/graph/passes/folding_kernel/reformat_kernel.cc index 8829d4c4..1e43a073 100644 --- a/src/ge/graph/passes/folding_kernel/reformat_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/reformat_kernel.cc @@ -26,6 +26,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::REFORMAT; + namespace ge { namespace { const size_t kReFormatInputSize = 1; diff --git a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc b/src/ge/graph/passes/folding_kernel/reshape_kernel.cc index 4e925836..525b4e03 100644 --- a/src/ge/graph/passes/folding_kernel/reshape_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/reshape_kernel.cc @@ -23,6 +23,8 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" +using domi::RESHAPE; + namespace ge { namespace { const int kReshapeDataIndex = 0; diff --git a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc b/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc index 44da2bef..809578eb 100644 --- a/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/rsqrt_kernel.cc @@ -28,6 +28,10 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" +using domi::PARAM_INVALID; +using domi::RSQRT; +using domi::SUCCESS; + namespace ge { namespace { const size_t kRsqrtInputSize = 1; diff --git a/src/ge/graph/passes/folding_kernel/shape_kernel.cc b/src/ge/graph/passes/folding_kernel/shape_kernel.cc index f7475b91..9cb005c9 100644 --- a/src/ge/graph/passes/folding_kernel/shape_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/shape_kernel.cc @@ -24,6 +24,8 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" +using domi::SHAPE; + namespace ge { namespace { const size_t kShapeInputSize = 1; diff --git a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc b/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc index 8ed546de..b7844876 100644 --- a/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/shape_n_kernel.cc @@ -24,6 +24,8 @@ #include "graph/passes/pass_utils.h" #include "inc/kernel_factory.h" +using domi::SHAPEN; + namespace ge { Status ShapeNKernel::Compute(const NodePtr &node, std::vector &v_output) { GELOGD("ShapeN kernel in"); diff --git a/src/ge/graph/passes/folding_kernel/size_kernel.cc b/src/ge/graph/passes/folding_kernel/size_kernel.cc index 3b121ba4..8f9ef8dd 100644 --- a/src/ge/graph/passes/folding_kernel/size_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/size_kernel.cc @@ -30,6 +30,7 @@ #include "inc/kernel_factory.h" #include "omg/omg_inner_types.h" +using domi::SIZE; namespace ge { namespace { const size_t kSizeInputSize = 1; @@ -62,7 +63,7 @@ Status SizeKernel::Compute(const NodePtr &node, std::vector &v_outp int64_t size = 1; // Calculate the number of elements of the sensor for (int64_t dim : op_desc->GetInputDesc(0).GetShape().GetDims()) { - if (!CheckInt64MulOverflow(size, dim)) { + if (!domi::CheckInt64MulOverflow(size, dim)) { GELOGE(INTERNAL_ERROR, "int64 overflow!"); return INTERNAL_ERROR; } diff --git a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc b/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc index 2660537a..aaac2b44 100644 --- a/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/slice_d_kernel.cc @@ -26,6 +26,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::SLICED; using ge::fp16_t; namespace ge { diff --git a/src/ge/graph/passes/folding_kernel/slice_kernel.cc b/src/ge/graph/passes/folding_kernel/slice_kernel.cc index a1250367..30baa934 100644 --- a/src/ge/graph/passes/folding_kernel/slice_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/slice_kernel.cc @@ -25,6 +25,8 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::SLICE; + namespace ge { namespace { const size_t kSliceInputSize = 3; diff --git a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc b/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc index b253f9a9..dec5db50 100644 --- a/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/squeeze_kernel.cc @@ -23,6 +23,8 @@ #include "graph/passes/folding_kernel/kernel_utils.h" #include "inc/kernel_factory.h" +using domi::SQUEEZE; + namespace { constexpr uint32_t kInputDescIndex = 0; constexpr uint32_t kOutputDescIndex = 0; diff --git a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc b/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc index 15985c5d..42e97a7e 100644 --- a/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/ssd_prior_box_kernel.cc @@ -24,12 +24,25 @@ #include "common/math/math_util.h" #include "common/math_util.h" #include "common/types.h" +#include "framework/common/op/attr_define.h" #include "framework/common/util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/attr_utils.h" #include "inc/kernel_factory.h" +using domi::NnSet; +using domi::SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; +using domi::SSD_PRIOR_BOX_ATTR_CLIP; +using domi::SSD_PRIOR_BOX_ATTR_FLIP; +using domi::SSD_PRIOR_BOX_ATTR_IMG_H; +using domi::SSD_PRIOR_BOX_ATTR_IMG_W; +using domi::SSD_PRIOR_BOX_ATTR_MAX_SIZE; +using domi::SSD_PRIOR_BOX_ATTR_MIN_SIZE; +using domi::SSD_PRIOR_BOX_ATTR_OFFSET; +using domi::SSD_PRIOR_BOX_ATTR_STEP_H; +using domi::SSD_PRIOR_BOX_ATTR_STEP_W; +using domi::SSD_PRIOR_BOX_ATTR_VARIANCE; +using domi::SSDPRIORBOX; namespace ge { namespace { const float kMinistBias = 1e-6; diff --git a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc b/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc index 224cf7a8..fa89249d 100644 --- a/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/strided_slice_kernel.cc @@ -27,6 +27,13 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::STRIDE_SLICE_ATTR_BEGIN_MASK; +using domi::STRIDE_SLICE_ATTR_ELLIPSIS_MASK; +using domi::STRIDE_SLICE_ATTR_END_MASK; +using domi::STRIDE_SLICE_ATTR_NEW_AXIS_MASK; +using domi::STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; +using domi::STRIDEDSLICE; + namespace ge { namespace { const int32_t kNumOne = 1; diff --git a/src/ge/graph/passes/folding_kernel/sub_kernel.cc b/src/ge/graph/passes/folding_kernel/sub_kernel.cc index 5934c6c1..c02f78f6 100644 --- a/src/ge/graph/passes/folding_kernel/sub_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/sub_kernel.cc @@ -27,6 +27,7 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::SUB; namespace ge { namespace { const size_t kSubFirstInput = 0; diff --git a/src/ge/graph/passes/folding_kernel/transdata_kernel.cc b/src/ge/graph/passes/folding_kernel/transdata_kernel.cc index d3637169..a5a9ccf4 100644 --- a/src/ge/graph/passes/folding_kernel/transdata_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/transdata_kernel.cc @@ -33,6 +33,11 @@ #include "graph/utils/type_utils.h" #include "inc/kernel_factory.h" +using domi::PARAM_INVALID; +using domi::Status; +using domi::SUCCESS; +using domi::TRANSDATA; + namespace ge { namespace { const size_t kTransdataInputSize = 1; @@ -48,9 +53,8 @@ Status TransdataKernel::ValidateInput(const OpDescPtr &op_desc_ptr, const std::v GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); return PARAM_INVALID; } - - // src_data == nullptr is supported - if (op_desc_ptr == nullptr) { + const uint8_t *src_data = const_weight_ptr->GetData().data(); + if (op_desc_ptr == nullptr || src_data == nullptr) { GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); return PARAM_INVALID; } diff --git a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc b/src/ge/graph/passes/folding_kernel/unpack_kernel.cc index 92ad140a..985f822b 100644 --- a/src/ge/graph/passes/folding_kernel/unpack_kernel.cc +++ b/src/ge/graph/passes/folding_kernel/unpack_kernel.cc @@ -22,6 +22,8 @@ #include "graph/debug/ge_attr_define.h" #include "inc/kernel_factory.h" +using domi::UNPACK; + namespace ge { namespace { const size_t kUnpackInputNum = 1; diff --git a/src/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc index 41528ec3..dedf095d 100644 --- a/src/ge/graph/passes/folding_pass.cc +++ b/src/ge/graph/passes/folding_pass.cc @@ -20,7 +20,6 @@ #include #include #include -#include #include "framework/common/debug/ge_log.h" #include "graph/utils/graph_utils.h" @@ -28,6 +27,7 @@ #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "graph/debug/ge_attr_define.h" +#include "framework/common/op/attr_define.h" #include "ge_local_engine/engine/host_cpu_engine.h" namespace ge { @@ -39,8 +39,8 @@ shared_ptr GetKernelByType(const NodePtr &node) { } KernelFactory &factory = KernelFactory::Instance(); string type = node->GetType(); - if (type == FRAMEWORKOP) { - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { + if (type == domi::FRAMEWORKOP) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { return nullptr; } } @@ -49,7 +49,7 @@ shared_ptr GetKernelByType(const NodePtr &node) { } bool IsNoNeedConstantFolding(const NodePtr &node) { auto node_desc = node->GetOpDesc(); - return node_desc == nullptr || node_desc->HasAttr(ATTR_NO_NEED_CONSTANT_FOLDING); + return node_desc == nullptr || node_desc->HasAttr(domi::ATTR_NO_NEED_CONSTANT_FOLDING); } } // namespace folding_pass @@ -100,7 +100,7 @@ NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tens } desc->SetName(name); - desc->SetType(IDENTITY); + desc->SetType(domi::IDENTITY); auto ret = desc->AddInputDesc(tensor); auto ret2 = desc->AddOutputDesc(tensor); if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) { @@ -170,7 +170,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) { if (in_node == nullptr) { continue; } - if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) { + if ((in_node->GetType() == domi::SWITCH) || (in_node->GetType() == domi::REFSWITCH)) { GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str()); auto ret = in_node_anchor->Unlink(in_data_anchor); if (ret != SUCCESS) { @@ -257,9 +257,9 @@ Status FoldingPass::AddConstNode(NodePtr &node, IndexsToAnchors indexes_to_ancho } GE_CHECK_NOTNULL(node->GetOpDesc()); std::string stream_label; - if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + if (AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_STREAM_LABEL, stream_label)) { GE_CHECK_NOTNULL(const_node->GetOpDesc()); - if (!AttrUtils::SetStr(const_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + if (!AttrUtils::SetStr(const_node->GetOpDesc(), domi::ATTR_NAME_STREAM_LABEL, stream_label)) { GELOGE(INTERNAL_ERROR, "Failed to set stream label on dynamic const node %s, with stream label:%s.", const_node->GetName().c_str(), stream_label.c_str()); return INTERNAL_ERROR; diff --git a/src/ge/graph/passes/get_original_format_pass.cc b/src/ge/graph/passes/get_original_format_pass.cc index 5b7e84c2..d065d581 100644 --- a/src/ge/graph/passes/get_original_format_pass.cc +++ b/src/ge/graph/passes/get_original_format_pass.cc @@ -26,11 +26,19 @@ #include "graph/utils/op_desc_utils.h" #include "framework/omg/omg_inner_types.h" +using domi::AIPP_DATA_TYPE; +using domi::ATTR_NAME_FORMAT; +using domi::ATTR_NAME_IGNORE_PRED_FORMAT; +using domi::ATTR_NAME_INFERRED_FORMAT; +using domi::BIASADD; +using domi::DATA_TYPE; using domi::DOMI_TENSOR_NCHW; using domi::DOMI_TENSOR_NHWC; using domi::DOMI_TENSOR_RESERVED; using domi::FAILED; using domi::PARAM_INVALID; +using domi::PERMUTE; +using domi::PERMUTE_ATTR_ORDER; using domi::SUCCESS; using domi::GetContext; @@ -62,8 +70,8 @@ Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph GE_CHECK_NOTNULL(desc_ptr); auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE); if (is_data) { - GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), domi::GetContext().format); - ori_format = static_cast(domi::GetContext().format); + GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetContext().format); + ori_format = static_cast(GetContext().format); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_FORMAT failed"); return FAILED); diff --git a/src/ge/graph/passes/guarantee_const_pass.cc b/src/ge/graph/passes/guarantee_const_pass.cc index f099c01d..8c34b8f5 100644 --- a/src/ge/graph/passes/guarantee_const_pass.cc +++ b/src/ge/graph/passes/guarantee_const_pass.cc @@ -25,6 +25,8 @@ #include "graph/utils/attr_utils.h" #include "graph/utils/graph_utils.h" +using domi::GUARANTEECONST; + namespace ge { namespace { const uint32_t kGuaranteeConstInputsSize = 1; diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index ac037d62..60001e30 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -25,6 +25,9 @@ #include "framework/common/types.h" #include "graph/utils/graph_utils.h" +using domi::CONSTANTOP; +using domi::DATA; + namespace { const int32_t kAnchorSize = 1; const int kAnchorNum = 0; @@ -53,7 +56,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { NodePtr src_node = src_out_anchor->GetOwnerNode(); std::string src_type = src_node->GetType(); bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA); - if (check_src_type && node->GetType() == HCOMALLREDUCE) { + if (check_src_type && node->GetType() == domi::HCOMALLREDUCE) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -88,9 +91,9 @@ NodePtr HcclMemcpyPass::CreateMemcpyNode(const ComputeGraphPtr &graph, const Out return nullptr; } - std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; + std::string node_name = pre_node->GetName() + "_" + domi::MEMCPYASYNC; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), MEMCPYASYNC); + OpDescPtr op_desc = MakeShared(node_name.c_str(), domi::MEMCPYASYNC); if (op_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); return nullptr; @@ -141,8 +144,8 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { /// @return bool /// bool HcclMemcpyPass::NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const { - return (op_desc->GetType() == HCOMALLGATHER || op_desc->GetType() == HCOMALLREDUCE || - op_desc->GetType() == HCOMREDUCESCATTER); + return (op_desc->GetType() == domi::HCOMALLGATHER || op_desc->GetType() == domi::HCOMALLREDUCE || + op_desc->GetType() == domi::HCOMREDUCESCATTER); } /// diff --git a/src/ge/graph/passes/identity_pass.cc b/src/ge/graph/passes/identity_pass.cc index 9b15f77a..fa6ff4ed 100644 --- a/src/ge/graph/passes/identity_pass.cc +++ b/src/ge/graph/passes/identity_pass.cc @@ -23,6 +23,9 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +using domi::IDENTITY; +using domi::IDENTITYN; + namespace ge { namespace { /// @@ -38,7 +41,7 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); return ret; } - if ((node_type != SWITCH) && (node_type != REFSWITCH)) { + if ((node_type != domi::SWITCH) && (node_type != domi::REFSWITCH)) { GELOGD("skip identity %s connected to switch", node->GetName().c_str()); break; } @@ -54,7 +57,7 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); return ret; } - if ((node_type != MERGE) && (node_type != REFMERGE)) { + if ((node_type != domi::MERGE) && (node_type != domi::REFMERGE)) { GELOGD("skip identity %s connected to merge", node->GetName().c_str()); break; } diff --git a/src/ge/graph/passes/isolated_op_remove_pass.cc b/src/ge/graph/passes/isolated_op_remove_pass.cc index 152104eb..c7e52a64 100644 --- a/src/ge/graph/passes/isolated_op_remove_pass.cc +++ b/src/ge/graph/passes/isolated_op_remove_pass.cc @@ -20,6 +20,9 @@ #include "common/types.h" #include "common/util.h" +using domi::SUCCESS; +using domi::TO_BE_OUTPUT; + namespace ge { Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index 35bfe496..d1fe211c 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -30,6 +30,8 @@ #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" +using domi::MEMCPYASYNC; + namespace ge { const char *const kGetNext = "GetNext"; diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc index ff150a54..62f8c57a 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -17,6 +17,8 @@ #include "graph/passes/link_gen_mask_nodes_pass.h" #include +#include +#include #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" @@ -26,6 +28,10 @@ using std::set; using std::vector; +using domi::CONSTANT; +using domi::CONSTANTOP; +using domi::DROPOUTDOMASK; + namespace ge { namespace { const size_t kGenMaskInputIndex = 1; @@ -68,8 +74,8 @@ Status LinkGenMaskNodesPass::Run(ComputeGraphPtr graph) { auto dest_anchor = dest_node->GetInControlAnchor(); GE_CHECK_NOTNULL(dest_anchor); - graphStatus status_link_to = src_anchor->LinkTo(dest_anchor); - if (status_link_to != GRAPH_SUCCESS) { + graphStatus status = src_anchor->LinkTo(dest_anchor); + if (status != GRAPH_SUCCESS) { GELOGE(FAILED, "Link from %s to %s failed.", src_node->GetName().c_str(), dest_node->GetName().c_str()); return FAILED; } diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.h b/src/ge/graph/passes/link_gen_mask_nodes_pass.h index f9979ab1..decc2d30 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.h +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.h @@ -17,10 +17,6 @@ #ifndef GE_GRAPH_PASSES_LINK_GEN_MASK_NODES_PASS_H_ #define GE_GRAPH_PASSES_LINK_GEN_MASK_NODES_PASS_H_ -#include -#include -#include - #include "graph/graph.h" #include "inc/graph_pass.h" diff --git a/src/ge/graph/passes/merge_pass.cc b/src/ge/graph/passes/merge_pass.cc index 768e5369..96dbf37f 100644 --- a/src/ge/graph/passes/merge_pass.cc +++ b/src/ge/graph/passes/merge_pass.cc @@ -28,6 +28,8 @@ #include "graph/utils/graph_utils.h" #include "graph/passes/pass_utils.h" +using domi::CONSTANT; +using domi::MERGE; using domi::PARAM_INVALID; using domi::SUCCESS; @@ -43,7 +45,7 @@ Status MergePass::Run(NodePtr &node) { std::string op_type; GE_CHK_STATUS_RET(GetOriginalType(node, op_type), "get original type failed"); - if (op_type != MERGE) { + if (op_type != domi::MERGE) { return SUCCESS; } @@ -97,9 +99,9 @@ bool MergePass::IsNeedChangeIndexToConstant(NodePtr &node) const { for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr) { GELOGI( - "[%s] MergePass, value_index link to other node, " - "change it to be Constant.", - node->GetName().c_str()); + "[%s] MergePass, value_index link to other node, " + "change it to be Constant.", + node->GetName().c_str()); return true; } } @@ -159,14 +161,15 @@ Status MergePass::CreateConstByValue(NodePtr &node, int value_index, OpDescPtr & // 3. create attr value of Constant, is a tensor GeTensorPtr const_tensor_ptr = - MakeShared(original_out_tensor_desc, reinterpret_cast(&value_index), sizeof(int)); + MakeShared(original_out_tensor_desc, reinterpret_cast(&value_index), sizeof(int)); if (const_tensor_ptr == nullptr) { GELOGE(FAILED, "[%s] Make shared of Constant tensor failed.", constant_name.c_str()); return FAILED; } GE_IF_BOOL_EXEC(!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_tensor_ptr), - GELOGE(FAILED, "get ATTR_NAME_WEIGHTS failed"); return FAILED); + GELOGE(FAILED, "get ATTR_NAME_WEIGHTS failed"); + return FAILED); // 4. set Constant output desc GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index 24941b17..428fada5 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -29,6 +29,14 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" +using domi::ATTR_NAME_STREAM_LABEL; + +using domi::NETOUTPUT; +using domi::STREAMACTIVE; +using domi::STREAMMERGE; +using domi::STREAMSWITCHN; +using domi::SWITCHN; + namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index 6e3f5e46..fd4e6b57 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -47,4 +47,4 @@ class MultiBatchPass : public GraphPass { std::vector> batch_head_nodes_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ +#endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/net_output_pass.cc b/src/ge/graph/passes/net_output_pass.cc index 7caf4990..31b7fb4e 100644 --- a/src/ge/graph/passes/net_output_pass.cc +++ b/src/ge/graph/passes/net_output_pass.cc @@ -30,6 +30,14 @@ #include "graph/utils/type_utils.h" #include "graph/debug/ge_attr_define.h" +using domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; +using domi::ATTR_NAME_NET_OUTPUT_DATATYPE; +using domi::ATTR_NAME_NET_OUTPUT_FORMAT; +using domi::ATTR_NAME_TRUE_BRANCH_STREAM; +using domi::NETOUTPUT; +using domi::NODE_NAME_NET_OUTPUT; +using domi::RETVAL_ATTR_NAME_INDEX; + namespace ge { Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, std::map> &retval_node_index_map) { diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index f0da5346..fdea1f8a 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -30,6 +30,14 @@ #include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" +using domi::ENTER; +using domi::LOOPCOND; +using domi::MERGE; +using domi::NEXTITERATION; +using domi::REFENTER; +using domi::REFMERGE; +using domi::STREAMACTIVE; +using domi::SWITCH; namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { diff --git a/src/ge/graph/passes/no_use_reshape_remove_pass.cc b/src/ge/graph/passes/no_use_reshape_remove_pass.cc index 5ae422ca..c0f46e2a 100644 --- a/src/ge/graph/passes/no_use_reshape_remove_pass.cc +++ b/src/ge/graph/passes/no_use_reshape_remove_pass.cc @@ -40,7 +40,7 @@ Status NoUseReshapeRemovePass::Run(ge::NodePtr &node) { GELOGE(PARAM_INVALID, "NoUseReshapeRemovePass enter. OpDesc is null."); return PARAM_INVALID; } - if (op_desc_ptr->GetType() != RESHAPE) { + if (op_desc_ptr->GetType() != domi::RESHAPE) { return SUCCESS; } GELOGI("NoUseReshapeRemovePass enter."); diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index f62ea160..d690e9c1 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -21,6 +21,8 @@ #include "graph/utils/node_utils.h" #include "omg/omg_inner_types.h" +using domi::SUCCESS; + namespace ge { const vector &PassManager::GraphPasses() const { return graph_passes_; } diff --git a/src/ge/graph/passes/pass_utils.cc b/src/ge/graph/passes/pass_utils.cc index 9b3f6b5f..80b85774 100644 --- a/src/ge/graph/passes/pass_utils.cc +++ b/src/ge/graph/passes/pass_utils.cc @@ -27,8 +27,8 @@ #include "common/ge/ge_util.h" #include "common/op/ge_op_utils.h" #include "common/types.h" +#include "framework/common/op/attr_define.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" #include "graph/manager/graph_var_manager.h" #include "graph/utils/graph_utils.h" @@ -111,7 +111,7 @@ bool PassUtils::IsConstant(const ConstNodePtr &node) { } auto src_node_type = node->GetType(); - bool is_constant = (src_node_type == CONSTANT) || (src_node_type == CONSTANTOP); + bool is_constant = (src_node_type == domi::CONSTANT) || (src_node_type == domi::CONSTANTOP); return is_constant; } @@ -203,7 +203,7 @@ Status PassUtils::RemoveBranch(const NodePtr &node, std::vector &delete auto dst_node = dst_in_anchor->GetOwnerNode(); std::string node_type; GE_CHK_STATUS_RET(GetOriginalType(dst_node, node_type), "get original type failed"); - if (node_type == NETOUTPUT) { + if (node_type == domi::NETOUTPUT) { if (dst_in_anchor->IsTypeOf()) { GELOGE(INTERNAL_ERROR, "[%s] Inactive branch connected to " @@ -215,7 +215,7 @@ Status PassUtils::RemoveBranch(const NodePtr &node, std::vector &delete GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(src_out_anchor, dst_in_anchor), "remove edge failed"); end_nodes.push_back(dst_node); } - } else if (node_type == MERGE) { + } else if (node_type == domi::MERGE) { /// Unlink connection between the inactive branch and Merge/NetOutput. /// The removal of inactive nodes will be handled in PrunePass GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(src_out_anchor, dst_in_anchor), "remove edge failed"); @@ -254,7 +254,7 @@ bool PassUtils::IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph) { if (compute_graph == nullptr) { return false; } - if (!ge::VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(NODE_NAME_FLOWCTRL_LOOP_PER_ITER)) { + if (!ge::VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(domi::NODE_NAME_FLOWCTRL_LOOP_PER_ITER)) { return false; } return compute_graph->GetNeedIteration(); @@ -319,7 +319,7 @@ Status PassUtils::RemoveInactiveBranchToMerge(const OutDataAnchorPtr &inactive_o if (dst_node != nullptr) { std::string dst_node_type; GE_CHK_STATUS_RET(GetOriginalType(dst_node, dst_node_type), "get original type failed"); - if (dst_node_type == MERGE) { + if (dst_node_type == domi::MERGE) { GELOGD("[%s] Switch connected directly to Merge", inactive_output_anchor->GetOwnerNode()->GetName().c_str()); GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(inactive_output_anchor, dst_anchor), "remove edge failed"); continue; diff --git a/src/ge/graph/passes/pass_utils.h b/src/ge/graph/passes/pass_utils.h index b889a056..a8b1cfe3 100644 --- a/src/ge/graph/passes/pass_utils.h +++ b/src/ge/graph/passes/pass_utils.h @@ -26,7 +26,6 @@ namespace ge { class PassUtils { public: PassUtils() = delete; - ~PassUtils() = delete; static NodePtr GetInDataNode(const ConstNodePtr &node, int index); diff --git a/src/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc index c2ce5465..0847453f 100644 --- a/src/ge/graph/passes/permute_pass.cc +++ b/src/ge/graph/passes/permute_pass.cc @@ -25,9 +25,18 @@ #include "inc/kernel_factory.h" #include "framework/omg/omg_inner_types.h" +using domi::ATTR_NAME_FORMAT; +using domi::ATTR_NAME_PRED_PERMUTE_DELETED; +using domi::CONVOLUTION; +using domi::DECONVOLUTION; +using domi::DEPCONVOLUTION; using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_NHWC; using domi::GetContext; +using domi::INTERNAL_ERROR; +using domi::PAD; +using domi::PERMUTE; +using domi::PERMUTE_ATTR_ORDER; using domi::SUCCESS; namespace ge { diff --git a/src/ge/graph/passes/placeholder_with_default_pass.cc b/src/ge/graph/passes/placeholder_with_default_pass.cc index 7a72fc36..cf1f84a6 100644 --- a/src/ge/graph/passes/placeholder_with_default_pass.cc +++ b/src/ge/graph/passes/placeholder_with_default_pass.cc @@ -20,6 +20,8 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +using domi::PLACEHOLDERWITHDEFAULT; + namespace ge { Status PlaceholderWithDefaultPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); diff --git a/src/ge/graph/passes/prevent_gradient_pass.cc b/src/ge/graph/passes/prevent_gradient_pass.cc index 87c1b3a1..049fece8 100644 --- a/src/ge/graph/passes/prevent_gradient_pass.cc +++ b/src/ge/graph/passes/prevent_gradient_pass.cc @@ -21,6 +21,8 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +using domi::PREVENTGRADIENT; + namespace ge { Status PreventGradientPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); diff --git a/src/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h index 64bf6573..cf8db6c5 100644 --- a/src/ge/graph/passes/print_op_pass.h +++ b/src/ge/graph/passes/print_op_pass.h @@ -18,8 +18,8 @@ #define GE_GRAPH_PASSES_PRINT_OP_PASS_H_ #include "framework/common/debug/ge_log.h" +#include "framework/common/op/attr_define.h" #include "framework/common/types.h" -#include "graph/debug/ge_attr_define.h" #include "graph/common/omg_util.h" #include "graph/graph.h" #include "graph/passes/base_pass.h" diff --git a/src/ge/graph/passes/prune_pass.cc b/src/ge/graph/passes/prune_pass.cc index f7d09740..8122e6e2 100644 --- a/src/ge/graph/passes/prune_pass.cc +++ b/src/ge/graph/passes/prune_pass.cc @@ -24,6 +24,10 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" +using domi::AIPPDATA; +using domi::DATA; +using domi::NETOUTPUT; + namespace ge { Status PrunePass::Run(ge::ComputeGraphPtr graph) { GELOGD("PrunePass Start"); diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.cc b/src/ge/graph/passes/replace_with_empty_const_pass.cc deleted file mode 100644 index b76b2cc9..00000000 --- a/src/ge/graph/passes/replace_with_empty_const_pass.cc +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/replace_with_empty_const_pass.h" -#include -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -Status ReplaceWithEmptyConstPass::Run(NodePtr &node) { - GELOGD("ReplaceWithEmptyConstPass in."); - if (node == nullptr) { - GELOGE(PARAM_INVALID, "Parameter is null."); - return PARAM_INVALID; - } - if (node->GetOpDesc() == nullptr) { - GELOGE(PARAM_INVALID, "Param [opDesc] must not be null."); - return PARAM_INVALID; - } - // Node like no op, it has no output - if (node->GetOpDesc()->GetAllOutputsDescPtr().empty()) { - GELOGI("Node %s has no output desc. Ignore current pass.", node->GetName().c_str()); - return SUCCESS; - } - // If outputs of current node are all empty, replace it with empty const - bool is_all_output_empty = true; - for (const auto &output_desc_ptr : node->GetOpDesc()->GetAllOutputsDescPtr()) { - if (output_desc_ptr == nullptr) { - GELOGI("Node %s Got empty output_desc_ptr, ignore current pass.", node->GetName().c_str()); - return SUCCESS; - } - if (!IsEmptyTenor(output_desc_ptr->GetShape())) { - is_all_output_empty = false; - break; - } - } - if (is_all_output_empty) { - GELOGI("Node %s has empty tensor output. It will be replaced by empty const.", node->GetName().c_str()); - // Replace op which all output is empty with empty const - Status ret = ReplaceWithEmptyConst(node); - if (ret != SUCCESS) { - // If replace failed, it should not break whole process, so still return success - GELOGW("Failed to repalce node %s with empty const.", node->GetName().c_str()); - } - } - GELOGD("ReplaceWithEmptyConstPass end."); - return SUCCESS; -} - -Status ReplaceWithEmptyConstPass::ReplaceWithEmptyConst(NodePtr &node_to_replace) { - std::map> shape_out_idx_map; - auto op_desc = node_to_replace->GetOpDesc(); - // Collect out_idx follow different out shape - for (const auto &out_anchor : node_to_replace->GetAllOutDataAnchors()) { - auto out_desc = op_desc->GetOutputDesc(out_anchor->GetIdx()); - shape_out_idx_map[GetDimStr(out_desc.GetShape())].emplace_back(out_anchor->GetIdx()); - } - - for (const auto &shape_2_out_idx : shape_out_idx_map) { - // Create empty const - // The out_desc in one group should be same shape, so here only get first out_desc. its valid index. - auto out_desc = op_desc->GetOutputDesc(shape_2_out_idx.second[0]); - NodePtr const_node; - auto graph = node_to_replace->GetOwnerComputeGraph(); - Status ret = InsertEmptyConst(out_desc, const_node, graph); - if (ret != SUCCESS) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; - } - - // Repalce data anchors - if (GraphUtils::ReplaceNodeDataAnchors(const_node, node_to_replace, {}, shape_2_out_idx.second) != GRAPH_SUCCESS) { - GELOGE(FAILED, "[%s] ReplaceNodeAnchors failed.", node_to_replace->GetName().c_str()); - return FAILED; - } - // Copy in control edge - if (GraphUtils::CopyInCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyInCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - // Copy out control edge - if (GraphUtils::CopyOutCtrlEdges(node_to_replace, const_node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "CopyOutCtrlEdges from %s to %s failed.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - return FAILED; - } - GELOGI("Node %s has been replaced by empty const %s.", node_to_replace->GetName().c_str(), - const_node->GetName().c_str()); - } - // Unlink control edge from node_to_replace to graph - if (node_to_replace->GetInControlAnchor() != nullptr) { - node_to_replace->GetInControlAnchor()->UnlinkAll(); - } - if (node_to_replace->GetOutControlAnchor() != nullptr) { - node_to_replace->GetOutControlAnchor()->UnlinkAll(); - } - return SUCCESS; -} -Status ReplaceWithEmptyConstPass::InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, - ComputeGraphPtr &graph) { - GeTensorPtr empty_tensor = MakeShared(); - if (empty_tensor == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed create empty tensor."); - return OUT_OF_MEMORY; - } - empty_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); - empty_tensor->MutableTensorDesc().SetFormat(out_desc.GetFormat()); - empty_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); - auto const_desc = OpDescUtils::CreateConstOp(empty_tensor); - if (const_desc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to get const desc from tensor"); - return OUT_OF_MEMORY; - } - - const_node = graph->AddNode(const_desc); - if (const_node == nullptr) { - GELOGE(FAILED, "Failed insert const node."); - return FAILED; - } - return SUCCESS; -} - -bool ReplaceWithEmptyConstPass::IsEmptyTenor(const GeShape &shape) const { - for (auto dim : shape.GetDims()) { - if (dim == 0) { - return true; - } - } - return false; -} - -string ReplaceWithEmptyConstPass::GetDimStr(const GeShape &shape) { - std::stringstream dim_str; - for (auto dim : shape.GetDims()) { - dim_str << dim << '-'; - } - return dim_str.str(); -} -} // namespace ge diff --git a/src/ge/graph/passes/replace_with_empty_const_pass.h b/src/ge/graph/passes/replace_with_empty_const_pass.h deleted file mode 100644 index 495b75b3..00000000 --- a/src/ge/graph/passes/replace_with_empty_const_pass.h +++ /dev/null @@ -1,34 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ -#define GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ - -#include "graph/passes/base_pass.h" - -namespace ge { -class ReplaceWithEmptyConstPass : public BaseNodePass { - public: - Status Run(NodePtr &node) override; - - private: - Status ReplaceWithEmptyConst(NodePtr &node_to_replace); - Status InsertEmptyConst(const GeTensorDesc &out_desc, NodePtr &const_node, ComputeGraphPtr &graph); - bool IsEmptyTenor(const GeShape &shape) const; - std::string GetDimStr(const GeShape &shape); -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_REPLACE_WITH_EMPTY_CONST_PASS_H_ diff --git a/src/ge/graph/passes/reshape_remove_pass.cc b/src/ge/graph/passes/reshape_remove_pass.cc index 13865648..49945f38 100644 --- a/src/ge/graph/passes/reshape_remove_pass.cc +++ b/src/ge/graph/passes/reshape_remove_pass.cc @@ -28,7 +28,7 @@ Status ReshapeRemovePass::Run(NodePtr &node) { GELOGE(FAILED, "parameter is null."); return FAILED; } - if (node->GetType() != RESHAPE) { + if (node->GetType() != domi::RESHAPE) { return SUCCESS; } GELOGD("Remove reshape node %s", node->GetName().c_str()); diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index 638bfb06..0d22d557 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -28,6 +28,13 @@ #include "graph/utils/op_desc_utils.h" #include "init/gelib.h" +using domi::ATTR_NAME_STREAM_LABEL; +using domi::CAST; +using domi::RESHAPE; +using domi::TRANSDATA; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace { const char *const kRemainNode = "node_remain"; const int kNoTransOp = 1; diff --git a/src/ge/graph/passes/shape_operate_op_remove_pass.cc b/src/ge/graph/passes/shape_operate_op_remove_pass.cc index b701e065..b04c2a18 100644 --- a/src/ge/graph/passes/shape_operate_op_remove_pass.cc +++ b/src/ge/graph/passes/shape_operate_op_remove_pass.cc @@ -20,6 +20,7 @@ #include "common/util.h" #include "graph/utils/attr_utils.h" +using domi::ATTR_TO_BE_DELETED; using domi::SUCCESS; namespace ge { diff --git a/src/ge/graph/passes/snapshot_pass.cc b/src/ge/graph/passes/snapshot_pass.cc index 702cf4de..83510e19 100644 --- a/src/ge/graph/passes/snapshot_pass.cc +++ b/src/ge/graph/passes/snapshot_pass.cc @@ -20,6 +20,8 @@ #include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +using domi::SNAPSHOT; + namespace ge { Status SnapshotPass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/stop_gradient_pass.cc b/src/ge/graph/passes/stop_gradient_pass.cc index bd5c0ea8..175c8756 100644 --- a/src/ge/graph/passes/stop_gradient_pass.cc +++ b/src/ge/graph/passes/stop_gradient_pass.cc @@ -17,6 +17,8 @@ #include "graph/passes/stop_gradient_pass.h" #include +using domi::STOPGRADIENT; + namespace ge { Status StopGradientPass::Run(NodePtr &node) { if (node == nullptr) { diff --git a/src/ge/graph/passes/switch_logic_remove_pass.cc b/src/ge/graph/passes/switch_logic_remove_pass.cc index be84a582..1ac25e13 100644 --- a/src/ge/graph/passes/switch_logic_remove_pass.cc +++ b/src/ge/graph/passes/switch_logic_remove_pass.cc @@ -37,7 +37,7 @@ char const *GetOutputNameFromIndex(int index) { return "UNKNOWN"; } -inline bool IsSwitch(const std::string &type) { return type == SWITCH || type == REFSWITCH; } +inline bool IsSwitch(const std::string &type) { return type == domi::SWITCH || type == domi::REFSWITCH; } Status GetPredNode(const NodePtr &switch_node, PredNodeAndOut &pred_node_index) { GE_CHECK_NOTNULL(switch_node); diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc index b21f962b..1e1975d5 100644 --- a/src/ge/graph/passes/switch_op_pass.cc +++ b/src/ge/graph/passes/switch_op_pass.cc @@ -31,6 +31,29 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" +using domi::ATTR_NAME_STREAM_LABEL; +using domi::ATTR_NAME_WEIGHTS; +using domi::CAST_ATTR_DSTT; +using domi::CAST_ATTR_SRCT; + +using domi::CAST; +using domi::CONSTANT; +using domi::ENTER; +using domi::EXIT; +using domi::MEMCPYASYNC; +using domi::MERGE; +using domi::NETOUTPUT; +using domi::NEXTITERATION; +using domi::REFENTER; +using domi::REFEXIT; +using domi::REFMERGE; +using domi::REFNEXTITERATION; +using domi::REFSWITCH; +using domi::STREAMACTIVE; +using domi::STREAMMERGE; +using domi::STREAMSWITCH; +using domi::SWITCH; + namespace ge { Status SwitchOpPass::Run(ComputeGraphPtr graph) { GELOGD("SwitchOpPass Enter"); @@ -137,7 +160,7 @@ Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_n NodePtr out_node = peer_in_anchor->GetOwnerNode(); GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); if ((type == MERGE) || (type == REFMERGE)) { - NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor); GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), "MemcpyAsync node add edge fail."); @@ -234,18 +257,16 @@ Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_nod need_label_nodes_.emplace_back(stream_merge); } - bool multi_batch_flag = false; if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); return FAILED; } - multi_batch_flag = true; } (void)bypass_nodes_.insert(merge_node); - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, multi_batch_flag), "StreamMerge add memcpy node fail."); + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge), "StreamMerge add memcpy node fail."); return SUCCESS; } @@ -304,20 +325,17 @@ NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodeP /// @brief Add MemcpyAsync Node /// @param [in] graph /// @param [in] in_node -/// @param [in] multi_batch_flag /// @return ge::NodePtr /// -NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, - bool multi_batch_flag) { +NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); - std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; - std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; + std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC; node_name = CheckDuplicateName(node_name); GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, memcpy_type); + OpDescPtr op_desc = MakeShared(node_name, MEMCPYASYNC); if (op_desc == nullptr) { GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); return nullptr; @@ -437,10 +455,9 @@ NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { /// @brief Add MemcpyAsync Op as StreamMerge in_node /// @param [in] graph /// @param [in] node -/// @param [in] multi_batch_flag /// @return Status /// -Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { +Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) { GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); @@ -453,7 +470,7 @@ Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, continue); GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { - in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); + in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor); GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), diff --git a/src/ge/graph/passes/switch_op_pass.h b/src/ge/graph/passes/switch_op_pass.h index 7e107e3b..14cdd22c 100644 --- a/src/ge/graph/passes/switch_op_pass.h +++ b/src/ge/graph/passes/switch_op_pass.h @@ -103,13 +103,13 @@ class SwitchOpPass : public GraphPass { NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, OutDataAnchorPtr &peer_cond_anchor); - NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); + NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); Status CombineSwitchNode(ComputeGraphPtr &graph); NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); - Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); + Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node); Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); diff --git a/src/ge/graph/passes/switch_pass.cc b/src/ge/graph/passes/switch_pass.cc index 8230d294..36fb4d81 100644 --- a/src/ge/graph/passes/switch_pass.cc +++ b/src/ge/graph/passes/switch_pass.cc @@ -25,6 +25,10 @@ #include "graph/passes/pass_utils.h" #include "graph/utils/graph_utils.h" +using domi::MERGE; +using domi::REFSWITCH; +using domi::SWITCH; + namespace ge { namespace { const std::vector::size_type kDataInputIndex = 0; diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc index bcf7e72f..30ca6a53 100644 --- a/src/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -24,6 +24,15 @@ #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" +using domi::ATTR_NAME_STREAM_LABEL; +using domi::CAST; +using domi::FAILED; +using domi::RESHAPE; +using domi::SUCCESS; +using domi::TRANSDATA; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace ge { Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { GE_TIMESTAMP_START(TransOpBreadthFusionPass); diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index da16ddbd..39989580 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -26,6 +26,15 @@ #include "graph/utils/graph_utils.h" #include "graph/common/transop_util.h" +using domi::CAST; +using domi::EXPANDDIMS; +using domi::REFORMAT; +using domi::RESHAPE; +using domi::SQUEEZE; +using domi::TRANSDATA; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { GE_TIMESTAMP_START(TransOpDepthFusionPass); diff --git a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc index 4b08e956..383ab285 100644 --- a/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc +++ b/src/ge/graph/passes/transop_nearby_allreduce_fusion_pass.cc @@ -28,7 +28,7 @@ Status TransOpNearbyAllreduceFusionPass::Run(NodePtr &node) { return SUCCESS; } - if (node->GetType() == HCOMALLREDUCE) { + if (node->GetType() == domi::HCOMALLREDUCE) { GELOGI("found allreduce op %s", node->GetName().c_str()); Status ret = RemoveNearbyPairedTransOps(node); if (ret != SUCCESS) { @@ -46,7 +46,7 @@ bool TransOpNearbyAllreduceFusionPass::IsSymmetricTransOps(const NodePtr &node1, return false; } - if (node1->GetType() != TRANSDATA || node2->GetType() != TRANSDATA) { + if (node1->GetType() != domi::TRANSDATA || node2->GetType() != domi::TRANSDATA) { return false; } diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index b1df8e09..80ed5d56 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -32,6 +32,14 @@ #include "graph/utils/type_utils.h" #include "init/gelib.h" +using domi::ATTR_NAME_INPUT_FORMAT; +using domi::ATTR_NAME_OUTPUT_FORMAT; +using domi::CAST; +using domi::RESHAPE; +using domi::TRANSDATA; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace { const char *const kRemainNode = "node_remain"; const int kInvalidFusionOpCount = -1; diff --git a/src/ge/graph/passes/transpose_transdata_pass.cc b/src/ge/graph/passes/transpose_transdata_pass.cc index b36dda6a..ebc068a9 100644 --- a/src/ge/graph/passes/transpose_transdata_pass.cc +++ b/src/ge/graph/passes/transpose_transdata_pass.cc @@ -26,6 +26,10 @@ #include "init/gelib.h" #include "opskernel_manager/ops_kernel_manager.h" +using domi::TRANSDATA; +using domi::TRANSPOSE; +using domi::TRANSPOSED; + namespace { const char *const kAttrNameSrcFormat = "src_format"; } // namespace diff --git a/src/ge/graph/passes/unused_const_pass.cc b/src/ge/graph/passes/unused_const_pass.cc index 386633b5..750c95f8 100644 --- a/src/ge/graph/passes/unused_const_pass.cc +++ b/src/ge/graph/passes/unused_const_pass.cc @@ -19,6 +19,8 @@ #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" +using domi::UNUSEDCONST; + namespace ge { /// /// run pass @@ -36,7 +38,7 @@ Status UnusedConstPass::Run(NodePtr &node) { } std::string op_type = node->GetOpDesc()->GetType(); - if (op_type == UNUSEDCONST) { + if (op_type == domi::UNUSEDCONST) { GELOGD("op type is unused const."); return IsolateAndDeleteNode(node, {-1}); } diff --git a/src/ge/graph/passes/unused_op_remove_pass.cc b/src/ge/graph/passes/unused_op_remove_pass.cc index 093d931a..9a56e3a2 100644 --- a/src/ge/graph/passes/unused_op_remove_pass.cc +++ b/src/ge/graph/passes/unused_op_remove_pass.cc @@ -29,7 +29,13 @@ #include "inc/pass_manager.h" #include "graph/passes/isolated_op_remove_pass.h" +using domi::ASSERT; +using domi::ATTENTIONDECODER; +using domi::DROPOUT; +; +using domi::PERMUTE; using domi::SUCCESS; +using domi::UNUSEDCONST; namespace ge { const std::set kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; diff --git a/src/ge/graph/passes/var_is_initialized_op_pass.cc b/src/ge/graph/passes/var_is_initialized_op_pass.cc index c88db80c..0e5e4674 100644 --- a/src/ge/graph/passes/var_is_initialized_op_pass.cc +++ b/src/ge/graph/passes/var_is_initialized_op_pass.cc @@ -26,6 +26,10 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" +using domi::CONSTANT; +using domi::VARIABLE; +using domi::VARISINITIALIZEDOP; + namespace ge { namespace { const int kAssignVarRefIndex = 0; @@ -280,7 +284,7 @@ bool VarIsInitializedOpPass::IsVarInitedOnTheGraphAndNode(const NodePtr &node, i Status VarIsInitializedOpPass::CheckAndSetVarInited(const NodePtr &node, bool &inited, int64_t &inited_var) { GE_CHECK_NOTNULL(node); inited = false; - if (node->GetType() != ASSIGN) { + if (node->GetType() != domi::ASSIGN) { return SUCCESS; } auto ref_in_anchor = node->GetInDataAnchor(kAssignVarRefIndex); diff --git a/src/ge/graph/passes/variable_format_pass.cc b/src/ge/graph/passes/variable_format_pass.cc index 28f6a4f7..302011fe 100644 --- a/src/ge/graph/passes/variable_format_pass.cc +++ b/src/ge/graph/passes/variable_format_pass.cc @@ -26,7 +26,7 @@ Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != domi::VARIABLE, continue); ge::NodePtr use_node = nullptr; if (GetApplyMomentumOpByVariableInput(node, use_node)) { @@ -79,7 +79,7 @@ Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); if (in_node != nullptr) { string in_op_type = in_node->GetType(); - if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && + if ((in_op_type == domi::VARIABLE) && (in_node->GetOpDesc() != nullptr) && (in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); @@ -104,7 +104,7 @@ Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &nod NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); if (in_node != nullptr) { string in_op_type = in_node->GetType(); - if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { + if ((in_op_type == domi::VARIABLE) && (in_node->GetOpDesc() != nullptr)) { ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); op_desc_ptr->MutableInputDesc(0)->SetFormat(format); op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index eb8b5206..302598da 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -91,9 +91,9 @@ Status ByPassTransNode(NodePtr &trans_node, NodePtr &ref_node) { } bool IsTransSupport(const TransNodeInfo &trans_info) { - if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { + if (trans_info.node_type == domi::RESHAPE || trans_info.node_type == domi::REFORMAT) { return true; - } else if (trans_info.node_type == TRANSDATA) { + } else if (trans_info.node_type == domi::TRANSDATA) { formats::TransArgs args{nullptr, trans_info.input.GetFormat(), trans_info.output.GetFormat(), @@ -101,7 +101,7 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { trans_info.output.GetShape().GetDims(), trans_info.input.GetDataType()}; return formats::IsTransFormatSupport(args); - } else if (trans_info.node_type == CAST) { + } else if (trans_info.node_type == domi::CAST) { formats::CastArgs datatype_args{nullptr, static_cast(trans_info.input.GetShape().GetShapeSize()), trans_info.input.GetDataType(), trans_info.output.GetDataType()}; return formats::IsTransDataTypeSupport(datatype_args); @@ -423,11 +423,11 @@ Status VariableOpPass::GenerateVariableVariableRefMap(const ComputeGraphPtr &com std::map> names_to_refs; GE_CHECK_NOTNULL(compute_graph); for (auto &node : compute_graph->GetAllNodes()) { - if (node->GetType() != VARIABLE) { + if (node->GetType() != domi::VARIABLE) { continue; } std::string ref_var_name; - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_name)) { + if (!ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_name)) { names_to_var[node->GetName()] = node; } else { names_to_refs[ref_var_name].insert(node); @@ -583,8 +583,8 @@ Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) { // renew var manager desc Status ret = SUCCESS; for (auto &node : graph->GetDirectNode()) { - bool is_var_node = - (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == VARHANDLEOP); + bool is_var_node = (node->GetType() == domi::VARIABLE) || (node->GetType() == domi::VARIABLEV2) || + (node->GetType() == domi::VARHANDLEOP); if (is_var_node) { if (!ge::VarManager::Instance(graph->GetSessionID())->IsVarExist(node->GetName())) { GELOGD("var manager does not exist var node[%s]", node->GetName().c_str()); diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index 3a62082a..981e1235 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -22,14 +22,16 @@ #include "common/ge/ge_util.h" #include "external/graph/graph.h" #include "framework/common/debug/ge_log.h" -#include "graph/common/omg_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/node.h" #include "graph/utils/tensor_utils.h" +using domi::ASSIGN; +using domi::REF_VAR_SRC_VAR_NAME; +using domi::VAR_ATTR_VAR_OUT_INDEX; +using domi::VARIABLE; + namespace ge { -std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ - {REFSWITCH, {{0, 0}, {0, 1}}}}; Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { @@ -46,7 +48,9 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - if (node->GetOpDesc()->GetType() == VARIABLE) { + bool is_variable = node->GetOpDesc()->GetType() == VARIABLE; + bool is_deal = has_dealed_variable_.find(node->GetName()) == has_dealed_variable_.end(); + if (is_variable && is_deal) { Status ret = DealVariableNode(node); if (ret != SUCCESS) { GELOGE(ret, "variable add back edge failed"); @@ -150,7 +154,7 @@ NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, } } if (!found_writeable_node) { - GELOGD("final writable node is %s", current_node->GetName().c_str()); + GELOGI("final writable node is %s", current_node->GetName().c_str()); return current_node; } } @@ -160,54 +164,53 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g GE_CHECK_NOTNULL(final_writable_node); GE_CHECK_NOTNULL(var_node); - if (final_writable_node->GetType() == FRAMEWORKOP) { - GELOGD("No need to add variable_ref for frameworkop"); - return SUCCESS; - } - std::stringstream variable_ref_name; - variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; - ge::NodePtr find_node = var_node->GetOwnerComputeGraph()->FindNode(var_node->GetName() + variable_ref_name.str()); - if (find_node != nullptr) { - GELOGD("The corresponding variable_ref [%s] has been added to this connection.", find_node->GetName().c_str()); - return SUCCESS; - } - NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); - - GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); - GE_CHECK_NOTNULL(variable_ref_node); - // add control anchor between variable_ref and final peer node - // variable_ref_node need to execute before other nodes + NodePtr var_ref_node = CreatVariableRef(final_writable_node, var_node); + GE_CHECK_NOTNULL(var_ref_node); + // add control anchor between var_ref_node and final peer node + // var_ref_node need to execute before other nodes auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); for (auto &final_writable_outAnchor : final_writable_outAnchors) { GE_CHECK_NOTNULL(final_writable_outAnchor); for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { GE_CHECK_NOTNULL(final_writable_peerAnchor); NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); - graphStatus ret = - ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); + graphStatus ret = ge::GraphUtils::AddEdge(var_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); + GELOGE(FAILED, "add control anchor between var_ref_node and final_writable peer_node failed"); return FAILED; } } } + // add edge final node:index ---> var_ref_node:0 graphStatus ret = - ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), variable_ref_node->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), var_ref_node->GetInDataAnchor(0)); if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); + GELOGE(FAILED, "add data anchor between var_ref_node and final_writable peer_node failed"); return FAILED; } return SUCCESS; } -ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { +ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) { + if ((final_writable_node == nullptr) || (var_node == nullptr) || (var_node->GetOwnerComputeGraph() == nullptr)) { + GELOGE(FAILED, "parameter ptr is null."); + return nullptr; + } + GELOGD("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(), + var_node->GetName().c_str()); + + static uint32_t var_ref_count = 0; + std::stringstream var_ref_name; + var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++; + OpDescPtr var_op_desc = var_node->GetOpDesc(); if (var_op_desc == nullptr) { GELOGE(FAILED, "get var opdesc is nullptr"); return nullptr; } - OpDescPtr var_ref_op_desc = MakeShared(variable_ref_name.c_str(), var_op_desc->GetType()); + OpDescPtr var_ref_op_desc = + MakeShared(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()); if (var_ref_op_desc == nullptr) { GELOGE(FAILED, "var_ref opdesc is nullptr"); return nullptr; @@ -219,15 +222,15 @@ ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, GELOGW("add input desc edge failed"); return nullptr); - NodePtr variable_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); - GE_IF_BOOL_EXEC(variable_ref_node == nullptr, GELOGW("variable_ref_node is null"); return nullptr); + NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); + GE_IF_BOOL_EXEC(var_ref_node == nullptr, GELOGW("var_ref_node is null"); return nullptr); + has_dealed_variable_.insert(var_node->GetName()); bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); if (is_set_str) { - GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", variable_ref_node->GetName().c_str(), - var_op_desc->GetName().c_str()); + GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", var_ref_node->GetName().c_str(), var_op_desc->GetName().c_str()); } - return variable_ref_node; + return var_ref_node; } int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { @@ -242,13 +245,16 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu } } - if (node_type == FRAMEWORKOP) { - std::string original_type; - GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); - GELOGI("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str()); - return FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_); + auto node_iter = ref_input_output_map_.find(node_type); + if (node_iter == ref_input_output_map_.end()) { + return -1; } - return FindRefOutIndex(node_type, input_index, ref_input_output_map_); + + auto index_iter = node_iter->second.find(input_index); + if (index_iter == node_iter->second.end()) { + return -1; + } + return index_iter->second; } void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { @@ -300,18 +306,4 @@ Status VariablePrepareOpPass::UpdateAssignOpDesc(const ge::NodePtr &node) { } return SUCCESS; } - -int VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, - const std::map> &ref_map) { - auto node_iter = ref_map.find(node_type); - if (node_iter == ref_map.end()) { - return -1; - } - - auto index_iter = node_iter->second.find(input_index); - if (index_iter == node_iter->second.end()) { - return -1; - } - return index_iter->second; -} } // namespace ge diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index fb25d5db..0fbd311c 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -33,15 +33,13 @@ class VariablePrepareOpPass : public GraphPass { Status DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index); NodePtr GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index); Status AddVariableRef(ge::NodePtr &node, ge::NodePtr &var_node, int index); - NodePtr CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node); + NodePtr CreatVariableRef(ge::NodePtr &final_ref_type_node, ge::NodePtr &var_node); int GetWritableNodeOutIndex(const NodePtr &node, int input_index); Status UpdateAssignOpDesc(const ge::NodePtr &node); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); - int FindRefOutIndex(const std::string &node_type, int input_index, - const std::map> &ref_map); std::map> ref_input_output_map_; - static std::map> ref_node_without_prototype_map_; + std::unordered_set has_dealed_variable_{}; }; } // namespace ge diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index 7bc767ee..1daa6e5c 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -18,6 +18,10 @@ #include #include "framework/common/debug/ge_log.h" +using domi::REF_VAR_PRE_PEER_OUT_INDEX; +using domi::REF_VAR_SRC_VAR_NAME; +using domi::VARIABLE; + namespace ge { Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { GE_TIMESTAMP_START(VariableRefDeleteOpPass); @@ -31,8 +35,8 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string ref_var_src_var_name; - bool is_variable_ref = (node->GetOpDesc()->GetType() == VARIABLE) && - (ge::AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)); + bool is_variable_ref = (node->GetOpDesc()->GetType() == domi::VARIABLE) && + (ge::AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)); if (!is_variable_ref) { continue; } @@ -83,7 +87,7 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: // add attr [REF_VAR_SRC_VAR_NAME] to the previous node of the variable_ref GE_CHECK_NOTNULL(peer_node->GetOpDesc()); - bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); ge::NodePtr var_ref_src_var = graph->FindNode(ref_var_src_var_name); if (var_ref_src_var == nullptr) { @@ -92,7 +96,7 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: } GE_CHECK_NOTNULL(var_ref_src_var->GetOpDesc()); - bool is_set_index = ge::AttrUtils::SetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, index); + bool is_set_index = ge::AttrUtils::SetInt(var_ref_src_var->GetOpDesc(), domi::REF_VAR_PRE_PEER_OUT_INDEX, index); if (is_set_str && is_set_index) { GELOGI("[%s]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), ref_var_src_var_name.c_str()); GELOGI("[%s]: add attr [ REF_VAR_PRE_PEER_OUT_INDEX: %d ]", var_ref_src_var->GetName().c_str(), index); diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index eacec6d1..8447552d 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -18,7 +18,6 @@ #include #include #include -#include #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/helper/model_helper.h" @@ -72,7 +71,6 @@ #include "graph/passes/var_is_initialized_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/common_subexpression_elimination_pass.h" -#include "graph/passes/replace_with_empty_const_pass.h" #include "graph/preprocess/insert_op/util_insert_aipp_op.h" #include "graph/types.h" #include "graph/utils/tensor_utils.h" @@ -82,6 +80,13 @@ #include "multi_batch_copy_graph.h" #include "runtime/dev.h" +using domi::AIPP; +using domi::AIPPDATA; +using domi::ASSIGN; +using domi::CAST; +using domi::DATA; +using domi::MULTIPLY; +using domi::StringUtils; using ge::CheckInt64Uint32MulOverflow; namespace ge { @@ -132,7 +137,7 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { void AddTransNodeAttr(const std::string &node_type, const GeTensorDesc &input, const GeTensorDesc &output, OpDescPtr &op_desc) { // For format transfer node, the IR definition has src/dst format attrs - if (node_type == TRANSDATA) { + if (node_type == domi::TRANSDATA) { GE_IF_BOOL_EXEC( !AttrUtils::SetStr(op_desc, FORMAT_TRANSFER_SRC_FORMAT, TypeUtils::FormatToSerialString(input.GetFormat())), GELOGW("SetStr FORMAT_TRANSFER_SRC_FORMAT failed");) @@ -141,7 +146,7 @@ void AddTransNodeAttr(const std::string &node_type, const GeTensorDesc &input, c GELOGW("SetStr FORMAT_TRANSFER_DST_FORMAT failed");) } // For cast node, the IR definition has src/dst attrs - if (node_type == CAST) { + if (node_type == domi::CAST) { GE_IF_BOOL_EXEC(!AttrUtils::SetInt(op_desc, CAST_ATTR_SRCT, static_cast(input.GetDataType())), GELOGW("SetInt CAST_ATTR_SRCT failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetInt(op_desc, CAST_ATTR_DSTT, static_cast(output.GetDataType())), @@ -196,7 +201,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c AddTransNodeAttr(node_type, input, output, op_desc); NodePtr shape_node = nullptr; - if (node_type == RESHAPE) { + if (node_type == domi::RESHAPE) { auto shape_desc = CreateTensorShape(output); if (shape_desc == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to add shape for reshape %s, can not create the shape input", @@ -222,7 +227,7 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c return nullptr; } - if (node_type == RESHAPE) { + if (node_type == domi::RESHAPE) { if (GraphUtils::AddEdge(shape_node->GetOutDataAnchor(0), trans_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to add shape node for reshape %s, can not add the edge", name.c_str()); return nullptr; @@ -377,10 +382,10 @@ VarNamesToRefs CollectVarNamesToRefs(const ComputeGraphPtr &graph) { return names_to_refs; } for (auto &node : graph->GetAllNodes()) { - if (node->GetType() != VARIABLE) { + if (node->GetType() != domi::VARIABLE) { continue; } - if (AttrUtils::GetStr(node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, var_name)) { + if (AttrUtils::GetStr(node->GetOpDesc(), domi::REF_VAR_SRC_VAR_NAME, var_name)) { (void)names_to_refs[var_name].insert(node); } } @@ -422,7 +427,7 @@ NodePtr CreateCastOp(const ge::GeShape &shape, const ge::DataType input_data_typ output.SetOriginDataType(output_data_type); ge::TensorUtils::SetRealDimCnt(output, static_cast(shape.GetDims().size())); - auto cast_node = CreateTransNode(name, CAST, input, output, node); + auto cast_node = CreateTransNode(name, domi::CAST, input, output, node); GELOGD("Create cast node success."); return cast_node; } @@ -487,7 +492,7 @@ NodePtr CreateTransdataNode(const ge::GeShape &in_shape, const ge::Format input_ output.SetOriginShape(out_shape); output.SetOriginDataType(dt); - return CreateTransNode(name, TRANSDATA, input, output, node); + return CreateTransNode(name, domi::TRANSDATA, input, output, node); } Status TransferShape2NC1HWC0(Format src_format, const std::vector &src_shape, DataType dt, Format dst_format, @@ -736,35 +741,6 @@ Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { } return SUCCESS; } - -Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { - auto op = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op); - auto inputDescsPtr = op->GetAllInputsDescPtr(); - auto outputDescsPtr = op->GetAllOutputsDescPtr(); - ge::Format format = ge::FORMAT_ND; - // if user set shape larger than 4, inferformat may set NCHW or NHWC, GE should set ND before FE - // process, otherwise fe will insert transdata. - for (auto &inputDescPtr : inputDescsPtr) { - GE_CHECK_NOTNULL(inputDescPtr); - if ((inputDescPtr->GetShape().GetDims().size() > ge::DIM_DEFAULT_SIZE) && - ((inputDescPtr->GetFormat() == ge::FORMAT_NCHW) || (inputDescPtr->GetFormat() == ge::FORMAT_NHWC))) { - GELOGI("The node inputdesc [%s] format need to be set ND", op->GetName().c_str()); - inputDescPtr->SetFormat(format); - inputDescPtr->SetOriginFormat(format); - } - } - for (auto &outputDescPtr : outputDescsPtr) { - GE_CHECK_NOTNULL(outputDescPtr); - if ((outputDescPtr->GetShape().GetDims().size() > ge::DIM_DEFAULT_SIZE) && - ((outputDescPtr->GetFormat() == ge::FORMAT_NCHW) || (outputDescPtr->GetFormat() == ge::FORMAT_NHWC))) { - GELOGI("The node outputdesc [%s] format need to be set ND", op->GetName().c_str()); - outputDescPtr->SetFormat(format); - outputDescPtr->SetOriginFormat(format); - } - } - return SUCCESS; -} } // namespace GraphPrepare::GraphPrepare() : compute_graph_(nullptr) {} @@ -782,7 +758,7 @@ Status GraphPrepare::UpdateVariableFormats(ComputeGraphPtr &graph) { if (node == nullptr) { continue; } - if (node->GetType() != VARIABLE) { + if (node->GetType() != domi::VARIABLE) { continue; } auto trans_road = VarManager::Instance(graph->GetSessionID())->GetTransRoad(node->GetName()); @@ -855,12 +831,9 @@ Status GraphPrepare::CheckGraph() { Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::unordered_set &ref_nodes) { - // Acceptable input types should be ref node, variable or Switch operator, which is issued by ME for dynamic - // lossscale and would be optimized in SwitchOpPass. Since ME dont differentiate between RefSwitch and Switch, - // and only issue Switch. - static std::unordered_set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, - ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, - ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; + static std::unordered_set acceptable_types = { + domi::VARIABLE, domi::VARIABLEV2, domi::VARHANDLEOP, domi::REFSWITCH, + domi::REFMERGE, domi::REFENTER, domi::REFNEXTITERATION, domi::REFEXIT}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -879,7 +852,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i return SUCCESS; } auto input_type = input_op_desc->GetType(); - if (input_type == ge::FRAMEWORKOP) { + if (input_type == domi::FRAMEWORKOP) { if (!ge::AttrUtils::GetStr(input_op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, input_type)) { GELOGE(PARAM_INVALID, "Get original type failed."); return PARAM_INVALID; @@ -971,7 +944,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GE_CHECK_NOTNULL(input_node); OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == DATA) { + if (op->GetType() == domi::DATA) { GeAttrValue::INT index = 0; if ((!(AttrUtils::GetInt(op, ATTR_NAME_INDEX, index))) || (domi::GetContext().is_dynamic_input)) { GELOGW("Get index from data attr failed"); @@ -1004,7 +977,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { int64_t desc_shape = desc.GetShape().GetShapeSize(); FMK_INT64_UINT32_MULCHECK(desc_shape, length); int64_t shape_size = desc_shape * length; - GE_IF_BOOL_EXEC(shape_size == 0 && desc.GetShape().GetDimNum() == 0, shape_size = static_cast(length)); + GE_IF_BOOL_EXEC(shape_size == 0, shape_size = static_cast(length)); int64_t size = 0; GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "TensorUtils GetSize failed"); @@ -1138,10 +1111,6 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { GE_RETURN_IF_ERROR(InsertNewOpUtil::Instance().UpdateDataNodeByAipp(compute_graph_)); for (auto &node_ptr : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); - if (CheckIfNeedSetNdFormat(node_ptr) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Set node [%s] format ND failed", node_ptr->GetName().c_str()); - return FAILED; - } if (node_ptr->GetType() == DATA) { if (ProcessDataNode(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process data node failed"); @@ -1149,7 +1118,7 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { } } - if (node_ptr->GetType() == ge::NETOUTPUT) { + if (node_ptr->GetType() == domi::NETOUTPUT) { if (ProcessNetoutputNode(node_ptr, options_.output_datatype) != SUCCESS) { GELOGE(INTERNAL_ERROR, "Process netoutput node failed"); return FAILED; @@ -1408,10 +1377,10 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u Status GraphPrepare::CheckConstOp() { for (auto &node_ptr : compute_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr); - if (node_ptr->GetType() == CONSTANT) { + if (node_ptr->GetType() == domi::CONSTANT) { Status ret = VerifyConstOp(node_ptr); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Const Op Check failed"); - } else if (node_ptr->GetType() == FRAMEWORKOP) { + } else if (node_ptr->GetType() == domi::FRAMEWORKOP) { auto op_desc = node_ptr->GetOpDesc(); if (op_desc == nullptr) { GELOGE(PARAM_INVALID, "Get op desc failed"); @@ -1421,7 +1390,7 @@ Status GraphPrepare::CheckConstOp() { GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type), GELOGI("Get FrameWorkOp original type [%s]", original_type.c_str())); GELOGI("original type is %s", original_type.c_str()); - if (original_type == CONSTANT) { + if (original_type == domi::CONSTANT) { Status ret = VerifyConstOp(node_ptr); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "Const Op Check failed"); } @@ -1452,17 +1421,9 @@ Status GraphPrepare::VerifyConstOp(const NodePtr &node) { FMK_INT64_UINT32_MULCHECK(shape_size, length); GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, shape_size * length, TypeUtils::DataTypeToSerialString(data_type).c_str()); - if (shape_size == 0) { - if (ge_tensor_desc.GetShape().GetDims().size() == 0) { - // shape = [], means it's a sclar tensor. - GE_CHK_BOOL_EXEC(data_size / length == 1, return PARAM_INVALID, "Const is invalid scalar tensor."); - } else { - // shape = [x, y, 0,...], means it's a vector tensor that value is []. - GE_CHK_BOOL_EXEC(data_size == 0, return PARAM_INVALID, "Const is invalid vector scalar."); - } - } else { - GE_CHK_BOOL_EXEC(data_size == static_cast(shape_size * length) && data_size != 0, return PARAM_INVALID, - "Const input data size is not equal with tensor desc shape"); + if ((shape_size != 0) || (data_size / length != 1)) { + GE_CHK_BOOL_EXEC(data_size == static_cast(shape_size * length) && data_size != 0, + return GRAPH_PARAM_INVALID, "Const input data size is not equal with tensor desc shape"); } return SUCCESS; } @@ -1478,7 +1439,7 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); node_num++; - if (op->GetType() == DATA || op->GetType() == AIPPDATA) { + if (op->GetType() == domi::DATA || op->GetType() == domi::AIPPDATA) { data_num++; GeAttrValue::INT index = 0; if (!(AttrUtils::GetInt(op, ATTR_NAME_INDEX, index))) { @@ -1492,8 +1453,8 @@ Status GraphPrepare::CheckUserInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { - if (desc.GetShape().GetDim(i) < 0) { - GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, + if (desc.GetShape().GetDim(i) <= 0) { + GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need > 0, real:%ld.", i, desc.GetShape().GetDim(i)); return GE_GRAPH_INIT_FAILED; } @@ -1659,8 +1620,8 @@ Status GraphPrepare::OptimizeForPreprocess() { if (options_.train_graph_flag) { for (ge::NodePtr &n : compute_graph_->GetAllNodes()) { // This can ensure that n is not a null pointer - if (n->GetOpDesc()->GetType() == CONSTANT) { - n->GetOpDesc()->SetType(CONSTANTOP); + if (n->GetOpDesc()->GetType() == domi::CONSTANT) { + n->GetOpDesc()->SetType(domi::CONSTANTOP); } } } diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index d35bd84c..b14aa4b9 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -39,6 +39,8 @@ #include "external/graph/operator_factory.h" #include "base_insert_op.h" +using domi::AIPP; + #define SAVE_AIPP_ATTR(KEY, SAVE_TYPE) \ do { \ (void)aipp_attrs.SetAttr(#KEY, GeAttrValue::CreateFrom(aipp_params_->KEY())); \ @@ -98,13 +100,13 @@ Status GetDataDimN(const ge::NodePtr &data_node, ge::Format format, int64_t &bat batch = 1; return SUCCESS; } - if (shape.size() == DIM_DEFAULT_SIZE) { + if (shape.size() == domi::DIM_DEFAULT_SIZE) { switch (format) { case FORMAT_NCHW: - batch = shape[NCHW_DIM_N]; + batch = shape[domi::NCHW_DIM_N]; return SUCCESS; case FORMAT_NHWC: - batch = shape[NHWC_DIM_N]; + batch = shape[domi::NHWC_DIM_N]; return SUCCESS; default: GELOGE(PARAM_INVALID, "Not support data format: %s", TypeUtils::FormatToSerialString(format).c_str()); @@ -208,7 +210,7 @@ NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr const std::string &aippConfigPath, const uint32_t &index) { std::string current_name = out_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor->GetIdx()) + "_huawei_aipp"; - auto aipp_opdesc_ptr = MakeShared(current_name, AIPP); + auto aipp_opdesc_ptr = MakeShared(current_name, domi::AIPP); if (aipp_opdesc_ptr == nullptr) { GELOGE(OUT_OF_MEMORY, "Failed to alloc aipp desc, name %s", current_name.c_str()); return nullptr; @@ -222,7 +224,7 @@ NodePtr AippOp::CreateAipp(const ComputeGraphPtr &graph, const OutDataAnchorPtr GELOGE(INTERNAL_ERROR, "Set config file path attr for aipp node failed"); return nullptr; } - if (!AttrUtils::SetNamedAttrs(aipp_opdesc_ptr, ATTR_NAME_AIPP, aipp_attr)) { + if (!AttrUtils::SetNamedAttrs(aipp_opdesc_ptr, domi::ATTR_NAME_AIPP, aipp_attr)) { GELOGE(INTERNAL_ERROR, "Set name attrs for aipp node failed"); return nullptr; } @@ -289,7 +291,7 @@ domi::AippOpParams::AippMode AippOp::GetAippMode() { return aipp_params_->aipp_m NodePtr AippOp::FindDataByIndex(const ComputeGraphPtr &graph, int rank) { int64_t data_index = 0; for (auto &node : graph->GetDirectNode()) { - if (node->GetType() != DATA) { + if (node->GetType() != domi::DATA) { continue; } // There is no `index` attribute on the `Data` node when compile in inference scene @@ -534,7 +536,7 @@ Status AippOp::GenerateOpDesc(OpDescPtr op_desc) { static int op_idx = 0; op_desc->SetName(std::string("aipp_node").append(std::to_string(op_idx++))); - op_desc->SetType(AIPP); + op_desc->SetType(domi::AIPP); // Add two InputDesc, add the second after the first one is added successfully. if ((op_desc->AddInputDesc(GeTensorDesc()) != GRAPH_SUCCESS) || @@ -550,7 +552,7 @@ Status AippOp::GenerateOpDesc(OpDescPtr op_desc) { GeAttrValue::NamedAttrs aipp_attrs; ConvertParamToAttr(aipp_attrs); - GE_IF_BOOL_EXEC(!AttrUtils::SetNamedAttrs(op_desc, ATTR_NAME_AIPP, aipp_attrs), + GE_IF_BOOL_EXEC(!AttrUtils::SetNamedAttrs(op_desc, domi::ATTR_NAME_AIPP, aipp_attrs), GELOGE(FAILED, "failed to set ATTR_NAME_AIPP"); return FAILED); @@ -652,7 +654,7 @@ Status AippOp::CreateAippData(const ComputeGraphPtr &graph, const NodePtr &aipp_ TensorUtils::SetSize(input_tensor, max_dynamic_aipp_size); // new add aipp_data ops for dynamic aipp param input - OpDescPtr op_desc_ptr_data = MakeShared(kDynamicAippData, AIPPDATA); + OpDescPtr op_desc_ptr_data = MakeShared(kDynamicAippData, domi::AIPPDATA); GE_CHECK_NOTNULL(op_desc_ptr_data); auto stat1 = op_desc_ptr_data->AddInputDesc(input_tensor); diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 680e40c9..218fc7f7 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -18,6 +18,7 @@ #include #include #include "common/ge/ge_util.h" +#include "common/op/attr_define.h" #include "common/op/ge_op_utils.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" @@ -34,7 +35,15 @@ #include "inc/common/dynamic_aipp.h" #include "common/formats/utils/formats_trans_utils.h" +using domi::AIPPDATA; using domi::AippOpParams; +using domi::DATA; +using domi::DEFAULT_FORMAT; +using domi::DIM_DEFAULT_SIZE; +using domi::NCHW_DIM_C; +using domi::NCHW_DIM_H; +using domi::NCHW_DIM_N; +using domi::NCHW_DIM_W; namespace ge { namespace { @@ -121,24 +130,25 @@ Status InsertNewOpUtil::CheckGraph(const ComputeGraphPtr &graph) { domi::AippOpParams::AippMode aippMode = domi::AippOpParams::undefined; for (const auto &node : graph->GetDirectNode()) { - if (node->GetType() != DATA) { + if (node->GetType() != domi::DATA) { continue; } - size_t next_nodes_cnt = 0; + std::vector aippNodes; for (const auto &anchor : node->GetAllOutDataAnchors()) { for (const auto &inAnchor : anchor->GetPeerInDataAnchors()) { const std::string &nodeType = inAnchor->GetOwnerNode()->GetType(); - next_nodes_cnt++; - if (nodeType == AIPP) { + + GE_CHK_BOOL_RET_STATUS(aippNodes.size() == 0 || nodeType == domi::AIPP, PARAM_INVALID, + "Can not config part of outputs of Data node to support AIPP, config all of the " + "outputs of Data to support AIPP, or config none of them"); + + if (nodeType == domi::AIPP) { aippNodes.push_back(inAnchor->GetOwnerNode()); continue; } } } - GE_CHK_BOOL_RET_STATUS((aippNodes.size() == 0) || (aippNodes.size() == next_nodes_cnt), PARAM_INVALID, - "Can not config part of outputs of Data node to support AIPP, config all " - "of the outputs of Data to support AIPP, or config none of them"); std::unique_ptr aippParams(new (std::nothrow) domi::AippOpParams()); GE_CHECK_NOTNULL(aippParams); @@ -177,7 +187,7 @@ Status InsertNewOpUtil::GetAippParams(const std::unique_ptr ge::GeAttrValue::NamedAttrs aipp_attr; const OpDescPtr tmpOpPtr = aipp_node->GetOpDesc(); GE_CHECK_NOTNULL(tmpOpPtr); - GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(tmpOpPtr, ATTR_NAME_AIPP, aipp_attr), FAILED, + GE_CHK_BOOL_RET_STATUS(AttrUtils::GetNamedAttrs(tmpOpPtr, domi::ATTR_NAME_AIPP, aipp_attr), FAILED, "Aipp node should contain param aipp!"); GE_CHK_STATUS_RET(OpUtils::ConvertAippParams(aipp_attr, aippParams.get()), "get aipp params failed"); @@ -188,13 +198,13 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { std::set updated_switchn; for (auto &node : graph->GetDirectNode()) { - if (node->GetType() == DATA) { + if (node->GetType() == domi::DATA) { std::string switchn_name; if (AttrUtils::GetStr(node->GetOpDesc(), kMbatchSwitchnName, switchn_name)) { switchn_names_to_data[switchn_name] = node; } } - if (node->GetType() == AIPP) { + if (node->GetType() == domi::AIPP) { GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); } } @@ -262,7 +272,7 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s output->SetShape(aipp_shape); output->SetOriginShape(aipp_shape); ge::TensorUtils::SetSize(*output, size); - if (src_node->GetType() == SWITCHN) { + if (src_node->GetType() == domi::SWITCHN) { switchns.insert(src_node); } GELOGI("Set node %s output %d size %ld by aipp.", src_node->GetName().c_str(), peer_out_anchor->GetIdx(), size); diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index 523c41cb..9edd1d0a 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -46,7 +46,9 @@ const int kMergeDataOutIndex = 0; const size_t kMaxShapesCount = 100; const size_t kMinShapesCount = 2; -inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); } +inline bool IsDataLikeType(const std::string &node_type) { + return (node_type == domi::DATA) || (node_type == domi::AIPP); +} NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) { OpDescPtr desc = MakeShared(); @@ -55,7 +57,7 @@ NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const return nullptr; } desc->SetName(name); - desc->SetType(MERGE); + desc->SetType(domi::MERGE); GeTensorDesc tensor_desc; for (size_t i = 0; i < input_num; ++i) { auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc); @@ -142,7 +144,7 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { bool IsAllDimsPositive(const std::vector &dims) { for (auto dim : dims) { - if (dim < 0) { + if (dim <= 0) { return false; } } @@ -156,7 +158,7 @@ NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) { return nullptr; } desc->SetName(name); - desc->SetType(CONSTANT); + desc->SetType(domi::CONSTANT); GeTensor tensor; tensor.SetData(std::vector({0})); if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) { @@ -176,7 +178,7 @@ NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) { bool IsOnlyOutputToAipp(const NodePtr &node) { for (const auto &out_node : node->GetOutDataNodes()) { - if (out_node->GetType() != AIPP) { + if (out_node->GetType() != domi::AIPP) { return false; } } @@ -186,7 +188,7 @@ bool IsOnlyOutputToAipp(const NodePtr &node) { Status CheckDataShape(const std::vector &nodes) { size_t unknown_shape_count = 0; for (const auto &node : nodes) { - if (node->GetType() != DATA) { + if (node->GetType() != domi::DATA) { continue; } for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) { @@ -288,7 +290,7 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { return SUCCESS; } NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { - if (node->GetType() == NETOUTPUT) { + if (node->GetType() == domi::NETOUTPUT) { return kNodeOutBatchBranch; } if (IsDataLikeType(node->GetType()) && !IsOnlyOutputToAipp(node)) { @@ -427,7 +429,7 @@ NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() { return nullptr; } desc->SetName("ascend_mbatch_shape_data"); - desc->SetType(DATA); + desc->SetType(domi::DATA); GeTensorDesc tensor_desc; tensor_desc.SetFormat(FORMAT_ND); @@ -622,7 +624,7 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { return OUT_OF_MEMORY; } switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn"); - switchn_desc->SetType(SWITCHN); + switchn_desc->SetType(domi::SWITCHN); GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); if (switchn_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) { // data @@ -872,7 +874,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { std::vector> shapes; if (!domi::GetContext().dynamic_batch_size.empty()) { GELOGD("Found dynamic batch option, value %s", domi::GetContext().dynamic_batch_size.c_str()); - std::vector dims = ge::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); + std::vector dims = domi::StringUtils::Split(domi::GetContext().dynamic_batch_size, ','); for (const auto &dim : dims) { if (dim.empty()) { continue; @@ -883,13 +885,13 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { } if (!domi::GetContext().dynamic_image_size.empty()) { GELOGD("Found dynamic image size option, value %s", domi::GetContext().dynamic_image_size.c_str()); - std::vector shape_strs = ge::StringUtils::Split(domi::GetContext().dynamic_image_size, ';'); + std::vector shape_strs = domi::StringUtils::Split(domi::GetContext().dynamic_image_size, ';'); for (const auto &shape_str : shape_strs) { if (shape_str.empty()) { continue; } std::vector shape; - std::vector dims = ge::StringUtils::Split(shape_str, ','); + std::vector dims = domi::StringUtils::Split(shape_str, ','); for (const auto &dim : dims) { if (dim.empty()) { continue; diff --git a/src/ge/inc/graph_pass.h b/src/ge/inc/graph_pass.h index 8eb241c8..fb2a6238 100644 --- a/src/ge/inc/graph_pass.h +++ b/src/ge/inc/graph_pass.h @@ -73,13 +73,13 @@ class GraphPass : public Pass { static bool IsConstNode(const ge::NodePtr &node) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, GELOGE(FAILED, "Node GetOpDesc is nullptr"); return false); - if (node->GetOpDesc()->GetType() == CONSTANTOP) { + if (node->GetOpDesc()->GetType() == domi::CONSTANTOP) { return true; - } else if (node->GetOpDesc()->GetType() == FRAMEWORKOP) { + } else if (node->GetOpDesc()->GetType() == domi::FRAMEWORKOP) { string type; - GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type), + GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), domi::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type), return false, "Get original_type for op %s fail!", node->GetName().c_str()); - GE_IF_BOOL_EXEC(type == CONSTANT, GELOGI("Is const op"); return true); + GE_IF_BOOL_EXEC(type == domi::CONSTANT, GELOGI("Is const op"); return true); return false; } else { return false; diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index 1b449521..84ecc506 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -40,6 +40,7 @@ #include "runtime/kernel.h" using Json = nlohmann::json; +using domi::StringUtils; namespace ge { namespace { @@ -142,35 +143,6 @@ Status GELib::InnerInitialize(const map &options) { return SUCCESS; } -void GELib::SetIncreBuild(const map &options) { - auto iter = options.find(OPTION_EXEC_ENABLE_INCRE_BUILD); - if (iter != options.end()) { - const std::string enable_incre_build = "true"; - const std::string disable_incre_build = "false"; - if (iter->second == enable_incre_build) { - is_incre_build_ = true; - GELOGI("Enable incre build."); - auto path_iter = options.find(OPTION_EXEC_INCRE_BUILD_CACHE_PATH); - if (path_iter != options.end()) { - std::string cache_path = path_iter->second; - if (!cache_path.empty() && cache_path[cache_path.size() - 1] != '/') { - cache_path += "/"; - } - incre_build_cache_path_ = cache_path; - } else { - incre_build_cache_path_ = ".ge_cache/"; - } - GELOGD("Using incre build cache path: %s.", incre_build_cache_path_.c_str()); - } else if (iter->second == disable_incre_build) { - is_incre_build_ = false; - GELOGI("Disable incre build."); - } else { - is_incre_build_ = false; - GELOGW("Invalid ENABLE_INCRE_BUILD option, it should be true or false."); - } - } -} - Status GELib::SystemInitialize(const map &options) { Status status = FAILED; auto iter = options.find(OPTION_GRAPH_RUN_MODE); @@ -203,8 +175,6 @@ Status GELib::SystemInitialize(const map &options) { PropertiesManager::Instance().SetDumpStep(dump_step); } } - // check incre build flag - SetIncreBuild(options); if (is_train_mode_) { InitOptions(options); @@ -396,7 +366,7 @@ Status GELib::Finalize() { } GELOGI("VarManagerPool finalization."); - VarManagerPool::Instance().Destroy(); + VarManagerPool::Instance().Destory(); GELOGI("MemManager finalization."); MemManager::Instance().Finalize(); @@ -438,6 +408,6 @@ void GELib::RollbackInit() { (void)sessionManager_.Finalize(); } MemManager::Instance().Finalize(); - VarManagerPool::Instance().Destroy(); + VarManagerPool::Instance().Destory(); } } // namespace ge diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index 3db32dd2..0945907a 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -65,12 +65,6 @@ class GELib { // add head stream to model bool HeadStream() const { return head_stream_; } - // get incre build flag - bool IsIncreBuild() const { return is_incre_build_; } - - // get incre build cache path - const std::string &GetIncreBuildCachePath() const { return incre_build_cache_path_; } - Status InitSystemWithoutOptions(); Status InitSystemWithOptions(Options &options); Status SystemShutdownWithOptions(const Options &options); @@ -82,7 +76,6 @@ class GELib { Status SystemInitialize(const map &options); void RollbackInit(); void InitOptions(const map &options); - void SetIncreBuild(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; @@ -94,9 +87,8 @@ class GELib { bool is_system_inited = false; bool is_shutdown = false; bool is_use_hcom = false; - bool is_incre_build_ = false; + bool head_stream_ = false; - std::string incre_build_cache_path_; }; } // namespace ge diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index 671a34af..2c871559 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -36,11 +36,13 @@ #include "framework/omg/omg_inner_types.h" using domi::GetContext; +using domi::StringUtils; using ge::FileSaver; using ge::GRAPH_PARAM_INVALID; using ge::GRAPH_SUCCESS; using ge::ParseInputShape; using std::string; + using namespace std; namespace ge { @@ -89,9 +91,9 @@ class Impl { GetContext().user_out_nodes.clear(); GetContext().net_format = domi::DOMI_TENSOR_RESERVED; GetContext().type = domi::FRAMEWORK_RESERVED; - GetContext().run_mode = ONLY_PRE_CHECK; + GetContext().run_mode = domi::ONLY_PRE_CHECK; GetContext().train_flag = false; - GetContext().fp16_high_precision = HIGH_PRECISION_DEFAULT; + GetContext().fp16_high_precision = domi::HIGH_PRECISION_DEFAULT; GetContext().output_type.clear(); GetContext().net_name.clear(); GetContext().is_dynamic_input = false; @@ -187,7 +189,7 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vectorGetOpDesc(); GE_CHECK_NOTNULL(op); - if (op->GetType() == DATA) { + if (op->GetType() == domi::DATA) { GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); diff --git a/src/ge/omm/csa_interact.cc b/src/ge/omm/csa_interact.cc index dd3f6240..075da863 100644 --- a/src/ge/omm/csa_interact.cc +++ b/src/ge/omm/csa_interact.cc @@ -25,6 +25,8 @@ #include "mmpa/mmpa_api.h" #include "nlohmann/json.hpp" +using domi::CurrentTimeInStr; + namespace ge { namespace { const char FMK_STATUS_FILE_DIR_ENV[] = "FMK_STATUS_FILE_DIR"; diff --git a/src/ge/session/session_manager.cc b/src/ge/session/session_manager.cc index ebe0b188..aa34441a 100644 --- a/src/ge/session/session_manager.cc +++ b/src/ge/session/session_manager.cc @@ -23,6 +23,7 @@ #include "graph/load/new_model_manager/model_manager.h" #include "graph/ge_context.h" +using domi::ATTR_NAME_SESSION_GRAPH_ID; using std::map; using std::string; using std::vector; diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index f2d2da88..22e46008 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -29,10 +29,20 @@ #include "runtime/rt.h" #include "task/tbe_task_builder.h" -using domi::TaskDef; using std::unique_ptr; using std::vector; +using domi::AIPP_DATA_TYPE; +using domi::ALLOC_MEMORY_MAX_SIZE; +using domi::DATA_TYPE; +using domi::MODEL_ATTR_TASK_GEN_BASE_ADDR; +using domi::MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; +using domi::ModelFileHeader; +using domi::ModelHelper; +using domi::NETOUTPUT; +using domi::OmFileLoadHelper; +using domi::TaskDef; + namespace ge { namespace { const size_t kDataOutputNum = 1; @@ -76,11 +86,8 @@ void SingleOpModel::ParseOpModelParams(ModelHelper &model_helper, SingleOpModelP param.base_addr = ret ? static_cast(value) : 0; ret = ge::AttrUtils::GetInt(model, MODEL_ATTR_TASK_GEN_WEIGHT_ADDR, value); param.weight_addr = ret ? static_cast(value) : 0; - ret = ge::AttrUtils::GetInt(model, ATTR_MODEL_CORE_TYPE, value); - param.core_type = ret ? value : 0; - GELOGI("ParseOpModelParams(), memory_size:%lu, weight_size:%lu. core_type = %lu", param.memory_size, - param.weight_size, param.core_type); + GELOGI("ParseOpModelParams(), memory_size:%lu, weight_size:%lu.", param.memory_size, param.weight_size); } Status SingleOpModel::InitModelMem(StreamResource &res) { diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index c1a63758..528004b8 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -39,7 +39,6 @@ struct SingleOpModelParam { uint8_t *weight_base = nullptr; std::map addr_mapping_; - int64_t core_type = 0; }; class SingleOpModel { @@ -63,14 +62,14 @@ class SingleOpModel { Status BuildTaskList(SingleOp &single_op); Status BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task); - static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); + static void ParseOpModelParams(domi::ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(TbeOpTask *task, SingleOp &op); std::string model_name_; const void *ori_model_data_; uint32_t ori_model_size_; - ModelHelper model_helper_; + domi::ModelHelper model_helper_; map op_list_; SingleOpModelParam model_params_; diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index c0f6877f..1a47402e 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -22,13 +22,15 @@ #include "common/helper/model_helper.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/op/attr_define.h" #include "graph/load/new_model_manager/model_utils.h" -#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/task_info/task_info.h" #include "graph/manager/graph_var_manager.h" #include "runtime/rt.h" #include "single_op/task/build_task_utils.h" +using domi::TVM_ATTR_NAME_METADATA; + namespace ge { namespace { std::mutex g_reg_mutex; @@ -89,17 +91,16 @@ TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const OpDescPtr &o const domi::KernelDef &kernel_def) : op_desc_(op_desc), kernel_def_(kernel_def), stub_name_(model_name + "/" + op_desc->GetName() + "_tvmbin") {} -Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, - const SingleOpModelParam ¶m) const { +Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle) const { rtDevBinary_t binary; binary.version = 0; binary.data = kernel_bin.GetBinData(); binary.length = kernel_bin.GetBinDataSize(); - binary.magic = param.core_type == 0 ? RT_DEV_BINARY_MAGIC_ELF : RT_DEV_BINARY_MAGIC_ELF_AIVEC; + binary.magic = RT_DEV_BINARY_MAGIC_ELF; auto ret = rtDevBinaryRegister(&binary, bin_handle); if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtDevBinaryRegister failed, bin key = %s, core_type = %ld, rt ret = %d", stub_name_.c_str(), - param.core_type, static_cast(ret)); + GELOGE(RT_FAILED, "rtDevBinaryRegister failed, bin key = %s, rt ret = %d", stub_name_.c_str(), + static_cast(ret)); return RT_FAILED; } @@ -133,13 +134,13 @@ Status TbeTaskBuilder::DoRegisterFunction(void *bin_handle, const char *stub_nam return SUCCESS; } -Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, void **bin_handle, - const SingleOpModelParam ¶m) { +Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const char *bin_file_key, + void **bin_handle) { std::string kernel_name; GetKernelName(op_desc_, kernel_name); void *handle = nullptr; - auto ret = DoRegisterBinary(tbe_kernel, &handle, param); + auto ret = DoRegisterBinary(tbe_kernel, &handle); if (ret != SUCCESS) { return ret; } @@ -161,7 +162,7 @@ Status TbeTaskBuilder::DoRegisterKernel(const ge::OpKernelBin &tbe_kernel, const return SUCCESS; } -Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m) { +Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task) { KernelBinRegistry ®istry = KernelBinRegistry::GetInstance(); // check if already registered const char *stub_func = registry.GetStubFunc(stub_name_); @@ -191,7 +192,7 @@ Status TbeTaskBuilder::RegisterKernel(TbeOpTask &task, const SingleOpModelParam } void *bin_handle = nullptr; - auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle, param); + auto ret = DoRegisterKernel(*tbe_kernel, stub_func, &bin_handle); if (ret == SUCCESS) { holder->SetBinHandle(bin_handle); if (!registry.AddKernel(stub_name_, holder)) { @@ -286,7 +287,7 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ return ret; } - ret = RegisterKernel(task, param); + ret = RegisterKernel(task); if (ret != SUCCESS) { return ret; } diff --git a/src/ge/single_op/task/tbe_task_builder.h b/src/ge/single_op/task/tbe_task_builder.h index 5e0965bf..25441289 100644 --- a/src/ge/single_op/task/tbe_task_builder.h +++ b/src/ge/single_op/task/tbe_task_builder.h @@ -74,10 +74,9 @@ class TbeTaskBuilder { Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; - Status RegisterKernel(TbeOpTask &task, const SingleOpModelParam ¶m); - Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle, - const SingleOpModelParam ¶m); - Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const; + Status RegisterKernel(TbeOpTask &task); + Status DoRegisterKernel(const OpKernelBin &kernel_bin, const char *bin_file_key, void **bin_handle); + Status DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle) const; Status DoRegisterMeta(void *bin_handle); static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); diff --git a/src/proto/fusion_model.proto b/src/proto/fusion_model.proto index 6220963c..2ff6b77a 100644 --- a/src/proto/fusion_model.proto +++ b/src/proto/fusion_model.proto @@ -17,10 +17,9 @@ syntax = "proto3"; import "om.proto"; - package domi; message FusionModelDef { string version = 1; repeated OpDef fusion_op = 2; -} \ No newline at end of file +} diff --git a/src/proto/task.proto b/src/proto/task.proto index 8ef5c2e2..3eb8de5c 100644 --- a/src/proto/task.proto +++ b/src/proto/task.proto @@ -31,7 +31,7 @@ message ModelTaskDef { repeated bytes op = 15; // input/output opdef in bytes - uint64 base_addr = 16; // base addr + uint64 base_addr = 16; // base addr uint64 weight_addr = 17; // weight addr uint32 batch_num = 18; } @@ -58,10 +58,6 @@ message TaskDef { bytes private_def = 34; uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; } message KernelDef { @@ -123,7 +119,6 @@ message MemcpyAsyncDef { uint64 src = 3; uint64 count = 4; uint32 kind = 5; - uint32 op_index = 6; } message StreamSwitchDef { @@ -147,20 +142,3 @@ message StreamSwitchNDef { uint32 element_size = 5; uint32 data_type = 6; } - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} diff --git a/tests/depends/cce/CMakeLists.txt b/tests/depends/cce/CMakeLists.txt index 885a5ca2..70516146 100644 --- a/tests/depends/cce/CMakeLists.txt +++ b/tests/depends/cce/CMakeLists.txt @@ -29,8 +29,10 @@ include_directories(${GE_SOURCE_DIR}/src/common/graph) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "${GE_SOURCE_DIR}/src/proto/om.proto" "${GE_SOURCE_DIR}/src/proto/ge_ir.proto" diff --git a/tests/depends/mmpa/CMakeLists.txt b/tests/depends/mmpa/CMakeLists.txt index 4688eb04..6185c8fe 100644 --- a/tests/depends/mmpa/CMakeLists.txt +++ b/tests/depends/mmpa/CMakeLists.txt @@ -25,6 +25,7 @@ include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/inc) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/external) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) add_library(mmpa_stub SHARED ${SRCS}) target_link_libraries(mmpa_stub protobuf::protobuf) diff --git a/tests/depends/omg/CMakeLists.txt b/tests/depends/omg/CMakeLists.txt index 03915f5c..158a1ff4 100644 --- a/tests/depends/omg/CMakeLists.txt +++ b/tests/depends/omg/CMakeLists.txt @@ -29,6 +29,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/src/ge) include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/proto/ge) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "${GE_SOURCE_DIR}/src/proto/om.proto" "${GE_SOURCE_DIR}/src/proto/task.proto" diff --git a/tests/depends/omg/src/omg_stub.cc b/tests/depends/omg/src/omg_stub.cc index 224d4128..7197dac7 100644 --- a/tests/depends/omg/src/omg_stub.cc +++ b/tests/depends/omg/src/omg_stub.cc @@ -122,7 +122,6 @@ struct OmFileContext { class SubGraphInfo; using SubGraphInfoPtr = std::shared_ptr; -using Graph2SubGraphInfoList = std::unordered_map>; using GeModelPartitionPtr = std::shared_ptr; using ModelPtr = std::shared_ptr; @@ -221,7 +220,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void OmFileSaveHelper::AddParti } class ModelBuilder { public: - ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, + ModelBuilder(ge::ComputeGraphPtr compute_graph, const std::vector &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode); virtual ~ModelBuilder(); Status BuildModel(ge::Model &model_def); @@ -236,7 +235,7 @@ class ModelBuilder { ge::Buffer weight_buffer_; }; -ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, +ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const std::vector &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode) { weight_buffer_ = ge::Buffer(4100000); } diff --git a/tests/depends/runtime/CMakeLists.txt b/tests/depends/runtime/CMakeLists.txt index 80cc14e4..dbbaa8fc 100644 --- a/tests/depends/runtime/CMakeLists.txt +++ b/tests/depends/runtime/CMakeLists.txt @@ -23,5 +23,6 @@ file(GLOB_RECURSE SRCS RELATIVE ${CMAKE_CURRENT_LIST_DIR} include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/inc/framework) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) add_library(runtime_stub SHARED ${SRCS}) diff --git a/tests/ut/common/graph/CMakeLists.txt b/tests/ut/common/graph/CMakeLists.txt index eda5df28..064abc16 100644 --- a/tests/ut/common/graph/CMakeLists.txt +++ b/tests/ut/common/graph/CMakeLists.txt @@ -35,6 +35,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external) include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/inc/common) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) include_directories(${CMAKE_BINARY_DIR}) @@ -93,4 +94,4 @@ file(GLOB_RECURSE SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ) add_executable(ut_libgraph ${UT_FILES} ${SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) -target_link_libraries(ut_libgraph graphengine::gtest graphengine::gtest_main slog_stub protobuf::protobuf graphengine::securec rt dl) +target_link_libraries(ut_libgraph graphengine::gtest graphengine::gtest_main slog_stub protobuf::protobuf ${c_sec} rt dl) diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index c636362c..5ed130c7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -43,7 +43,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) include_directories(${GE_SOURCE_DIR}/inc/graph) include_directories(${GE_SOURCE_DIR}/inc/framework) include_directories(${GE_SOURCE_DIR}/inc/common) -include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) +include_directories(${GE_SOURCE_DIR}/third_party/securec/include) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) @@ -491,7 +491,7 @@ file(GLOB_RECURSE OTHERS_TEST_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} list(APPEND COMMON_SHARED_LIBRARIES omg_stub - graphengine::securec + ${c_sec} slog_stub cce_ge_stub runtime_stub diff --git a/tests/ut/ge/common/datatype_transfer_unittest.cc b/tests/ut/ge/common/datatype_transfer_unittest.cc index e0f258a9..5f11b272 100644 --- a/tests/ut/ge/common/datatype_transfer_unittest.cc +++ b/tests/ut/ge/common/datatype_transfer_unittest.cc @@ -368,20 +368,14 @@ TEST_F(UtestDataTypeTransfer, invalid_src_data_type) { EXPECT_EQ(transfer.TransDataType(args, result), UNSUPPORTED); } -TEST_F(UtestDataTypeTransfer, src_shape_empty) { - uint8_t data[1*4*4*1] = {0}; - constexpr int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; +TEST_F(UtestDataTypeTransfer, src_shape_empry) { + uint8_t data[1 * 4 * 4 * 1] = {0}; DataTypeTransfer transfer; - CastArgs args { - reinterpret_cast(data), - 0, - DT_UINT8, - DT_INT32 - }; + CastArgs args{reinterpret_cast(data), 0, DT_UINT8, DT_INT32}; TransResult result; - EXPECT_EQ(transfer.TransDataType(args, result), SUCCESS); + EXPECT_EQ(transfer.TransDataType(args, result), PARAM_INVALID); } TEST_F(UtestDataTypeTransfer, unsupprot_trans) { diff --git a/tests/ut/ge/common/format_transfer_nhwc_5d_unittest.cc b/tests/ut/ge/common/format_transfer_nhwc_5d_unittest.cc index b4beb6ce..8d1ff256 100644 --- a/tests/ut/ge/common/format_transfer_nhwc_5d_unittest.cc +++ b/tests/ut/ge/common/format_transfer_nhwc_5d_unittest.cc @@ -701,7 +701,7 @@ TEST_F(UtestFormatTransferNhwc5d, invalid_src_shape2) { EXPECT_EQ(transfer.TransFormat(args, result), PARAM_INVALID); Status status = transfer.TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, args.dst_shape); - EXPECT_EQ(status, SUCCESS); + EXPECT_EQ(status, PARAM_INVALID); } TEST_F(UtestFormatTransferNhwc5d, invalid_src_format) { diff --git a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc index e49005e8..f9799b49 100644 --- a/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc +++ b/tests/ut/ge/graph/build/logical_stream_allocator_unittest.cc @@ -21,7 +21,6 @@ #define protected public #define private public #include "graph/manager/graph_manager_utils.h" -#include "common/op/attr_value_util.h" #undef protected #undef private @@ -190,20 +189,18 @@ class UtestLogicalStreamAllocator : public testing::Test { bool ExpectStreamEq(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) == expect; } bool ExpectStreamNe(SubGraphInfoPtr subgraph, int64_t expect) { return GetStream(subgraph) != expect; } - Status AssignLogicalStreams(Graph2SubGraphInfoList &subgraph_map, vector &confs, + Status AssignLogicalStreams(vector subgraphs, vector &confs, std::map &max_parallel_num, ComputeGraphPtr &whole_graph) { SchedulerConf scheduler_conf; if (confs.empty()) { - for (const auto &subgraph_pair : subgraph_map) { - for (const auto &subgraph : subgraph_pair.second) { - EngineConfPtr conf = make_shared(); - conf->id = subgraph->GetEngineName(); - if (conf->id == "ge_local") { - conf->skip_assign_stream = true; - conf->attach = true; - } - scheduler_conf.cal_engines[conf->id] = conf; + for (const auto &subgraph : subgraphs) { + EngineConfPtr conf = make_shared(); + conf->id = subgraph->GetEngineName(); + if (conf->id == "ge_local") { + conf->skip_assign_stream = true; + conf->attach = true; } + scheduler_conf.cal_engines[conf->id] = conf; } } else { for (auto &conf : confs) { @@ -220,33 +217,24 @@ class UtestLogicalStreamAllocator : public testing::Test { scheduler_confs["scheduler"] = scheduler_conf; LogicalStreamAllocator allocator(scheduler_confs, max_parallel_num); int64_t stream_num = 0; - return allocator.Assign(whole_graph, subgraph_map, stream_num); + return allocator.Assign(whole_graph, subgraphs, stream_num); } - Status AssignLogicalStreams(vector subgraphs, - vector &confs, - std::map &max_parallel_num, - ComputeGraphPtr &whole_graph) { - Graph2SubGraphInfoList subgraph_map; - subgraph_map[whole_graph] = subgraphs; - return AssignLogicalStreams(subgraph_map, confs, max_parallel_num, whole_graph); - } - - Status AssignLogicalStreams(vector subgraphs, vector& confs, - std::map& max_parallel_num) { - ComputeGraphPtr whole_graph = make_shared < ComputeGraph > ("whole_graph"); + Status AssignLogicalStreams(vector subgraphs, std::map &max_parallel_num, + vector &confs) { + ComputeGraphPtr whole_graph = make_shared("whole_graph"); return AssignLogicalStreams(subgraphs, confs, max_parallel_num, whole_graph); } Status AssignLogicalStreams(vector subgraphs, vector confs = vector()) { std::map max_parallel_num; - return AssignLogicalStreams(subgraphs, confs, max_parallel_num); + return AssignLogicalStreams(subgraphs, max_parallel_num, confs); } - Status AssignLogicalStreams(vector subgraphs, std::map& max_parallel_num) { - vector < EngineConfPtr > confs; - return AssignLogicalStreams(subgraphs, confs, max_parallel_num); + Status AssignLogicalStreams(vector subgraphs, std::map &max_parallel_num) { + vector confs; + return AssignLogicalStreams(subgraphs, max_parallel_num, confs); } /// typical case @@ -306,8 +294,8 @@ class UtestLogicalStreamAllocator : public testing::Test { max_parallel_num["aicpu"] = parallel_num; Status status = AssignLogicalStreams({const1, const2, get_next, genmask1, genmask2, domask, subgraph4, subgraph5, - subgraph6, allreduce1, allreduce2, apply1, apply2}, confs, - max_parallel_num); + subgraph6, allreduce1, allreduce2, apply1, apply2}, + max_parallel_num, confs); EXPECT_EQ(status, ge::SUCCESS); EXPECT_EQ(GetStream(get_next), 0); @@ -336,7 +324,7 @@ class UtestLogicalStreamAllocator : public testing::Test { /// E --> F(AllReduce) --- G /// stream id: 2 2 2 /// - void MakeGraphWithAllreduce(ge::ComputeGraphPtr graph) { + void make_graph_with_allreduce(ge::ComputeGraphPtr graph) { ge::OpDescPtr op_a = make_shared("A", DATA); auto desc_temp_ptr = make_shared(); auto desc_temp = *desc_temp_ptr; @@ -349,7 +337,6 @@ class UtestLogicalStreamAllocator : public testing::Test { ge::OpDescPtr op_c = make_shared("C", "HcomAllReduce"); op_c->AddInputDesc(desc_temp); - op_c->AddInputDesc(desc_temp); op_c->AddOutputDesc(desc_temp); ge::OpDescPtr op_d = make_shared("D", "testa"); @@ -362,21 +349,12 @@ class UtestLogicalStreamAllocator : public testing::Test { ge::OpDescPtr op_f = make_shared("F", "HcomAllReduce"); op_f->AddInputDesc(desc_temp); - op_f->AddInputDesc(desc_temp); op_f->AddOutputDesc(desc_temp); ge::OpDescPtr op_g = make_shared("G", "testa"); op_g->AddInputDesc(desc_temp); op_g->AddOutputDesc(desc_temp); - ge::OpDescPtr op_h = make_shared("H", "testa"); - op_h->AddInputDesc(desc_temp); - op_h->AddOutputDesc(desc_temp); - - ge::OpDescPtr op_i = make_shared("I", "testa"); - op_i->AddInputDesc(desc_temp); - op_i->AddOutputDesc(desc_temp); - // add node ge::NodePtr node_a = graph->AddNode(op_a); ge::NodePtr node_b = graph->AddNode(op_b); @@ -385,18 +363,14 @@ class UtestLogicalStreamAllocator : public testing::Test { ge::NodePtr node_e = graph->AddNode(op_e); ge::NodePtr node_f = graph->AddNode(op_f); ge::NodePtr node_g = graph->AddNode(op_g); - ge::NodePtr node_h = graph->AddNode(op_h); - ge::NodePtr node_i = graph->AddNode(op_i); // add edge - node_a->GetOutDataAnchor(0)->LinkTo(node_b->GetInDataAnchor(0)); - node_a->GetOutDataAnchor(0)->LinkTo(node_e->GetInDataAnchor(0)); - node_b->GetOutDataAnchor(0)->LinkTo(node_c->GetInDataAnchor(0)); - node_c->GetOutDataAnchor(0)->LinkTo(node_d->GetInDataAnchor(0)); - node_e->GetOutDataAnchor(0)->LinkTo(node_f->GetInDataAnchor(0)); - node_f->GetOutDataAnchor(0)->LinkTo(node_g->GetInDataAnchor(0)); - node_h->GetOutDataAnchor(0)->LinkTo(node_c->GetInDataAnchor(1)); - node_i->GetOutDataAnchor(0)->LinkTo(node_f->GetInDataAnchor(1)); + ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_e->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_d->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_e->GetOutDataAnchor(0), node_f->GetInDataAnchor(0)); + ge::GraphUtils::AddEdge(node_f->GetOutDataAnchor(0), node_g->GetInDataAnchor(0)); // add stream id node_a->GetOpDesc()->SetStreamId(0); @@ -406,14 +380,6 @@ class UtestLogicalStreamAllocator : public testing::Test { node_e->GetOpDesc()->SetStreamId(2); node_f->GetOpDesc()->SetStreamId(2); node_g->GetOpDesc()->SetStreamId(2); - - // add stream label - string stream_label1 = "1"; - (void) AttrUtils::SetStr(node_c->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label1); - (void) AttrUtils::SetStr(node_d->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label1); - string stream_label2 = "2"; - (void) AttrUtils::SetStr(node_f->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label2); - (void) AttrUtils::SetStr(node_g->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label2); } }; @@ -686,7 +652,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent) { vector confs = {conf1, conf2}; Status status = - AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, confs, max_parallel_num); + AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); EXPECT_EQ(status, ge::SUCCESS); EXPECT_EQ(GetStream(subgraph1), 0); EXPECT_EQ(GetStream(subgraph2), 0); @@ -729,7 +695,7 @@ TEST_F(UtestLogicalStreamAllocator, test_independent_switch_label) { vector confs = {conf1, conf2, conf3}; Status status = - AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, confs, max_parallel_num); + AssignLogicalStreams({subgraph1, subgraph2, subgraph3, subgraph4, subgraph5}, max_parallel_num, confs); EXPECT_EQ(status, ge::SUCCESS); EXPECT_EQ(GetStream(subgraph1), 4); EXPECT_EQ(GetStream(subgraph2), 0); @@ -867,9 +833,9 @@ TEST_F(UtestLogicalStreamAllocator, test_reassign_stream) { auto node1_1 = whole_graph->AddNode(node1->GetOpDesc()); auto node1_2 = whole_graph->AddNode(node2->GetOpDesc()); auto node1_3 = whole_graph->AddNode(node3->GetOpDesc()); - node1_1->GetOutControlAnchor()->LinkTo(node1_2->GetInControlAnchor()); - node1_2->GetOutDataAnchor(0)->LinkTo(node1_3->GetInDataAnchor(0)); - node1->GetOutControlAnchor()->LinkTo(node2->GetInControlAnchor()); + GraphUtils::AddEdge(node1_1->GetOutControlAnchor(), node1_2->GetInControlAnchor()); + GraphUtils::AddEdge(node1_2->GetOutDataAnchor(0), node1_3->GetInDataAnchor(0)); + GraphUtils::AddEdge(node1->GetOutControlAnchor(), node2->GetInControlAnchor()); std::map max_parallel_num; vector subgraphs = {subgraph1, const2, subgraph3}; @@ -887,7 +853,7 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { ge::ComputeGraphPtr graph = make_shared(""); graph->SetName("TestAllReduceParallelPass"); - MakeGraphWithAllreduce(graph); + make_graph_with_allreduce(graph); std::map max_parallel_num; LogicalStreamPass::Context context; @@ -897,13 +863,7 @@ TEST_F(UtestLogicalStreamAllocator, test_all_reduce_parallel_pass) { LogicalStreamPassPtr allreduce_pass = std::make_shared(); ret = allreduce_pass->Run(graph, subgraphs, context); - EXPECT_EQ(ret, SUCCESS); - - ge::NodePtr node_d = graph->FindNode("D"); - ge::NodePtr node_g = graph->FindNode("G"); - int64_t stream_d = node_d->GetOpDesc()->GetStreamId(); - int64_t stream_g = node_g->GetOpDesc()->GetStreamId(); - EXPECT_EQ(stream_d + stream_g, 11); + EXPECT_EQ(ret, NOT_CHANGED); } } // namespace ge diff --git a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc index a51299b3..f8deff7f 100644 --- a/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc +++ b/tests/ut/ge/graph/load/new_model_manager_davinci_model_unittest.cc @@ -315,7 +315,7 @@ TEST_F(UtestModelManagerDavinciModel, success_GetInputOutputDescInfo_without_net auto node = compute_graph->AddNode(op_desc); model.data_op_list_.push_back(op_desc); - model.output_data_info_[0] = {32, (void *)0x70002010}; + model.output_size_list_.push_back(32); model.op_list_[0] = op_desc; @@ -419,7 +419,7 @@ TEST_F(UtestModelManagerDavinciModel, success_get_input_output_descInfo_with_net model.op_list_[0] = op_desc; model.output_op_list_.push_back(op_desc); - model.output_data_info_[0] = {32, (void *)0x70002010}; + model.output_size_list_.push_back(32); vector input_shapes; vector output_shapes; @@ -463,7 +463,7 @@ TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_for_zer model.op_list_[0] = op_desc; model.output_op_list_.push_back(op_desc); - model.output_data_info_[0] = {32, (void *)0x70002010}; + model.output_size_list_.push_back(32); model.output_memory_size_list_.push_back(64); vector input_shapes; @@ -508,7 +508,7 @@ TEST_F(UtestModelManagerDavinciModel, success_get_input_output_desc_info_dim_siz model.op_list_[0] = op_desc; model.output_op_list_.push_back(op_desc); - model.output_data_info_[0] = {32, (void *)0x70002010}; + model.output_size_list_.push_back(32); vector input_shapes; vector output_shapes; @@ -1282,7 +1282,7 @@ TEST_F(UtestModelManagerDavinciModel, success_get_output_desc_info_with_netoutpu model.op_list_[0] = op_desc; model.output_op_list_.push_back(op_desc); - model.output_data_info_[0] = {32, (void *)0x70002010}; + model.output_size_list_.push_back(32); model.output_memory_size_list_.push_back(64); vector output_shapes; diff --git a/tests/ut/ge/graph/load/output_net_output_unittest.cc b/tests/ut/ge/graph/load/output_net_output_unittest.cc index ca0eb871..52fdebfa 100644 --- a/tests/ut/ge/graph/load/output_net_output_unittest.cc +++ b/tests/ut/ge/graph/load/output_net_output_unittest.cc @@ -131,6 +131,25 @@ TEST_F(UtestNetOutput, true_is_output) { delete model_utils; } +// test ModelUtils::IsInputTensorNeedTrans +TEST_F(UtestNetOutput, success_is_output_tensor_need_trans) { + ModelUtils *model_utils = new ModelUtils(); + std::shared_ptr op_desc = std::make_shared(); + OmeTestOpDescBuilder builder(op_desc); + builder.SetType("NetOutput"); + size_t tensor_index = 1; + vector outputs_desc; + std::shared_ptr desc = std::make_shared(); + outputs_desc.push_back(desc); + op_desc->outputs_desc_ = outputs_desc; + op_desc->inputs_desc_ = outputs_desc; + + bool ret = model_utils->IsInputTensorNeedTrans(op_desc, tensor_index); + EXPECT_EQ(false, ret); + + delete model_utils; +} + // test ModelUtils::GetOutputSize TEST_F(UtestNetOutput, success_get_output_size) { vector v_output_size; diff --git a/third_party/fwkacllib/inc/ops/all_ops.h b/third_party/fwkacllib/inc/ops/all_ops.h index 37315c74..36c991ff 100644 --- a/third_party/fwkacllib/inc/ops/all_ops.h +++ b/third_party/fwkacllib/inc/ops/all_ops.h @@ -35,6 +35,7 @@ #include "decode_wheels_target.h" #include "elewise_calculation_ops.h" #include "fastrcnn_predictions.h" +#include "fsrdetectionoutput_ops.h" #include "functional_ops.h" #include "get_data_ops.h" #include "hcom_ops.h" @@ -57,6 +58,7 @@ #include "outfeed_ops.h" #include "pad_ops.h" #include "parsing_ops.h" +#include "power_ops.h" #include "quantize_ops.h" #include "ragged_conversion_ops.h" #include "random_ops.h" diff --git a/third_party/fwkacllib/inc/ops/array_ops.h b/third_party/fwkacllib/inc/ops/array_ops.h index 7febad77..0d1126aa 100644 --- a/third_party/fwkacllib/inc/ops/array_ops.h +++ b/third_party/fwkacllib/inc/ops/array_ops.h @@ -595,9 +595,6 @@ REG_OP(ExpandDims) *@par Outputs: *y: A tensor. - -*@par Attention: -*This operator cannot be directly called by the acllopExecute API. */ REG_OP(Reshape) .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, @@ -851,7 +848,6 @@ REG_OP(Copy) `farmhash::fingerprint64`. *@par Outputs: -y: A two-dimensional `Tensor` of type `uint8`. The first dimension equals to \n `data`'s first dimension, and the second dimension size depends on the \n fingerprint algorithm. diff --git a/third_party/fwkacllib/inc/ops/data_flow_ops.h b/third_party/fwkacllib/inc/ops/data_flow_ops.h index fee5e67d..08cbd1ff 100644 --- a/third_party/fwkacllib/inc/ops/data_flow_ops.h +++ b/third_party/fwkacllib/inc/ops/data_flow_ops.h @@ -259,7 +259,7 @@ match this name to the matching Unstage Op. REG_OP(Stage) .DYNAMIC_INPUT(values, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, \ DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -312,7 +312,7 @@ REG_OP(StagePeek) .INPUT(index, TensorType({DT_INT32})) .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -363,7 +363,7 @@ REG_OP(StackPop) .INPUT(handle, TensorType({DT_RESOURCE})) .OUTPUT(element, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .REQUIRED_ATTR(elem_type, Type) .OP_END_FACTORY_REG(StackPop) @@ -388,10 +388,10 @@ REG_OP(StackPush) .INPUT(handle, TensorType({DT_RESOURCE})) .INPUT(element, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .ATTR(swap_memory, Bool, false) .OP_END_FACTORY_REG(StackPush) @@ -540,7 +540,6 @@ REG_OP(ParallelDynamicStitch) *@par Attributes:An optional int that is >= 0. Defaults to "0". *@li capacity: An optional int that is >= 0. Defaults to "0". *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes. *@li container: An optional string. Defaults to "". *@li shared_name: An optional string. Defaults to "". @@ -563,7 +562,6 @@ REG_OP(MapClear) *@par Attributes: *@li capacity: An optional int that is >= 0. Defaults to "0". *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes. *@li container: An optional string. Defaults to "". *@li shared_name: An optional string. Defaults to "". @@ -602,7 +600,7 @@ REG_OP(MapIncompleteSize) REG_OP(Unstage) .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT16, \ DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, \ - DT_DOUBLE, DT_UINT32, DT_UINT64})) + DT_DOUBLE})) .ATTR(capacity, Int, 0) .ATTR(memory_limit, Int, 0) .ATTR(container, String, "") @@ -630,7 +628,6 @@ DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32. Maximum number of elements in the Staging Area. If > 0, \n inserts on the container will block when the capacity is reached. *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes. *@li container: An optional string. Defaults to "". \n If non-empty, this queue is placed in the given container. \n Otherwise, a default container is used. @@ -752,7 +749,6 @@ REG_OP(MapUnstageNoKey) *@par Attributes: *@li capacity: An optional int that is >= 0. Defaults to "0". *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes that has length >= 1. *@li container: An optional string. Defaults to "". *@li shared_name: An optional string. Defaults to "". @@ -789,7 +785,6 @@ REG_OP(MapPeek) *@par Attributes: *@li capacity: An optional int that is >= 0. Defaults to "0". *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes. *@li container: An optional string. Defaults to "". *@li shared_name: An optional string. Defaults to "". @@ -1183,7 +1178,6 @@ REG_OP(PaddingFIFOQueue) *@brief A queue that produces elements sorted by the first component value. *@par Attributes: -*@li component_types: An optional list of DTypes. Defaults to {}. \n The type of each component in a value. *@li shapes: A list of shapes for each component of a queue element. The length of this attr must be either 0 or the same as the length of \n @@ -1451,7 +1445,6 @@ REG_OP(OrderedMapUnstageNoKey) *@par Attributes: *@li capacity: An optional int that is >= 0. Defaults to "0". *@li memory_limit: An optional int that is >= 0. Defaults to "0". -*@li dtypes: A list of DTypes that has length >= 1. *@li container: An optional string. Defaults to "". *@li shared_name: An optional string. Defaults to "". @@ -1876,7 +1869,7 @@ REG_OP(SparseAccumulatorApplyGradient) .INPUT(local_step, TensorType({DT_INT64})) .INPUT(indices, TensorType({DT_INT64})) .INPUT(values, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, \ - DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_UINT32, \ + DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT,DT_FLOAT16, DT_UINT32, \ DT_UINT64, DT_COMPLEX64, DT_COMPLEX128, DT_QINT16, DT_QUINT16, \ DT_QINT8, DT_QUINT8, DT_QINT32})) .INPUT(shape, TensorType({DT_INT64})) diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index 3eff2cbe..d5272805 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -886,10 +886,7 @@ REG_OP(BesselI1e) * y: A Tensor of type UnaryDataType. * @attention Constraints: -* @li "base" is supposed to be greater than 0. Retaining the default \n -* value "-1" sets "base" to "e". -* @li If the input value of operator Log is within the range (0, 0.01] or \n -* [0.95, 1.05], the output accuracy is subject to change. +* @li base > 0 or if base is set to default (-1), base is set to e; */ REG_OP(Log) .INPUT(x, TensorType::UnaryDataType()) @@ -2059,7 +2056,6 @@ REG_OP(ArgMinWithValue) * "0": product, "1": sum, "2": max. *@li coeff: A required attribute. Must met all of following rules: * size of "coeff" must be equal to len("x") or is null. -* the absolute value of “coeff” must less than or equal to 1. */ REG_OP(Eltwise) .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h b/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h new file mode 100644 index 00000000..2b3e206d --- /dev/null +++ b/third_party/fwkacllib/inc/ops/fsrdetectionoutput_ops.h @@ -0,0 +1,67 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_FSRDETECTIONOUTPUT_OPS_H_ +#define GE_OP_FSRDETECTIONOUTPUT_OPS_H_ +#include "graph/operator_reg.h" + +namespace ge { +/** +*@brief Returns detection result. + +*@par Inputs: +* Four inputs, including: +*@li rois: An NCHW tensor of type floa16 or float32, output from operator proposal_d at the preceding layer, used as the input of operator FSRDetectionOutput. +*@li prior_box: An NCHWC0 tensor of type floa16 or float32, specifying the prediction offset, used to update the coordinates [x1, y1, x2, y2] of each ROI. +*@li score: An NCHWC0 tensor of type floa16 or float32, specifying the probability of each class. Class 0 is the background class. +*@li actual_rois_num: An NCHW tensor of type int32, specifying the number of valid boxes per batch. +*@par Attributes: +*@li batch_rois: An optional int32, specifying the number of images to be predicted. Defaults to "1024". The value range is [1, 1024]. +*@li im_info: An optional list of two ints. Defaults to (375, 1024). The value range is [1, 1024]. +*@li num_classes: An optional int32, specifying the number of classes to be predicted. Defaults to "80". The value must be greater than 0. +*@li max_rois_num: An optional int32, specifying the maximum number of ROIs per batch. Defaults to "1024". The value must be a multiple of 16. +*@li score_thresh: An optional float32, specifying the threshold for box filtering. Defaults to 0.45. The value range is [0.0, 1.0]. +*@li nms_thresh: An optional float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. Defaults to 0.7. The value range is (0.0, 1.0). +*@li bbox_reg_weights: An optional list of four ints. Defaults to (1, 1, 1, 1). Must not have value "0". +*@li post_nms_topn: An optional int, specifying the number of output boxes. Defaults to "304". The value must be less than or equal to 1024 and must be a multiple of 16. +*@li kernel_name: An optional string, specifying the operator name. Defaults to "fsr_detection_output". +*@par Outputs: +*box: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*actual_bbox_num: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li totalnum < max_rois_num * batch_rois. +*@li "score" must be with shape (total_num, (num_classes+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. +*@li "prior_box" must be with shape (total_num, (num_classes*4+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. +*/ +REG_OP(FSRDetectionOutput) + .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(prior_box, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) + .INPUT(actual_rois_num, TensorType({DT_INT32})) + .OUTPUT(actual_bbox_num, TensorType({DT_INT32})) + .OUTPUT(box, TensorType({DT_FLOAT, DT_FLOAT16})) + .ATTR(batch_rois, Int, 1024) + .ATTR(im_info, ListInt, {375,1024}) + .ATTR(num_classes, Int, 80) + .ATTR(max_rois_num, Int, 1024) + .ATTR(score_thresh, Float, 0.45) + .ATTR(nms_thresh, Float, 0.7) + .ATTR(bbox_reg_weights, ListInt, {1,1,1,1}) + .ATTR(post_nms_topn, Int, 304) + .OP_END_FACTORY_REG(FSRDetectionOutput) +} +#endif diff --git a/third_party/fwkacllib/inc/ops/image_ops.h b/third_party/fwkacllib/inc/ops/image_ops.h index aaad03c6..2ac7a70e 100644 --- a/third_party/fwkacllib/inc/ops/image_ops.h +++ b/third_party/fwkacllib/inc/ops/image_ops.h @@ -525,7 +525,8 @@ REG_OP(ResizeBilinearV2) .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .INPUT(size, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, + DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .ATTR(align_corners, Bool, false) .ATTR(half_pixel_centers, Bool, false) .OP_END_FACTORY_REG(ResizeBilinearV2) @@ -924,7 +925,7 @@ images[3] <= 2048. */ REG_OP(ResizeBilinearV2D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(y, TensorType({DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) .ATTR(align_corners, Bool, false) .ATTR(half_pixel_centers, Bool, false) .REQUIRED_ATTR(size, ListInt) diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index cc97a337..aa318c94 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -22,29 +22,6 @@ namespace ge { -/** -*@brief Computes the output as (shift + scale * x) ^ power. - -*@par Inputs: -* x: A Tensor of type float16 or float32. - -*@par Attributes: -*@li power: Optional. Defaults to 1.0. -*@li scale: Optional. Defaults to 1.0. -*@li shift: Optional. Defaults to 0.0. - -*@par Outputs: -* y: A Tensor. Has the same type and shape as "x". -*/ - -REG_OP(Power) - .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(power, Float, 1.0) - .ATTR(scale, Float, 1.0) - .ATTR(shift, Float, 0.0) - .OP_END_FACTORY_REG(Power); - /** *@brief Compute the lower regularized incomplete Gamma function P(a, x). diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index dd2ce56c..597a8982 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -492,7 +492,7 @@ REG_OP(DiagPart) *@brief Also known as a "fully-connected" layer, computes an inner product with a set of learned weights, and (optionally) adds biases. *@par Inputs: -* Four inputs, including: +* Two inputs, including: *@li x: A Tensor of type float16, int8. *@li w: A weight matrix of type float16, int8. *@li b: A Tensor of type float16, int32. @@ -501,13 +501,14 @@ REG_OP(DiagPart) *@par Attributes: *@li num_output: Reserved. *@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". -*@li axis: Reserved. -*@li offset_x: Reserved. +*@li bias_term: A bool, specifying whether to learn and apply a set of additive biases to the filter outputs, either "true" or "false". Defaults to "true". +*@li axis: only support axis is 1. Defaults to "1". +*@li offset_a: A type of Int, Defaults to "1". *@par Outputs: *y: The result tensor of type float16, int8. */ -REG_OP(FullyConnection) +REG_OP(InnerProduct) .INPUT(x, TensorType({DT_FLOAT16, DT_INT8})) .INPUT(w, TensorType({DT_FLOAT16, DT_INT8})) .OPTIONAL_INPUT(b, TensorType({DT_FLOAT16, DT_INT32})) @@ -515,9 +516,10 @@ REG_OP(FullyConnection) .OUTPUT(y, TensorType({DT_FLOAT16, DT_INT32})) .REQUIRED_ATTR(num_output, Int) .ATTR(transpose, Bool, false) + .ATTR(bias_term, Bool, true) .ATTR(axis, Int, 1) - .ATTR(offset_x, Int, 0) - .OP_END_FACTORY_REG(FullyConnection) + .ATTR(offset_a, Int, 0) + .OP_END_FACTORY_REG(InnerProduct) /** *@brief Computes the confusion matrix from predictions and labels. diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index bc492e1b..1be85a0e 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -62,7 +62,7 @@ namespace ge { * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n * Limited by Tiling and L1 / L0 buffer memory: 512 * ceil(Wo, 16) + (480 * -* stride_h + 32 * filter_h) * ceil(Wi, 16) ≤ l1_size and Hf*Wf ≤ l0b_size/512.\n +* stride_h + 32 * filter_h) * ceil(Wi, 16) ?l1_size and Hf*Wf ?l0b_size/512.\n */ REG_OP(DepthwiseConv2DBackpropFilter) .INPUT(input, TensorType({float16})) @@ -115,7 +115,7 @@ REG_OP(DepthwiseConv2DBackpropFilter) * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n * Limited by Tiling and L1 / L0 buffer memory: 512 * ceil(Wo, 16) + (480 * -* stride_h + 32 * filter_h) * ceil(Wi, 16) ≤ l1_size and Hf*Wf ≤ l0b_size/512.\n +* stride_h + 32 * filter_h) * ceil(Wi, 16) ?l1_size and Hf*Wf ?l0b_size/512.\n */ REG_OP(DepthwiseConv2DBackpropFilterD) .INPUT(input, TensorType({float16})) @@ -170,7 +170,7 @@ REG_OP(DepthwiseConv2DBackpropFilterD) * Output backprop is 4D with shape [N, C, Ho, Wo] or [N, Ho, Wo, C], but the * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n -* Limited by Tiling: max_h_in_l1 ≥ C0, where max_h_in_l1 = (l1_size - Hf * +* Limited by Tiling: max_h_in_l1 ?C0, where max_h_in_l1 = (l1_size - Hf * * Wf * C0 * C0 * 2) / (2 * Wo *C0).\n */ REG_OP(DepthwiseConv2DBackpropInput) @@ -223,7 +223,7 @@ REG_OP(DepthwiseConv2DBackpropInput) * Output backprop is 4D with shape [N, C, Ho, Wo] or [N, Ho, Wo, C], but the * data is 5D with shape [N, C1, Ho, Wo, C0], * where C is the same as that of the feature map and C0 is 16.\n -* Limited by Tiling: max_h_in_l1 ≥ C0, where max_h_in_l1 = (l1_size - Hf * +* Limited by Tiling: max_h_in_l1 ?C0, where max_h_in_l1 = (l1_size - Hf * * Wf * C0 * C0 * 2) / (2 * Wo *C0).\n */ REG_OP(DepthwiseConv2DBackpropInputD) @@ -439,17 +439,13 @@ REG_OP(Conv2DBackpropInputD) * One optional input: * @li bias: An optional tensor of type int8 *@par Attributes: - * Five attributes: + * Three attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window * for H/W dimension. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] * padding on the feature map * @li dilations: A tuple or list of 4 integers. The dilation factor for each * dimension of input. Must be [1, 1, 1, 1]. - * @li groups: Number of blocked connections from input channels to \n - output channels. - * @li data_format: An optional string from: "NHWC", "NCHW". Defaults to "NHWC".\n - Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as "filter". 4D tensor with shape * [batch, height, width, channels] or [batch, channels, height, width]. @@ -462,8 +458,6 @@ REG_OP(Deconvolution) .ATTR(strides, ListInt, {1, 1, 1, 1}) .ATTR(pads, ListInt, {0, 0, 0, 0}) .ATTR(dilations, ListInt, {1, 1, 1, 1}) - .ATTR(groups, Int, 1) - .ATTR(data_format, String, "NHWC") .OP_END_FACTORY_REG(Deconvolution) /** *@brief Computes the gradients of convolution with respect to the filter @@ -637,6 +631,7 @@ REG_OP(Conv2D) *@par Attributes: *@li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". *@li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. +*@li padding_mode: An optional string from: "zeros", "circular". Defaults to "zeros". *@li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. *@li dilations: A list of 5 ints. Specifies the dilation factor for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". @@ -654,6 +649,7 @@ REG_OP(Conv3D) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .ATTR(strides, ListInt, {1, 1, 1, 1, 1}) .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) + .ATTR(padding_mode, String, "zeros") .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3D) @@ -675,7 +671,7 @@ REG_OP(Conv3D) * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as filter,and has same format as input_size -*/ +*/ REG_OP(Conv3DBackpropInput) .INPUT(input_sizes, TensorType({DT_INT32, DT_INT64})) .INPUT(filters, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) @@ -702,7 +698,7 @@ REG_OP(Conv3DBackpropInput) * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. *@par Outputs: * y: A Tensor. Has the same type as filter -*/ +*/ REG_OP(Conv3DBackpropInputD) .INPUT(filters, TensorType({DT_FLOAT16})) .INPUT(grads, TensorType({DT_FLOAT16})) diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index f1d6e420..1d8f0ae5 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -310,358 +310,6 @@ REG_OP(PSROIPooling) .ATTR(spatial_scale, Float, 0.0625) .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) .OP_END_FACTORY_REG(PSROIPooling) - -/** -*@brief Returns detection result. - -*@par Inputs: -* Four inputs, including: -*@li rois: An NCHW tensor of type floa16 or float32, output from operator proposal_d at the preceding layer, used as the input of operator FSRDetectionOutput. -*@li prior_box: An NCHWC0 tensor of type floa16 or float32, specifying the prediction offset, used to update the coordinates [x1, y1, x2, y2] of each ROI. -*@li score: An NCHWC0 tensor of type floa16 or float32, specifying the probability of each class. Class 0 is the background class. -*@li actual_rois_num: An NCHW tensor of type int32, specifying the number of valid boxes per batch. -*@par Attributes: -*@li batch_rois: An optional int32, specifying the number of images to be predicted. Defaults to "1024". The value range is [1, 1024]. -*@li im_info: An optional list of two ints. Defaults to (375, 1024). The value range is [1, 1024]. -*@li num_classes: An optional int32, specifying the number of classes to be predicted. Defaults to "80". The value must be greater than 0. -*@li max_rois_num: An optional int32, specifying the maximum number of ROIs per batch. Defaults to "1024". The value must be a multiple of 16. -*@li score_thresh: An optional float32, specifying the threshold for box filtering. Defaults to 0.45. The value range is [0.0, 1.0]. -*@li nms_thresh: An optional float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. Defaults to 0.7. The value range is (0.0, 1.0). -*@li bbox_reg_weights: An optional list of four ints. Defaults to (1, 1, 1, 1). Must not have value "0". -*@li post_nms_topn: An optional int, specifying the number of output boxes. Defaults to "304". The value must be less than or equal to 1024 and must be a multiple of 16. -*@li kernel_name: An optional string, specifying the operator name. Defaults to "fsr_detection_output". -*@par Outputs: -*box: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*actual_bbox_num: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li totalnum < max_rois_num * batch_rois. -*@li "score" must be with shape (total_num, (num_classes+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. -*@li "prior_box" must be with shape (total_num, (num_classes*4+15)//16, 1, 1, 16), where "total_num" indicates the number of valid input boxes of all images. -*/ -REG_OP(FSRDetectionOutput) - .INPUT(rois, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(prior_box, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(score, TensorType({DT_FLOAT, DT_FLOAT16})) - .INPUT(actual_rois_num, TensorType({DT_INT32})) - .OUTPUT(actual_bbox_num, TensorType({DT_INT32})) - .OUTPUT(box, TensorType({DT_FLOAT, DT_FLOAT16})) - .ATTR(batch_rois, Int, 1024) - .ATTR(im_info, ListInt, {375,1024}) - .ATTR(num_classes, Int, 80) - .ATTR(max_rois_num, Int, 1024) - .ATTR(score_thresh, Float, 0.45) - .ATTR(nms_thresh, Float, 0.7) - .ATTR(bbox_reg_weights, ListInt, {1,1,1,1}) - .ATTR(post_nms_topn, Int, 304) - .OP_END_FACTORY_REG(FSRDetectionOutput) - -/** -*@brief Normalizes data. It is called Region on YOLO v2 and Yolo on YOLO v3. - -*@par Inputs: -*x: An NCHW tensor of type float16 or float32. The data is with shape (N, boxes*(coords+obj+classes), H, W),where, "obj" indicates the confidence of an object, and only one confidence is supported. Boxes are arranged as xx...xyy...yww...whh...hbb...bc0c0..c0c1c1...c1......cncn...cn. - -*@par Attributes: -*@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. -*@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). -*@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. -*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". -*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". -*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". -*@li background: A bool. - -*@par Outputs: -*@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. -*@li obj_prob: A float16 or float32 with shape [N, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the confidence. -*@li classes_prob: A float16 or float32 with shape [N, classes, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the prediction classes. - -*@attention Constraints: -*@li This operator applies to YOLO v2 and v3 networks. -*@li The succeeding layer of the Yolo operator must be operator Yolov3DetectionOutput. -*/ -REG_OP(Yolo) - .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(yolo_version, String, "V3") - .ATTR(softmax, Bool, false) - .ATTR(background, Bool, false) - .ATTR(softmaxtree, Bool, false) - .OP_END_FACTORY_REG(Yolo) - -/** -*@brief Performs YOLO V2 detection. - -*@par Inputs: -* Four inputs, including: -*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov3DetectionOutput. \n -Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. - -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v2 network. -*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. - -*@see Yolo() -*/ -REG_OP(YoloV2DetectionOutput) - .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases, ListFloat) - .ATTR(boxes, Int, 5) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV2DetectionOutput) - -/** -*@brief Performs YOLO V2 detection. - -*@par Inputs: -*Six inputs, including: -*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov2DetectionOutput. \n -Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -*@li windex: A windex tensor with shape [height, weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed. \n - -*@li hindex: A hindex tensor with shape [height, weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]]. \n - -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. - -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -* -*@attention Constraints:\n -*@li This operator applies only to the YOLO v2 network. -*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. - -*@see Yolo() -*/ -REG_OP(YoloV2DetectionOutputD) - .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases, ListFloat) - .ATTR(boxes, Int, 5) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV2DetectionOutputD) - -/** -*@brief Performs YOLO V3 detection. - -*@par Inputs: -*Ten inputs, including: -*@li Operator Yolov3DetectionOutput takes the outputs of operator Yolo as its inputs. A Yolo operator has three outputs: "coords", "obj", and "class". \n -There are three Yolo operators at Yolov3DetectionOutput's preceding layer on Yolo v3. For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -* -*@par Attributes: -*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. - -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. - -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n - -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v3 network. -*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. - -*@see Yolo() -*/ -REG_OP(YoloV3DetectionOutput) - .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases_low, ListFloat) - .REQUIRED_ATTR(biases_mid, ListFloat) - .REQUIRED_ATTR(biases_high, ListFloat) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV3DetectionOutput) - -/** -*@brief Performs YOLO V3 detection. - -*@par Inputs: -*16 Input, including: -*@li The outputs of operator Yolo at the preceding layer (that is, three Yolo operators on YOLO v3) are used as the inputs of operator Yolov3DetectionOutput. \n -A Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. -*@li imginfo: A float16, describing the image information including the required image height and width \n -and the actual image height and width. -*@li windex: A windex tensor with shape [height,weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed for the three Yolo outputs, respectively. - -*@li hindex: A hindex tensor with shape [height,weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]] is formed for the three Yolo outputs, respectively. - -* -*@par Attributes: -*@li biases: A required float32. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" -*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. -*@li coords: Specifies the number of coordinate parameters. Must be 4. -*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. -*@li relative: An optional bool. Defaults to and must be "true". -*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. -*@li post_nms_topn: An optional int32. This attribute is reserved. -*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. -*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n -*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". -* -*@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. - -*@attention Constraints:\n -*@li This operator applies only to the YOLO v3 network. -*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. -*@see Yolo() -*/ -REG_OP(YoloV3DetectionOutputD) - .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex1, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex2, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(windex3, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex1, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex2, TensorType({DT_FLOAT16,DT_FLOAT})) - .INPUT(hindex3, TensorType({DT_FLOAT16,DT_FLOAT})) - .REQUIRED_ATTR(biases_low, ListFloat) - .REQUIRED_ATTR(biases_mid, ListFloat) - .REQUIRED_ATTR(biases_high, ListFloat) - .ATTR(boxes, Int, 3) - .ATTR(coords, Int, 4) - .ATTR(classes, Int, 80) - .ATTR(relative, Bool, true) - .ATTR(obj_threshold, Float, 0.5) - .ATTR(post_nms_topn, Int, 1024) - .ATTR(score_threshold, Float, 0.5) - .ATTR(iou_threshold, Float, 0.45) - .ATTR(pre_nms_topn, Int, 512) - .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) - .OUTPUT(box_out_num, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(YoloV3DetectionOutputD) - -/** -*@brief Spatial Pyramid Pooling, multi-level pooling. -* Pooling out(n, sigma(c*2^i*2^i)) tensor, i in range[0,pyramid_height). - -*@par Inputs: -*x: An NCHW tensor, support float16 or float32 type. - -*@par Attributes: -* @li pyramid_height: An required int32. -* Multi-level pooling out from 2^0 to 2^(pyramid_height-1). -* @li pool_method: An optional int32, pooling method: 0-MAX, 1-AVE. -* Defaults to "0". - -*@par Outputs: -*y: A NCHW tensor, support float16 or float32 type. - -*@attention Constraints: -* @li pyramid_height: pyramid_heigjt should be in range [0,7). -* @li feature_size:input feture map h and w should be [1, 510]. - -*/ -REG_OP(SPP) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) - .REQUIRED_ATTR(pyramid_height, Int) - .ATTR(pool_method, Int, 0) - .OP_END_FACTORY_REG(SPP) - } // namespace ge #endif // GE_OP_NN_DETECT_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 87cc004c..10f3f369 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -487,6 +487,34 @@ REG_OP(Upsample) .ATTR(stride_h, Int, 2) .ATTR(stride_w, Int, 2) .OP_END_FACTORY_REG(Upsample) + +/** +*@brief Spatial Pyramid Pooling, multi-level pooling. +* Pooling out(n, sigma(c*2^i*2^i)) tensor, i in range[0,pyramid_height). + +*@par Inputs: +*x: An NCHW tensor, support float16 or float32 type. + +*@par Attributes: +* @li pyramid_height: An required int32. +* Multi-level pooling out from 2^0 to 2^(pyramid_height-1). +* @li pool_method: An optional int32, pooling method: 0-MAX, 1-AVE. +* Defaults to "0". + +*@par Outputs: +*y: A NCHW tensor, support float16 or float32 type. + +*@attention Constraints: +* @li pyramid_height: pyramid_heigjt should be in range [0,7). +* @li feature_size:input feture map h and w should be [1, 510]. + +*/ +REG_OP(SPP) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16})) + .REQUIRED_ATTR(pyramid_height, Int) + .ATTR(pool_method, Int, 0) + .OP_END_FACTORY_REG(SPP) } // namespace ge #endif // GE_OP_NN_POOLING_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index 88d1a913..d800d075 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -67,57 +67,6 @@ REG_OP(ApplyAdaMax) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdaMax) -/** -*@brief Updates "var" according to the AdaMax algorithm.\n -* t-1 mean previous period. -* m_t <- beta1 * m{t-1} + (1 - beta1) * grad\n -* v_t <- max(beta2 * v{t-1}, abs(grad))\n -* var <- var - lr / (1 - beta1^t) * m_t / (v_t + epsilon) -* -*@attention Constraints:\n -* the input tensors must have the same shape. -* -*@par Inputs: -*@li var: A mutable tensor. Must be one of the following types: TensorType::NumberType(). -* Should be from a Variable(). -*@li m: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li v: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li beta1_power: A scalar. Has the same type as "var". -*@li lr: learning_rate. A scalar. Has the same type as "var". -*@li beta1: A scalar. Has the same type as "var". -*@li beta2: A scalar. Has the same type as "var". -*@li epsilon: A scalar. Has the same type as "var". -*@li grad: A tensor for the gradient. Has the same type as "var". -* -*@par Attributes:\n -* use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", "ms", and "mom" tensors is protected -* by a lock; otherwise the behavior is undefined, but may exhibit less -* contention. -* -*@par Outputs: -* var: A mutable tensor. Has the same type as input "var". -* -* -*/ -REG_OP(ApplyAdaMaxD) - .INPUT(var, TensorType::NumberType()) - .INPUT(m, TensorType::NumberType()) - .INPUT(v, TensorType::NumberType()) - .INPUT(beta1_power, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(beta1, TensorType::NumberType()) - .INPUT(beta2, TensorType::NumberType()) - .INPUT(epsilon, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(m, TensorType::NumberType()) - .OUTPUT(v, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAdaMaxD) - /** *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. @@ -164,8 +113,7 @@ REG_OP(SparseApplyAdagrad) *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False". *@par Outputs: -*@li var: A Tensor. Has the same type and format as input "var". -*@li accum: A Tensor. Has the same type and format as input "var". +*var: A Tensor. Has the same type and format as input "var". */ REG_OP(SparseApplyAdagradD) @@ -174,7 +122,6 @@ REG_OP(SparseApplyAdagradD) .INPUT(grad, TensorType({DT_FLOAT})) .INPUT(indices, TensorType({DT_INT32})) .OUTPUT(var, TensorType({DT_FLOAT})) - .OUTPUT(accum, TensorType({DT_FLOAT})) .REQUIRED_ATTR(lr, Float) .ATTR(use_locking, Bool, false) .ATTR(update_slots, Bool, true) @@ -184,7 +131,7 @@ REG_OP(SparseApplyAdagradD) *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. *@par Inputs: -*Six inputs, including: +* Five inputs, including: *@li var: An NCHW, NHWC, or ND Tensor of type float32. *@li accum: An NCHW, NHWC, or ND Tensor of type float32. *@li lr: An NCHW, NHWC, or ND Tensor of type float32. @@ -216,7 +163,7 @@ REG_OP(SparseApplyAdagradV2) *@brief Updates relevant entries in "var" and "accum" according to the adagrad scheme. *@par Inputs: -*Four inputs, including: +* Four inputs, including: *@li var: An NCHW, NHWC, or ND Tensor of type float32. *@li accum: An NCHW, NHWC, or ND Tensor of type float32. *@li grad: An NCHW, NHWC, or ND Tensor of type float32. @@ -229,8 +176,8 @@ REG_OP(SparseApplyAdagradV2) *@li update_slots: An optional bool. Defaults to "True". If "True", the calcution will be different as "False". *@par Outputs: -*@li var: A Tensor. Has the same type and format as input "var". -*@li accum: A Tensor. Has the same type and format as input "accum". +*var: A Tensor. Has the same type and format as input "var". +*accum: A Tensor. Has the same type and format as input "accum". */ REG_OP(SparseApplyAdagradV2D) @@ -300,51 +247,6 @@ REG_OP(ApplyMomentumCCE) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyMomentumCCE) -/** -*@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you -* want to use Nesterov momentum.\n -* computing process: \n -* accum = accum * momentum + grad\n -* var -= lr * accum -* -*@attention Constraints:\n -* the input tensors must have the same shape. -* -*@par Inputs: -*@li var: A mutable tensor. Should be from a Variable(). -*@li accum: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li lr: A scalar. Has the same type as "var". -*@li grad: A tensor for the gradient. Has the same type as "var". -* -*@par Attributes: -*@li use_nesterov: An optional bool. Defaults to "False". -* If "True", the tensor passed to compute grad will be -* var - lr * momentum * accum, so in the end, the var you get is actually -* var - lr * momentum * accum. -* -*@li use_locking: An optional bool. Defaults to "False".\n -* If "True", updating of the "var", "ms", and "mom" tensors is protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. -* -*@par Outputs: -* var: A mutable tensor. Has the same type as input "var". -* accum: A mutable tensor. Has the same type as input "accum". -* -*/ - -REG_OP(ApplyMomentumD) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(momentum, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .ATTR(use_nesterov, Bool, false) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyMomentumD) - /** *@brief Updates "var" according to the AddSign update.\n * t-1 mean previous period. @@ -387,51 +289,6 @@ REG_OP(ApplyPowerSign) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyPowerSign) -/** -*@brief Updates "var" according to the AddSign update.\n -* t-1 mean previous period. -* m_t <- beta1 * m_{t-1} + (1 - beta1) * grad\n -* update <- exp(logbase * sign_decay * sign(grad) * sign(m_t)) * grad\n -* var <- var - lr * update -* -*@attention Constraints:\n -* the input tensors must have the same shape. -* -*@par Inputs: -*@li var: A mutable tensor. Should be from a Variable(). -*@li m: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li lr: A scalar. Has the same type as "var". -*@li logbase: A scalar. Has the same type as "var". -*@li sign_decay: A scalar. Has the same type as "var". -*@li beta: A scalar. Has the same type as "var". -*@li grad: A tensor for the gradient. Has the same type as "var". -* -*@par Attributes: -* use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", "ms", and "mom" tensors is protected -* by a lock; otherwise the behavior is undefined, but may exhibit less -* contention. -* -*@par Outputs: -*@li var: A mutable tensor. Has the same type as input "var". -*@li m: A mutable tensor. Has the same type as input "var". -* -* -*/ -REG_OP(ApplyPowerSignD) - .INPUT(var, TensorType::NumberType()) - .INPUT(m, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(logbase, TensorType::NumberType()) - .INPUT(sign_decay, TensorType::NumberType()) - .INPUT(beta, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(m, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyPowerSignD) - /** *@brief Updates "var" as FOBOS algorithm with fixed learning rate.\n * prox_v = var - alpha * delta\n @@ -504,46 +361,6 @@ REG_OP(ApplyAddSign) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAddSign) -/** -*@brief Updates "var" according to the AddSign update. - -*@par Inputs: -*Seven inputs, including: -* @li var: A mutable Tensor of type TensorType::NumberType(). -* Should be a Variable Tensor. -* @li m: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. -* @li alpha: A Tensor of the same type as "var". Must be a scalar. -* @li sign_decay: A Tensor of the same type as "var". Must be a scalar. -* @li beta: A Tensor of the same type as "var". Must be a scalar. -* @li grad: A Tensor of the same type as "var", for the gradient. - - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "m" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. - -*@par Outputs: -*@li var: A mutable Tensor. Has the same type as "var". -*@li m: A mutable Tensor. Has the same type as "m". - -*/ -REG_OP(ApplyAddSignD) - .INPUT(var, TensorType::NumberType()) - .INPUT(m, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(alpha, TensorType::NumberType()) - .INPUT(sign_decay, TensorType::NumberType()) - .INPUT(beta, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(m, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAddSignD) - /** *@brief Updates "var" according to the centered RMSProp algorithm.\n * The centered RMSProp algorithm uses an estimate of the centered second moment @@ -600,70 +417,6 @@ REG_OP(ApplyCenteredRMSProp) .OUTPUT(var, TensorType::NumberType()) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyCenteredRMSProp) - -/** -*@brief Updates "var" according to the centered RMSProp algorithm.\n -* The centered RMSProp algorithm uses an estimate of the centered second moment -* (i.e., the variance) for normalization, as opposed to regular RMSProp, which -* uses the (uncentered) second moment. This often helps with training, but is -* slightly more expensive in terms of computation and memory. -* -* t-1 mean previous period. -* mg <- rho * mg{t-1} + (1-rho) * grad\n -* ms <- rho * ms{t-1} + (1-rho) * grad * grad\n -* mom <- momentum * mom{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)\n -* var <- var - mom\n -* -*@attention Constraints:\n -*@li in dense implementation of this algorithm, mg, ms, and mom will -* update even if the grad is zero, but in this sparse implementation, mg, ms, -* and mom will not update in iterations during which the grad is zero. -*@li the input tensors must have the same shape. -* -*@par Inputs: -*@li var: A mutable tensor. Should be from a Variable(). -*@li mg: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li ms: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li mom: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li lr: A scalar. Has the same type as "var". -*@li rho: A scalar. Has the same type as "var". -*@li momentum: A tensor. Has the same type as "var". -*@li epsilon: A scalar. Has the same type as "var". -*@li grad: A tensor for the gradient. Has the same type as "var". -* -*@par Attributes: -* use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", "ms", and "mom" tensors is protected -* by a lock; otherwise the behavior is undefined, but may exhibit less -* contention. -* -*@par Outputs: -*@li var: A mutable Tensor. Has the same type as "var". -*@li mg: A mutable Tensor. Has the same type as "mg". -*@li ms: A mutable Tensor. Has the same type as "ms". -*@li mom: A mutable Tensor. Has the same type as "mom". - -* -*/ -REG_OP(ApplyCenteredRMSPropD) - .INPUT(var, TensorType::NumberType()) - .INPUT(mg, TensorType::NumberType()) - .INPUT(ms, TensorType::NumberType()) - .INPUT(mom, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(rho, TensorType::NumberType()) - .INPUT(momentum, TensorType::NumberType()) - .INPUT(epsilon, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(mg, TensorType::NumberType()) - .OUTPUT(ms, TensorType::NumberType()) - .OUTPUT(mom, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyCenteredRMSPropD) /** *@brief Updates "var" by subtracting 'alpha' * 'delta' from it.\n @@ -730,73 +483,34 @@ REG_OP(ApplyAdagrad) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdagrad) -/** -*@brief Updates "var" according to the adagrad scheme.\n -* accum += grad * grad\n -* var -= lr * grad * (1 / sqrt(accum)) -* -*@attention Constraints:\n -* the input tensors must have the same shape. -* -*@par Inputs: -*@li var: A mutable tensor. Should be from a Variable(). -*@li accum: A mutable tensor. Has the same type as "var". -* Should be from a Variable(). -*@li lr: A scalar. Has the same type as "var". -*@li grad: A tensor for the gradient. Has the same type as "var". -* -*@par Attributes: -* use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", "ms", and "mom" tensors is protected -* by a lock; otherwise the behavior is undefined, but may exhibit less -* contention. -* -*@par Outputs: -*@li var: A mutable tensor. Has the same type as input "var". -*@li accum: A mutable tensor. Has the same type as input "var". -* -* -*/ -REG_OP(ApplyAdagradD) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .ATTR(update_slots, Bool, true) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAdagradD) - /** * @brief Updates "var" according to the adagradv2 scheme.\n * accum += grad * grad \n * var -= lr * grad * (1 / sqrt(accum) + epsilon) * +* @attention Constraints: +* the input tensors must have the same shape. +* * @par Inputs: * @li var: A mutable tensor. Must be one of the data types defined in -* TensorType::NumberType(). Should be from a Variable(). +* TensorType::NumberType(). Should be from a Variable(). * @li accum: A mutable tensor. Has the same type as "var". Should be from a -* Variable(). +* Variable(). * @li lr: A tensor for the learning rate. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li grad: A tensor for the gradient. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li epsilon: A scalar. Has the same type as "var". * * @par Attributes: * @li update_slots: An optional bool. Defaults to "True". -* If "True", accum will be updated +* If "True", accum will be updated * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" tensor is protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. * * @par Outputs: -* var: A mutable tensor. Has the same type as input "var". -* -* @attention Constraints: -* The input tensors must have the same shape. -* +* var: A mutable tensor. Has the same type as input "var". * */ REG_OP(ApplyAdagradV2) @@ -813,33 +527,32 @@ REG_OP(ApplyAdagradV2) /** * @brief Updates "var" according to the adagradv2 scheme.\n -* accum += grad * grad \n -* var -= lr * grad * (1 / sqrt(accum) + epsilon) +* accum += grad * grad \n +* var -= lr * grad * (1 / sqrt(accum) + epsilon) +* +* @attention Constraints: +* the input tensors must have the same shape. * * @par Inputs: * @li var: A mutable tensor. Must be one of the data types defined in -* TensorType::NumberType(). Should be from a Variable(). +* TensorType::NumberType(). Should be from a Variable(). * @li accum: A mutable tensor. Has the same type as "var". Should be from a -* Variable(). +* Variable(). * @li lr: A tensor for the learning rate. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * @li grad: A tensor for the gradient. Has the same type as "var". Should be -* from a Variable(). +* from a Variable(). * * @par Attributes: * @li epsilon: A scalar. Has the same type as "var". * @li update_slots: An optional bool. Defaults to "True". -* If "True", accum will be updated +* If "True", accum will be updated * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" tensor is protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. +* If "True", updating of the "var" tensor is protected by a lock; +* otherwise the behavior is undefined, but may exhibit less contention. * * @par Outputs: -* var: A mutable tensor. Has the same type as input "var". -* -* @attention Constraints: -* The input tensors must have the same shape. -* +* var: A mutable tensor. Has the same type as input "var". * */ REG_OP(ApplyAdagradV2D) @@ -897,54 +610,6 @@ REG_OP(ApplyAdagradDA) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdagradDA) -/** -*@brief Updates "var" according to the proximal adagrad scheme. - -*@par Inputs: -*Eight inputs, including: -* @li var: A mutable Tensor. Must be one of the following types: -* TensorType::NumberType(). Should be a Variable Tensor. -* @li gradient_accumulator: A mutable Tensor. Must have the same -* type as "var". Should be a Variable Tensor. -* @li gradient_squared_accumulator: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li grad: A Tensor of the same type as "var", for the gradient. -* @li lr: A Tensor of the same type as "var". -* Scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var". -* L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var". -* L2 regulariation. Must be a scalar. -* @li global_step: A Tensor of type int32 or int64. -* Training step number. Must be a scalar. - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False". -* If "True", updating of the var and accum tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. - -*@par Outputs: -*var: A mutable Tensor. Has the same type as "var". -*gradient_accumulator: A mutable Tensor. Has the same type as "var". -*gradient_squared_accumulator: A mutable Tensor. Has the same type as "var". - -*/ -REG_OP(ApplyAdagradDAD) - .INPUT(var, TensorType::NumberType()) - .INPUT(gradient_accumulator, TensorType::NumberType()) - .INPUT(gradient_squared_accumulator, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(global_step, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(gradient_accumulator, TensorType::NumberType()) - .OUTPUT(gradient_squared_accumulator, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAdagradDAD) - /** *@brief Returns the dimension index in the destination data format given the one in * the source data format. @@ -1133,9 +798,7 @@ REG_OP(ApplyRMSPropD) *use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention. *@par Outputs: -* @li var: A mutable tensor. Must have the same type as input "var". -* @li ms: A mutable tensor. Must have the same type as input "ms". -* @li mom: A mutable tensor. Must have the same type as input "mom". +*var: A mutable Tensor. Has the same type as "var". */ REG_OP(ApplyProximalAdagrad) .INPUT(var, TensorType::NumberType()) @@ -1149,213 +812,52 @@ REG_OP(ApplyProximalAdagrad) .OP_END_FACTORY_REG(ApplyProximalAdagrad) /** -*@brief Update "var" and "accum" according to FOBOS with Adagrad learning rate. - -*@par Inputs: -*Six inputs, including: -* @li var: A mutable Tensor of type TensorType::NumberType(). -* Should be from a Variable(). -* @li accum: A mutable Tensor of the same type as "var". Should be from a Variable(). -* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. -* @li grad: A Tensor of the same type as "var", for the gradient. - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False". If "True", updating of the "var" and "accum" *tensors will be protected by a lock; otherwise the behavior is undefined, but may exhibit less *contention. - -*@par Outputs: -* @li var: A mutable Tensor. Has the same type as "var". -* @li accum: A mutable Tensor. Has the same type as "var". - -*/ -REG_OP(ApplyProximalAdagradD) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyProximalAdagradD) - -/** -*@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n -* Compared with op ApplyProximalAdagrad, an additional index tensor is input, -* Only the indices into the first dimensions of "var" and "accum" are updated. - -*@par Inputs: -* Seven inputs, including:\n -* @li var: A mutable Tensor.\n -* TensorType::NumberType(). Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var".\n -* Should be a Variable Tensor. -* @li lr: A Tensor of the same type as "var".\n -* Scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var".\n -* L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var".\n -* L2 regulariation. Must be a scalar. -* @li grad: A Tensor. Has the same type as "var". \n -* The gradient. -* @li indices: A vector of indices into the first dimension of "var" and "accum".\n -* TensorType::IndexNumberType(). - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False".\n -* If "True", updating of the var and accum tensors will be protected by a lock; \n -* If "False", the behavior is undefined, but may exhibit less contention. - -*@par Outputs: -*var: A mutable Tensor. Has the same type as "var". -*/ -REG_OP(SparseApplyProximalAdagrad) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(indices, TensorType::IndexNumberType()) - .OUTPUT(var, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(SparseApplyProximalAdagrad) - -/** -*@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n -* Compared with op ApplyProximalAdagrad, an additional index tensor is input, -* Only the indices into the first dimensions of "var" and "accum" are updated. - -*@par Inputs: -* Seven inputs, including:\n -* @li var: A mutable Tensor.\n -* TensorType::NumberType(). Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var".\n -* Should be a Variable Tensor. -* @li lr: A Tensor of the same type as "var".\n -* Scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var".\n -* L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var".\n -* L2 regulariation. Must be a scalar. -* @li grad: A Tensor. Has the same type as "var". \n -* The gradient. -* @li indices: A vector of indices into the first dimension of "var" and "accum".\n -* TensorType::IndexNumberType(). - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False".\n -* If "True", updating of the var and accum tensors will be protected by a lock; \n -* If "False", the behavior is undefined, but may exhibit less contention. - -*@par Outputs: -*@li var: A mutable Tensor. Has the same type as "var". -*@li accum: A mutable Tensor. Has the same type as "var". - -*/ -REG_OP(SparseApplyProximalAdagradD) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(indices, TensorType::IndexNumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(SparseApplyProximalAdagradD) - -/** -*@brief Updates "var" according to the Ftrl-proximal scheme. - -*@par Inputs: -*Eight inputs, including: -* @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li grad: A Tensor of the same type as "var", for the gradient. -* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. -* @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. - -*@par Outputs: -*var: A mutable Tensor. Has the same type as "var". -*/ -REG_OP(ApplyFtrl) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(linear, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(l1, TensorType::NumberType()) - .INPUT(l2, TensorType::NumberType()) - .INPUT(lr_power, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyFtrl) - -/** -*@brief Updates "var" according to the Ftrl-proximal scheme. +*@brief Updates entries in 'var' and 'accum' according to the Proximal Adagrad algorithm.\ n +* Compared with op ApplyProximalAdagrad, an additional index tensor is input, +* Only the indices into the first dimensions of "var" and "accum" are updated. *@par Inputs: -*Eight inputs, including: -* @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li linear: A mutable Tensor of the same type as "var". +* Seven inputs, including:\n +* @li var: A mutable Tensor.\n +* TensorType::NumberType(). Should be a Variable Tensor. +* @li accum: A mutable Tensor of the same type as "var".\n * Should be a Variable Tensor. -* @li grad: A Tensor of the same type as "var", for the gradient. -* @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. -* @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. -* @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. -* @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. +* @li lr: A Tensor of the same type as "var".\n +* Scaling factor. Must be a scalar. +* @li l1: A Tensor of the same type as "var".\n +* L1 regulariation. Must be a scalar. +* @li l2: A Tensor of the same type as "var".\n +* L2 regulariation. Must be a scalar. +* @li grad: A Tensor. Has the same type as "var". \n +* The gradient. +* @li indices: A vector of indices into the first dimension of "var" and "accum".\n +* TensorType::IndexNumberType(). *@par Attributes: -*use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +*use_locking: An optional bool. Defaults to "False".\n +* If "True", updating of the var and accum tensors will be protected by a lock; \n +* If "False", the behavior is undefined, but may exhibit less contention. *@par Outputs: -*@li var: A mutable Tensor. Has the same type as "var". -*@li accum: A mutable Tensor. Has the same type as "accum". -*@li linear: A mutable Tensor. Has the same type as "linear". - +*var: A mutable Tensor. Has the same type as "var". */ -REG_OP(ApplyFtrlD) +REG_OP(SparseApplyProximalAdagrad) .INPUT(var, TensorType::NumberType()) .INPUT(accum, TensorType::NumberType()) - .INPUT(linear, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) .INPUT(lr, TensorType::NumberType()) .INPUT(l1, TensorType::NumberType()) .INPUT(l2, TensorType::NumberType()) - .INPUT(lr_power, TensorType::NumberType()) + .INPUT(grad, TensorType::NumberType()) + .INPUT(indices, TensorType::IndexNumberType()) .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .OUTPUT(linear, TensorType::NumberType()) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyFtrlD) + .OP_END_FACTORY_REG(SparseApplyProximalAdagrad) /** -*@brief Update "var" according to the Ftrl-proximal scheme. +*@brief Updates "var" according to the Ftrl-proximal scheme. *@par Inputs: -*Nine inputs, including: +*Eight inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). * Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". @@ -1366,7 +868,6 @@ REG_OP(ApplyFtrlD) * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. * @li l1: A Tensor of the same type as "var", for L1 regulariation. Must be a scalar. * @li l2: A Tensor of the same type as "var", for L2 regulariation. Must be a scalar. -* @li l2_shrinkage: A Tensor of the same type as "var". * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. *@par Attributes: @@ -1378,7 +879,7 @@ REG_OP(ApplyFtrlD) *@par Outputs: *var: A mutable Tensor. Has the same type as "var". */ -REG_OP(ApplyFtrlV2) +REG_OP(ApplyFtrl) .INPUT(var, TensorType::NumberType()) .INPUT(accum, TensorType::NumberType()) .INPUT(linear, TensorType::NumberType()) @@ -1386,11 +887,10 @@ REG_OP(ApplyFtrlV2) .INPUT(lr, TensorType::NumberType()) .INPUT(l1, TensorType::NumberType()) .INPUT(l2, TensorType::NumberType()) - .INPUT(l2_shrinkage, TensorType::NumberType()) .INPUT(lr_power, TensorType::NumberType()) .OUTPUT(var, TensorType::NumberType()) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyFtrlV2) + .OP_END_FACTORY_REG(ApplyFtrl) /** *@brief Update "var" according to the Ftrl-proximal scheme. @@ -1418,11 +918,9 @@ REG_OP(ApplyFtrlV2) *@par Outputs: *var: A mutable Tensor. Has the same type as "var". -*accum: A mutable Tensor. Has the same type as "accum". -*linear: A mutable Tensor. Has the same type as "linear". */ -REG_OP(ApplyFtrlV2D) +REG_OP(ApplyFtrlV2) .INPUT(var, TensorType::NumberType()) .INPUT(accum, TensorType::NumberType()) .INPUT(linear, TensorType::NumberType()) @@ -1433,10 +931,8 @@ REG_OP(ApplyFtrlV2D) .INPUT(l2_shrinkage, TensorType::NumberType()) .INPUT(lr_power, TensorType::NumberType()) .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .OUTPUT(linear, TensorType::NumberType()) .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyFtrlV2D) + .OP_END_FACTORY_REG(ApplyFtrlV2) /** *@brief Updates "var" according to the Adam algorithm.\n @@ -1475,61 +971,6 @@ REG_OP(ApplyFtrlV2D) * var: A mutable Tensor. Has the same type as intput "var". */ REG_OP(ApplyAdam) - .INPUT(var, TensorType::NumberType()) - .INPUT(m, TensorType::NumberType()) - .INPUT(v, TensorType::NumberType()) - .INPUT(beta1_power, TensorType::NumberType()) - .INPUT(beta2_power, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(beta1, TensorType::NumberType()) - .INPUT(beta2, TensorType::NumberType()) - .INPUT(epsilon, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .ATTR(use_nesterov, Bool, false) - .OP_END_FACTORY_REG(ApplyAdam) - -/** -*@brief Updates "var" according to the Adam algorithm.\n -* lr_t <- text{learning\_rate} * sqrt{1 - beta_2^t} / (1 - beta_1^t)\n -* m_t <- beta_1 * m_{t-1} + (1 - beta_1) * g\n -* v_t <- max(beta2 * v{t-1}, abs(g))\n -* variable <- variable - lr_t * m_t / (sqrt{v_t} + epsilon) -* -*@attention Constraints:\n -* *The input tensors must have the same shape.* -* -*@par Inputs: -*@li var: A mutable Tensor of the type TensorType::NumberType(). -* Should be from a Variable(). -*@li m: A mutable Tensor of the same type as "var". -* Should be from a Variable(). -*@li v: A mutable Tensor of the same type as "var". -* Should be from a Variable(). -*@li beta1_power: A scalar of the same type as "var". -*@li beta2_power: A scalar of the same type as "var". -*@li lr: learning_rate. A scalar of the same type as "var". -*@li beta1: A scalar of the same type as "var". -*@li beta2: A scalar of the same type as "var". -*@li epsilon: A scalar of the same type as "var". -*@li grad: A Tensor of the same type as "var", for the gradient. -* -*@par Attributes:\n -*@li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", m", and "v" tensors will be protected -* by a lock; otherwise the behavior is undefined, but may exhibit less -* contention. -*@li use_nesterov: An optional bool. Defaults to "False". - If "True", uses the nesterov update. -* -*@par Outputs: -*@li var: A mutable tensor. Has the same type as input "var". -*@li m: A mutable tensor. Has the same type as input "m". -*@li v: A mutable tensor. Has the same type as input "v". - -*/ -REG_OP(ApplyAdamD) .INPUT(var, TensorType::NumberType()) .INPUT(m, TensorType::NumberType()) .INPUT(v, TensorType::NumberType()) @@ -1545,7 +986,7 @@ REG_OP(ApplyAdamD) .OUTPUT(v, TensorType::NumberType()) .ATTR(use_locking, Bool, false) .ATTR(use_nesterov, Bool, false) - .OP_END_FACTORY_REG(ApplyAdamD) + .OP_END_FACTORY_REG(ApplyAdam) /** *@brief Updates "var" according to the proximal adadelta scheme. @@ -1584,48 +1025,6 @@ REG_OP(ApplyAdadelta) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyAdadelta) -/** -*@brief Updates "var" according to the proximal adadelta scheme. - -*@par Inputs: -*Seven inputs, including: -* @li var: A mutable Tensor of type TensorType::NumberType(). -* Should be a Variable Tensor. -* @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li accum_update: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. -* @li lr: A scalar of the same type as "var", for the scaling factor. -* @li rho: A scalar of the same type as "var", for the decay factor. -* @li epsilon: A scalar of the same type as "var", for the constant factor. -* @li grad: A Tensor of the same type as "var", for the gradient. - -*@par Attributes: -*use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var", "accum" and "accum_update" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. - -*@par Outputs: -*@li var: A mutable Tensor. Has the same type as "var". -*@li accum: A mutable Tensor. Has the same type as "var". -*@li accum_update: A mutable Tensor. Has the same type as "var". - -*/ -REG_OP(ApplyAdadeltaD) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(accum_update, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(rho, TensorType::NumberType()) - .INPUT(epsilon, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(accum, TensorType::NumberType()) - .OUTPUT(accum_update, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyAdadeltaD) - /** * @brief Updates "var" according to the ApplyMomentum algorithm. \n * accum = accum * momentum + x1 * x2 \n @@ -1794,11 +1193,11 @@ REG_OP(LarsV2Update) * @par Inputs: * Nine inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. @@ -1808,9 +1207,9 @@ REG_OP(LarsV2Update) * @par Attributes: * use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: * var: A Tensor. Has the same type and format as input "var". @@ -1834,13 +1233,13 @@ REG_OP(SparseApplyFtrl) * @brief Update relevant entries in '*var' according to the Ftrl-proximal scheme. * @par Inputs: -* Five inputs, including: +* Nine inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. * @li lr: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. @@ -1850,14 +1249,14 @@ REG_OP(SparseApplyFtrl) * @par Attributes: * use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* protected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* protected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: -* @li var: A Tensor. Has the same type and format as input "var". -* @li accum: A Tensor. Has the same type and format as input "accum". -* @li linear: A Tensor. Has the same type and format as input "linear". +* var: A Tensor. Has the same type and format as input "var". +* accum: A Tensor. Has the same type and format as input "accum". +* linear: A Tensor. Has the same type and format as input "linear". */ REG_OP(SparseApplyFtrlD) @@ -1926,13 +1325,13 @@ REG_OP(SparseApplyFtrlV2) * That is for rows we have grad for, we update var, accum and linear * @par Inputs: -* Five inputs, including: +* Ten inputs, including: * @li var: A mutable Tensor. Must be of type TensorType::NumberType(). -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li accum: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li linear: A mutable Tensor of the same type as "var". -* Should be a Variable Tensor. +* Should be a Variable Tensor. * @li grad: A Tensor of the same type as "var", for the gradient. * @li indices: A vector of indices into the first dimension of var and accum. @@ -1943,14 +1342,14 @@ REG_OP(SparseApplyFtrlV2) * @li l2_shrinkage: A Tensor of the same type as "var", L2 shrinkage regulariation. Must be a scalar. * @li lr_power: A Tensor of the same type as "var", for the scaling factor. Must be a scalar. * @li use_locking: An optional bool. Defaults to "False". -* If "True", updating of the "var" and "accum" tensors will be -* rotected by a lock; otherwise the behavior is undefined, -* but may exhibit less contention. +* If "True", updating of the "var" and "accum" tensors will be +* rotected by a lock; otherwise the behavior is undefined, +* but may exhibit less contention. * @par Outputs: -* @li var: A Tensor. Has the same type and format as input "var". -* @li accum: A Tensor. Has the same type and format as input "accum". -* @li linear: A Tensor. Has the same type and format as input "linear". +* var: A Tensor. Has the same type and format as input "var". +* accum: A Tensor. Has the same type and format as input "accum". +* linear: A Tensor. Has the same type and format as input "linear". */ REG_OP(SparseApplyFtrlV2D) @@ -1970,109 +1369,6 @@ REG_OP(SparseApplyFtrlV2D) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(SparseApplyFtrlV2D) -/** -* @brief Updates "var" in specified index according to the RMSProp algorithm. -* mean_square = decay * mean_square + (1-decay) * gradient ** 2\n -* Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n -* ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n -* mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n -* var <- var - mom\n -* -* @par Inputs: -* @li var: A mutable tensor. Must be one of the data types defined in\n -* TensorType::NumberType(). Should be from a Variable(). -* @li ms: A mutable tensor. Must have the same type as "var". Should be from a -* Variable(). -* @li mom: A mutable tensor. Must have the same type as "var". Should be from a -* Variable(). -* @li lr: A scalar. Must have the same type as "var". -* @li rho: A scalar. Must have the same type as "var". -* @li momentum: A scalar. Must have the same type as "var". -* @li epsilon: A scalar. Must have the same type as "var". -* @li grad: A tensor, specifying the gradient. -* @li indices: A vector of indices into the first dimension of var, mom and ms. -* -* @par Attributes: -* use_locking: An optional "bool". Defaults to "False". If "True", updating of -* the "var", "ms", and "mom" tensors will be protected by a lock; otherwise the -* behavior is undefined, but may exhibit less contention. -* -* @par Outputs: -* var: A mutable tensor. Has the same type as input "var". -* -* @attention Constraints: -* @li Note that in this sparse implementation, "ms" and "mom" will not update -* in iterations during which "grad" is 0. -* @li The input tensors "var", "ms", "mom" must have the same shape. -* -*/ -REG_OP(SparseApplyRMSProp) - .INPUT(var, TensorType::NumberType()) - .INPUT(ms, TensorType::NumberType()) - .INPUT(mom, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(rho, TensorType::NumberType()) - .INPUT(momentum, TensorType::NumberType()) - .INPUT(epsilon, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(indices, TensorType::IndexNumberType()) - .OUTPUT(var, TensorType::NumberType()) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(SparseApplyRMSProp) - -/** -* @brief Updates "var" in specified index according to the RMSProp algorithm. -* a const input will be considered as an attribute.\n -* mean_square = decay * mean_square + (1-decay) * gradient ** 2\n -* Delta = learning_rate * gradient / sqrt(mean_square + epsilon)\n -* ms <- rho * ms_{t-1} + (1-rho) * grad * grad\n -* mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)\n -* var <- var - mom -* -* @par Inputs: -* @li var: A mutable tensor. Must be one of the data types defined in -* TensorType::NumberType(). Should be from a Variable(). -* @li ms: A mutable tensor. Must have the same type as "var". Should be from a -* Variable(). -* @li mom: A mutable tensor. Must have the same type as "var". Should be from a -* Variable(). -* @li lr: A scalar. Must have the same type as "var". -* @li grad: A tensor, specifying the gradient. -* -* @par Attributes: -* @li use_locking: An optional "bool". Defaults to "False". If "True", -* updating of the "var", "ms", and "mom" tensors will be protected by a lock; -* otherwise the behavior is undefined, but may exhibit less contention. -* @li rho: A required scalar. Must have the same type as "var". -* @li momentum: A required scalar. Must have the same type as "var". -* @li epsilon: A required scalar. Must have the same type as "var". -* -* @par Outputs: -* @li var: A mutable tensor. Must have the same type as input "var". -* @li ms: A mutable tensor. Must have the same type as input "ms". -* @li mom: A mutable tensor. Must have the same type as input "mom". -* -* @attention Constraints: -* @li Note that in this sparse implementation, "ms" and "mom" will not update -* in iterations during which "grad" is 0. -* @li The input tensors "var", "ms" and "mom" must have the same shape. -*/ -REG_OP(SparseApplyRMSPropD) - .INPUT(var, TensorType::NumberType()) - .INPUT(ms, TensorType::NumberType()) - .INPUT(mom, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(indices, TensorType::IndexNumberType()) - .OUTPUT(var, TensorType::NumberType()) - .OUTPUT(ms, TensorType::NumberType()) - .OUTPUT(mom, TensorType::NumberType()) - .REQUIRED_ATTR(rho, Float) - .REQUIRED_ATTR(momentum, Float) - .REQUIRED_ATTR(epsilon, Float) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(SparseApplyRMSPropD) - /** *@brief Clean memory of workspace list. diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index 46d29b8d..992077ad 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -172,6 +172,24 @@ REG_OP(SigmoidGrad) .OUTPUT(z, TensorType(UnaryDataType)) .OP_END_FACTORY_REG(SigmoidGrad) +REG_OP(Activation) + .INPUT(x, TensorType::ALL()) + .OUTPUT(y, TensorType::ALL()) + /* + 0: sigmod, 1: relu, 2: tanh, 3: clipped ReLU, 4: Elu, + 5: leaky relu, 6: abs, 7: relu1, 8: softsign, 9: softplus + */ + .ATTR(mode, Int, 1) + .ATTR(coef, Float, 0) + .OP_END_FACTORY_REG(Activation) + +REG_OP(ActivationGrad) + .INPUT(dy, TensorType{DT_FLOAT}) + .INPUT(x, TensorType{DT_FLOAT}) + .OUTPUT(dx, TensorType{DT_FLOAT}) + .ATTR(mode, Int, 1) + .OP_END_FACTORY_REG(ActivationGrad) + /** *@brief Computes the binomial normal log likelihood (BNLL) output:\n *if x>0, x+log(1+exp(-x)); otherwise log(1+exp(x)). diff --git a/third_party/fwkacllib/inc/ops/power_ops.h b/third_party/fwkacllib/inc/ops/power_ops.h new file mode 100644 index 00000000..b1f5bc24 --- /dev/null +++ b/third_party/fwkacllib/inc/ops/power_ops.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #ifndef GE_OP_POWER_H + #define GE_OP_POWER_H + + #include "../graph/operator_reg.h" + + namespace ge { + +/** +*@brief Computes the output as (shift + scale * x) ^ power. + +*@par Inputs: +* x: A Tensor of type float16 or float32. + +*@par Attributes: +*@li power: Optional. Defaults to 1.0. +*@li scale: Optional. Defaults to 1.0. +*@li shift: Optional. Defaults to 0.0. + +*@par Outputs: +* y: A Tensor. Has the same type and shape as "x". +*/ + + REG_OP(Power) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(power, Float, 1.0) + .ATTR(scale, Float, 1.0) + .ATTR(shift, Float, 0.0) + .OP_END_FACTORY_REG(Power); + + } // namespace ge + + #endif // GE_OP_POWER_H diff --git a/third_party/fwkacllib/inc/ops/quantize_ops.h b/third_party/fwkacllib/inc/ops/quantize_ops.h index e44ae888..235f2645 100644 --- a/third_party/fwkacllib/inc/ops/quantize_ops.h +++ b/third_party/fwkacllib/inc/ops/quantize_ops.h @@ -19,6 +19,22 @@ #include "../graph/operator_reg.h" namespace ge { +REG_OP(QuantizedInnerProduct) + .INPUT(x, TensorType({DT_UINT8})) + .INPUT(w, TensorType({DT_INT8})) + .OPTIONAL_INPUT(b, TensorType({DT_INT32})) + .OPTIONAL_INPUT(scale_q, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(offset_q, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(scale_deq_req, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(offset_req, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(quant_algo, ListInt) + .REQUIRED_ATTR(scale_sqrt, ListInt) + .REQUIRED_ATTR(num_output, Int) + .ATTR(transpose, Bool, false) + .ATTR(bias_term, Bool, false) + .ATTR(axis, Int, 1) + .OP_END_FACTORY_REG(QuantizedInnerProduct) /** * @brief Dequantizes the input tensor into a float tensor.\n diff --git a/third_party/fwkacllib/inc/ops/ragged_array_ops.h b/third_party/fwkacllib/inc/ops/ragged_array_ops.h index 4f3cf97e..245f3551 100644 --- a/third_party/fwkacllib/inc/ops/ragged_array_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_array_ops.h @@ -45,10 +45,12 @@ namespace ge { REG_OP(RaggedGather) .DYNAMIC_INPUT(params_nested_splits, TensorType({DT_INT32, DT_INT64})) - .INPUT(params_dense_values, TensorType({DT_INT32, DT_INT64})) + .INPUT(params_dense_values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) .INPUT(indices, TensorType({DT_INT32, DT_INT64})) .DYNAMIC_OUTPUT(output_nested_splits, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(output_dense_values, TensorType({DT_INT32, DT_INT64})) + .OUTPUT(output_dense_values, TensorType({DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT8, DT_UINT8, DT_INT16, \ + DT_UINT16, DT_INT32, DT_UINT32, DT_INT64, DT_UINT64, DT_BOOL})) .REQUIRED_ATTR(Tsplits, Type) .ATTR(PARAMS_RAGGED_RANK, Int, 1) .ATTR(OUTPUT_RAGGED_RANK, Int, 0) diff --git a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h index 7a42e4d9..8e07bdc5 100644 --- a/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_conversion_ops.h @@ -50,43 +50,5 @@ REG_OP(RaggedTensorToSparse) .ATTR(RAGGED_RANK, Int, 1) .ATTR(Tsplits, Type, DT_INT64) .OP_END_FACTORY_REG(RaggedTensorToSparse) - -/** -*@brief Create a dense tensor from a ragged tensor, possibly altering its shape. - -*@par Inputs: -*Six inputs, including: -*@li shape:A `Tensor`. Must be one of the following types: `int64`, `int32`. -*@li values:A 1D tensor representing the values of the ragged tensor. -*@li default_value:A `Tensor`. Must have the same type as `values`. -*@li row_partition_tensors:A list of at least 1 `Tensor` objects with the same \n -type in: `int64`, `int32`. - -*@par Attributes: -*@li num_row_partition_tensors:Numbers of row partition tensors. -*@li row_partition_types: A list of `strings`. \n -The types of the row partition tensors. At present, these can be: \n -* "ROW_SPLITS": the row_splits tensor from the ragged tensor. \n -* "VALUE_ROWIDS": the value_rowids tensor from the ragged tensor. \n -* "FIRST_DIM_SIZE": if value_rowids is used for the first dimension, then it \n -is preceeded by "FIRST_DIM_SIZE". - -*@par Outputs: -*@li result: A `Tensor`. Has the same type as `values`. -*/ -REG_OP(RaggedTensorToTensor) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .INPUT(values, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, - DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) - .INPUT(default_value, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, - DT_UINT16, DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) - .DYNAMIC_INPUT(row_partition_tensors, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(result, TensorType({DT_BOOL, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, - DT_INT32, DT_INT64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16})) - .REQUIRED_ATTR(num_row_partition_tensors, Int) - .REQUIRED_ATTR(row_partition_types, ListString) - .OP_END_FACTORY_REG(RaggedTensorToTensor) - - } // namespace ge #endif // GE_OP_RAGGED_CONVERSION_OPS_H \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/ragged_math_ops.h b/third_party/fwkacllib/inc/ops/ragged_math_ops.h index 80669f0f..51797ff8 100644 --- a/third_party/fwkacllib/inc/ops/ragged_math_ops.h +++ b/third_party/fwkacllib/inc/ops/ragged_math_ops.h @@ -41,11 +41,11 @@ namespace ge { */ REG_OP(RaggedRange) - .INPUT(starts, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) - .INPUT(limits, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) - .INPUT(deltas, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(starts, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(limits, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .INPUT(deltas, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) .OUTPUT(rt_nested_splits, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(rt_dense_values, TensorType({DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) + .OUTPUT(rt_dense_values, TensorType({DT_FLOAT16,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64})) .REQUIRED_ATTR(Tsplits, Type) .OP_END_FACTORY_REG(RaggedRange) diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index 7a6aaa9e..abd98695 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -180,15 +180,15 @@ REG_OP(RNN) .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(w_xh, TensorType({DT_FLOAT16})) - .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(w_sh, TensorType({DT_FLOAT16})) .INPUT(w_hh, TensorType({DT_FLOAT16})) .INPUT(w_ho, TensorType({DT_FLOAT16})) + .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(bias_o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(num_output, Int, 0) .ATTR(expose_hidden, Bool, false) + .ATTR(num_output, Int, 0) .OP_END_FACTORY_REG(RNN) /** @@ -220,9 +220,9 @@ REG_OP(BasicRNNCell) .OPTIONAL_INPUT(w_xh_x_static, TensorType({DT_FLOAT16, DT_FLOAT})) .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(w_xh, TensorType({DT_FLOAT16})) - .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .OPTIONAL_INPUT(w_hh, TensorType({DT_FLOAT16})) .INPUT(w_ho, TensorType({DT_FLOAT16})) + .INPUT(bias_h, TensorType({DT_FLOAT16, DT_FLOAT})) .INPUT(bias_o, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(o_t, TensorType({DT_FLOAT16, DT_FLOAT})) .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/sdca_ops.h b/third_party/fwkacllib/inc/ops/sdca_ops.h index 15428d2b..3f1e938a 100644 --- a/third_party/fwkacllib/inc/ops/sdca_ops.h +++ b/third_party/fwkacllib/inc/ops/sdca_ops.h @@ -64,7 +64,7 @@ REG_OP(SdcaOptimizerV2) .INPUT(example_weights, TensorType({DT_FLOAT})) .INPUT(example_labels, TensorType({DT_FLOAT})) .DYNAMIC_INPUT(sparse_indices, TensorType({DT_INT64})) - .DYNAMIC_INPUT(sparse_weights, TensorType({DT_FLOAT})) + .DYNAMIC_INPUT(sparse_weights, TensorType({DT_INT64})) .DYNAMIC_INPUT(dense_weights, TensorType({DT_FLOAT})) .INPUT(example_state_data, TensorType({DT_FLOAT})) .OUTPUT(out_example_state_data, TensorType({DT_FLOAT})) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index c7b59caa..dab71025 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -240,7 +240,7 @@ REG_OP(GatherV2D) REG_OP(StridedSlice) .INPUT(x, TensorType::BasicType()) .INPUT(begin, TensorType::IndexNumberType()) - .INPUT(end, TensorType::IndexNumberType()) + .INPUT(end, TensorType::IndexNumberTypeT()) .INPUT(strides, TensorType::IndexNumberType()) .ATTR(begin_mask, Int, 0) .ATTR(end_mask, Int, 0) @@ -571,7 +571,7 @@ REG_OP(SegmentMax) *@par Outputs: *y:A Tensor with same type as "x". -*/ +*/ REG_OP(SegmentMaxD) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) @@ -703,7 +703,6 @@ REG_OP(SliceD) * @attention Constraints: * @li k =< 4096 * @li Size of the last dimension =< 65500 -* @li sorted = true */ REG_OP(TopKD) .INPUT(x, TensorType::RealNumberType()) @@ -1309,6 +1308,174 @@ REG_OP(UnsortedSegmentProdD) .REQUIRED_ATTR(num_segments, Int) .OP_END_FACTORY_REG(UnsortedSegmentProdD) +/** +*@brief Normalizes data. It is called Region on YOLO v2 and Yolo on YOLO v3. + +*@par Inputs: +*x: An NCHW tensor of type float16 or float32. The data is with shape (N, boxes*(coords+obj+classes), H, W),where, "obj" indicates the confidence of an object, and only one confidence is supported. Boxes are arranged as xx...xyy...yww...whh...hbb...bc0c0..c0c1c1...c1......cncn...cn. + +*@par Attributes: +*@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. +*@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). +*@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. +*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". +*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". +*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". +*@li background: A bool. + +*@par Outputs: +*@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. +*@li obj_prob: A float16 or float32 with shape [N, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the confidence. +*@li classes_prob: A float16 or float32 with shape [N, classes, ceilx(boxes*height*width *2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the prediction classes. + +*@attention Constraints: +*@li This operator applies to YOLO v2 and v3 networks. +*@li The succeeding layer of the Yolo operator must be operator Yolov3DetectionOutput. +*/ +REG_OP(Yolo) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(yolo_version, String, "V3") + .ATTR(softmax, Bool, false) + .ATTR(background, Bool, false) + .ATTR(softmaxtree, Bool, false) + .OP_END_FACTORY_REG(Yolo) + +/** +*@brief Performs YOLO V3 detection. + +*@par Inputs: +*Ten inputs, including: +*@li Operator Yolov3DetectionOutput takes the outputs of operator Yolo as its inputs. A Yolo operator has three outputs: "coords", "obj", and "class". \n +There are three Yolo operators at Yolov3DetectionOutput's preceding layer on Yolo v3. For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. + +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. + +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n + +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v3 network. +*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. + +*@see Yolo() +*/ +REG_OP(YoloV3DetectionOutput) + .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases_low, ListFloat) + .REQUIRED_ATTR(biases_mid, ListFloat) + .REQUIRED_ATTR(biases_high, ListFloat) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 1024) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV3DetectionOutput) + +/** +*@brief Performs YOLO V3 detection. + +*@par Inputs: +*16 Input, including: +*@li The outputs of operator Yolo at the preceding layer (that is, three Yolo operators on YOLO v3) are used as the inputs of operator Yolov3DetectionOutput. \n +A Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +*@li windex: A windex tensor with shape [height,weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed for the three Yolo outputs, respectively. + +*@li hindex: A hindex tensor with shape [height,weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]] is formed for the three Yolo outputs, respectively. + +* +*@par Attributes: +*@li biases: A required float32. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v3 network. +*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. +*@see Yolo() +*/ +REG_OP(YoloV3DetectionOutputD) + .INPUT(coord_data_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(coord_data_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_low, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_mid, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob_high, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex3, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases_low, ListFloat) + .REQUIRED_ATTR(biases_mid, ListFloat) + .REQUIRED_ATTR(biases_high, ListFloat) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 1024) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV3DetectionOutputD) + /** *@brief Performs object detection. @@ -1387,6 +1554,116 @@ REG_OP(ProposalD) .ATTR(nms_thresh, Float, 0.7) .OP_END_FACTORY_REG(ProposalD) +/** +*@brief Performs YOLO V2 detection. + +*@par Inputs: +* Four inputs, including: +*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov3DetectionOutput. \n +Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. + +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. + +*@attention Constraints:\n +*@li This operator applies only to the YOLO v2 network. +*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. + +*@see Yolo() +*/ +REG_OP(YoloV2DetectionOutput) + .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases, ListFloat) + .ATTR(boxes, Int, 5) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 1024) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV2DetectionOutput) + +/** +*@brief Performs YOLO V2 detection. + +*@par Inputs: +*Six inputs, including: +*@li The outputs of operator Yolo at the preceding layer (that is, one Yolo operator on YOLO v2) are used as the inputs of operator Yolov2DetectionOutput. \n +Each Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +*@li windex: A windex tensor with shape [height, weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed. \n + +*@li hindex: A hindex tensor with shape [height, weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]]. \n + +* +*@par Attributes: +*@li biases: A required float. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. + +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "1024". +* +*@par Outputs: +*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. +*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. +* +*@attention Constraints:\n +*@li This operator applies only to the YOLO v2 network. +*@li The preceding layer of operator Yolov2DetectionOutput must be one Yolo operator. + +*@see Yolo() +*/ +REG_OP(YoloV2DetectionOutputD) + .INPUT(coord_data, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(obj_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(classes_prob, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(img_info, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(windex, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(hindex, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases, ListFloat) + .ATTR(boxes, Int, 5) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 1024) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV2DetectionOutputD) + /** *@brief Performs plane or channel conversion on YoloV2. * If reverse=true: (N, H, W, C)->(N, H*stride, W*stride, C/(stride*stride)) diff --git a/third_party/fwkacllib/inc/ops/sparse_ops.h b/third_party/fwkacllib/inc/ops/sparse_ops.h index 5c50298c..abb1361c 100644 --- a/third_party/fwkacllib/inc/ops/sparse_ops.h +++ b/third_party/fwkacllib/inc/ops/sparse_ops.h @@ -215,7 +215,7 @@ REG_OP(SparseDenseCwiseMul) REG_OP(AddSparseToTensorsMap) .INPUT(indices, TensorType({DT_INT64})) .INPUT(values, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, \ - DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE, \ + DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_BOOL, DT_DOUBLE \ DT_COMPLEX64, DT_COMPLEX128, DT_RESOURCE, DT_STRING})) .INPUT(shape, TensorType({DT_INT64})) .OUTPUT(handle, TensorType({DT_INT64})) @@ -410,6 +410,7 @@ REG_OP(SparseToDense) * @li y_indices:A `Tensor` of type `int64`. * @li y_values:A `Tensor`. Has the same type as `values`. * @li y_shape:A `Tensor` of type `int64`. + */ REG_OP(SparseConcat) .DYNAMIC_INPUT(indices, TensorType({DT_INT64})) @@ -450,6 +451,7 @@ REG_OP(SparseConcat) * @li sum_indices:A `Tensor` of type `int64`. * @li sum_values:A `Tensor`. Has the same type as `x1_values`. * @li sum_shape:A `Tensor` of type `int64`. + */ REG_OP(SparseAdd) .INPUT(x1_indices, TensorType({DT_INT64})) @@ -485,6 +487,7 @@ REG_OP(SparseAdd) * @li y_values:A `Tensor`. Has the same type as `values`. * @li empty_row_indicator:A `Tensor` of type `bool`. * @li reverse_index_map:A `Tensor` of type `int64`. + */ REG_OP(SparseFillEmptyRows) .INPUT(indices, TensorType({DT_INT64})) @@ -523,6 +526,7 @@ REG_OP(SparseFillEmptyRows) *@par Outputs: * @li y_indices:A `Tensor` of type `int64`. * @li y_values:A `Tensor`. Has the same type as `x1_values`. + */ REG_OP(SparseSparseMaximum) .INPUT(x1_indices, TensorType({DT_INT64})) @@ -556,6 +560,7 @@ REG_OP(SparseSparseMaximum) *@par Outputs: * @li y_indices:A `Tensor` of type `int64`. * @li y_values:A `Tensor`. Has the same type as `x1_values`. + */ REG_OP(SparseSparseMinimum) .INPUT(x1_indices, TensorType({DT_INT64})) @@ -594,6 +599,7 @@ REG_OP(SparseSparseMinimum) *@par Outputs: * y:A `Tensor`. Has the same type as `input_values`. + */ REG_OP(SparseReduceMax) .INPUT(x_indices, TensorType({DT_INT64})) @@ -628,6 +634,7 @@ REG_OP(SparseReduceMax) * @li y_indices:A `Tensor` of type `int64`. * @li y_values:A `Tensor`. Has the same type as `input_values`. * @li y_shape:A `Tensor` of type `int64`. + */ REG_OP(SparseReduceMaxSparse) .INPUT(x_indices, TensorType({DT_INT64})) @@ -840,7 +847,6 @@ REG_OP(AddManySparseToTensorsMap) * The "N" serialized SparseTensor objects. *@par Attributes: -* @li dtype: A DType. The "dtype" of the SparseTensor objects stored in the "SparseTensorsMap". * @li container: An optional string. Defaults to "". \n *The container name for the "SparseTensorsMap" read by this op. * @li shared_name: An optional string. Defaults to "". \n diff --git a/third_party/fwkacllib/inc/ops/stateful_random_ops.h b/third_party/fwkacllib/inc/ops/stateful_random_ops.h index 9ba09dd6..929481d5 100644 --- a/third_party/fwkacllib/inc/ops/stateful_random_ops.h +++ b/third_party/fwkacllib/inc/ops/stateful_random_ops.h @@ -87,9 +87,9 @@ smaller than the range of the output (either `2^32` or `2^64`). REG_OP(StatefulRandomBinomial) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32})) - .INPUT(counts, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .INPUT(probs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(shape, TensorType({DT_INT32, DT_INT64})) + .INPUT(counts, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) + .INPUT(probs, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64})) .REQUIRED_ATTR(dtype, Type) .OP_END_FACTORY_REG(StatefulRandomBinomial) @@ -111,7 +111,7 @@ REG_OP(StatefulRandomBinomial) REG_OP(StatefulStandardNormalV2) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32,DT_INT64})) + .INPUT(shape, TensorType({DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulStandardNormalV2) @@ -134,7 +134,7 @@ REG_OP(StatefulStandardNormalV2) REG_OP(StatefulTruncatedNormal) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32,DT_INT64})) + .INPUT(shape, TensorType({DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulTruncatedNormal) @@ -156,7 +156,7 @@ lower bound 0 is included in the range, while the upper bound 1 is excluded. \n REG_OP(StatefulUniform) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32,DT_INT64})) + .INPUT(shape, TensorType({DT_INT64})) .OUTPUT(y, TensorType({DT_FLOAT})) .OP_END_FACTORY_REG(StatefulUniform) @@ -177,8 +177,8 @@ The generated values are uniform integers covering the whole range of `dtype`. REG_OP(StatefulUniformFullInt) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32,DT_INT64})) - .OUTPUT(y, TensorType({DT_UINT64})) + .INPUT(shape, TensorType({DT_INT64})) + .OUTPUT(y, TensorType({DT_INT64})) .OP_END_FACTORY_REG(StatefulUniformFullInt) /** @@ -205,7 +205,7 @@ smaller than the range of the output (either `2^32` or `2^64`). REG_OP(StatefulUniformInt) .INPUT(x, TensorType({DT_RESOURCE})) .INPUT(algorithm, TensorType({DT_INT64})) - .INPUT(shape, TensorType({DT_INT32,DT_INT64})) + .INPUT(shape, TensorType({DT_INT64})) .INPUT(minval, TensorType({DT_INT64})) .INPUT(maxval, TensorType({DT_INT64})) .OUTPUT(y, TensorType({DT_INT64})) diff --git a/third_party/fwkacllib/inc/ops/string_ops.h b/third_party/fwkacllib/inc/ops/string_ops.h index 1b88fbd0..0b4701b2 100644 --- a/third_party/fwkacllib/inc/ops/string_ops.h +++ b/third_party/fwkacllib/inc/ops/string_ops.h @@ -127,7 +127,6 @@ include: \n *inputs are trusted or unimportant. There is a risk of adversaries\n *constructing inputs that all hash to the same bucket.\n *To prevent this problem, use a strong hash function with\n -*string_to_hash_bucket_strong. *@see Substr() @@ -155,7 +154,6 @@ include: \n *This function may be used when CPU time is scarce and inputs are trusted or\n *unimportant. There is a risk of adversaries constructing inputs that all hash\n *to the same bucket. To prevent this problem, use a strong hash function with\n -*string_to_hash_bucket_strong. *@see StringToHashBucketFast() @@ -187,7 +185,6 @@ include: \n * hash value distribution over buckets. This requires that the hash function\ *is seeded by a high-entropy (random) "key" unknown to the adversary. *@li The additional robustness comes at a cost of roughly 4x higher\n -*compute time than string_to_hash_bucket_fast. *@see StringToHashBucketStrong() diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 69dd450f..689cde4e 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -400,44 +400,13 @@ REG_OP(Unpack) * "ksizes", "strides" and "rates" are lists of integers. */ REG_OP(ExtractImagePatches) - .INPUT(x, TensorType::RealNumberType()) - .OUTPUT(y, TensorType::RealNumberType()) - .REQUIRED_ATTR(ksizes, ListInt) - .REQUIRED_ATTR(strides, ListInt) - .REQUIRED_ATTR(rates, ListInt) - .REQUIRED_ATTR(padding, String) - .OP_END_FACTORY_REG(ExtractImagePatches) - -/** -* @brief Extract "patches" from "input" and put them in the "depth" -* dimension of the output. - -* @par Inputs: -* x: A 5D Tensor with shape [batch, in_planes, in_rows, in_cols, depth]. - -* @par Attributes: -* @li ksizes: A required list or tuple. The size of the sliding window for each -* dimension of "x". -* @li strides: A required list or tuple. How far the centers of two consecutive -* patches are in "x". Must be: [1, stride_planes, stride_rows, stride_cols, 1]. -* @li padding: A required string. The type of padding algorithm to use. - -* @par Outputs: -* Output: A 5D Tensor with shape [batch, out_planes, out_rows, out_cols, ksize_planes * \n -* ksize_rows * ksize_cols * depth] containing patches with size (ksize_rows * ksize_cols\n -* * depth) vectorized in the "depth" dimension. Note "out_planes", "out_rows" and "out_cols"\n -* are the dimensions of the output patches. - -* @attention Constraints: -* "ksizes" and "strides" are lists of integers. -*/ -REG_OP(ExtractVolumePatches) .INPUT(x, TensorType::REALNUMBERTYPE()) .OUTPUT(y, TensorType::REALNUMBERTYPE()) .REQUIRED_ATTR(ksizes, ListInt) .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(rates, ListInt) .REQUIRED_ATTR(padding, String) - .OP_END_FACTORY_REG(ExtractVolumePatches) + .OP_END_FACTORY_REG(ExtractImagePatches) /** *@brief Confuse reshape and transpose. @@ -497,7 +466,7 @@ REG_OP(ConfusionTranspose) *y: The flattened ND tensor. All data types are supported. *@attention Constraints: -* "axis" and "end_axis" must be within the dimension range of the input. This operator cannot be directly called by the acllopExecute API. +* "axis" and "end_axis" must be within the dimension range of the input. */ REG_OP(FlattenV2) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, diff --git a/third_party/fwkacllib/inc/register/op_kernel_registry.h b/third_party/fwkacllib/inc/register/op_kernel_registry.h index 2c479e92..cc8924b5 100644 --- a/third_party/fwkacllib/inc/register/op_kernel_registry.h +++ b/third_party/fwkacllib/inc/register/op_kernel_registry.h @@ -18,8 +18,7 @@ #define INC_REGISTER_OP_KERNEL_REGISTRY_H_ #include #include -#include "register/register_types.h" -#include "register.h" +#include "register/register.h" namespace ge { class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { diff --git a/third_party/fwkacllib/inc/register/register.h b/third_party/fwkacllib/inc/register/register.h deleted file mode 100644 index 27da0b0b..00000000 --- a/third_party/fwkacllib/inc/register/register.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_REGISTER_REGISTRY_H_ -#define INC_REGISTER_REGISTRY_H_ - -#include "external/register/register.h" - -namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { - public: - HostCpuOp() = default; - virtual ~HostCpuOp() = default; - - virtual graphStatus Compute(Operator &op, - const std::map &inputs, - std::map &outputs) = 0; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { - public: - HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); - ~HostCpuOpRegistrar() = default; -}; - -#define REGISTER_HOST_CPU_OP_BUILDER(name, op) \ - REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) \ - REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ - static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr \ - __attribute__((unused)) = \ - ::ge::HostCpuOpRegistrar(name, []()->::ge::HostCpuOp* { \ - return new (std::nothrow) op(); \ - }) -} // namespace ge - -#endif //INC_REGISTER_REGISTRY_H_ diff --git a/third_party/fwkacllib/inc/runtime/kernel.h b/third_party/fwkacllib/inc/runtime/kernel.h index c99eb96f..1609519f 100644 --- a/third_party/fwkacllib/inc/runtime/kernel.h +++ b/third_party/fwkacllib/inc/runtime/kernel.h @@ -448,7 +448,7 @@ RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stream); * @param [in] stream subscribed stream * @return RT_ERROR_NONE for ok, errno for failed */ -RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream, bool isBlock); +RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stream); /** * @ingroup rt_kernel diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index 93b7585a..1597c436 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -75,8 +75,6 @@ typedef enum tagRtMemcpyKind { RT_MEMCPY_DEVICE_TO_HOST, // device to host RT_MEMCPY_DEVICE_TO_DEVICE, // device to device, 1P && P2P RT_MEMCPY_MANAGED, // managed memory - RT_MEMCPY_ADDR_DEVICE_TO_DEVICE, - RT_MEMCPY_HOST_TO_DEVICE_EX, // host to device ex (only used for 8 bytes) RT_MEMCPY_RESERVED, } rtMemcpyKind_t; diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index d4e5682b..1e03e853 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -45,8 +45,7 @@ typedef enum tagModelTaskType { RT_MODEL_TASK_EVENT_RESET = 18, RT_MODEL_TASK_MODEL_END_GRAPH, RT_MODEL_TASK_STREAM_SWITCH_N, - RT_MODEL_TASK_RDMA_DB_SEND, - RT_MODEL_TASK_MEMCPY_ADDR_ASYNC + RT_MODEL_TASK_RDMA_DB_SEND } rtModelTaskType_t; typedef enum tagModelStreamType { diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index 2728c812..1fb9aff2 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -168,7 +168,6 @@ enum { DSS, PROCMGR, // Process Manager, Base Platform BBOX, - AIVECTOR, INVLID_MOUDLE_ID }; @@ -242,7 +241,6 @@ static DCODE g_moduleIdName[] = {SET_MOUDLE_ID_MAP_NAME(SLOG), SET_MOUDLE_ID_MAP_NAME(DSS), SET_MOUDLE_ID_MAP_NAME(PROCMGR), SET_MOUDLE_ID_MAP_NAME(BBOX), - SET_MOUDLE_ID_MAP_NAME(AIVECTOR), { NULL, -1 }}; #endif // MODULE_ID_NAME diff --git a/third_party/fwkacllib/version.info b/third_party/fwkacllib/version.info deleted file mode 100644 index 8bc7f6e0..00000000 --- a/third_party/fwkacllib/version.info +++ /dev/null @@ -1 +0,0 @@ -Version=1.71.T6.0.B070 diff --git a/third_party/patch/securec/securec.patch001 b/third_party/patch/securec/securec.patch001 deleted file mode 100644 index 8376784b..00000000 --- a/third_party/patch/securec/securec.patch001 +++ /dev/null @@ -1,23 +0,0 @@ -diff -Npur bounds_checking_function/CMakeLists.txt securec/CMakeLists.txt ---- bounds_checking_function/CMakeLists.txt 1970-01-01 08:00:00.000000000 +0800 -+++ securec/CMakeLists.txt 2020-05-11 17:10:49.406735400 +0800 -@@ -0,0 +1,19 @@ -+cmake_minimum_required(VERSION 3.14) -+project(Securec) -+set(CMAKE_BUILD_TYPE "Debug") -+set(CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -fPIC -O0 -Wall -Wno-deprecated-declarations -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)='") -+set(CMAKE_C_FLAGS_RELEASE "$ENV{CFLAGS} -fPIC -O3 -Wall -Wno-deprecated-declarations") -+set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -+ -+#add flags -+set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -I/usr/local/include -Werror") -+ -+include_directories(./include) -+aux_source_directory(./src SECUREC_SRCS) -+add_library(c_sec SHARED ${SECUREC_SRCS}) -+ -+install(TARGETS c_sec -+ DESTINATION lib) -+install(FILES "./include/securec.h" -+ "./include/securectype.h" -+ DESTINATION include) diff --git a/third_party/prebuild/x86_64/libc_sec.so b/third_party/prebuild/x86_64/libc_sec.so new file mode 100755 index 0000000000000000000000000000000000000000..8290bbcce1531065b32b8fae4bcb2f7ed1959fe4 GIT binary patch literal 80080 zcmeEv34ByVw)gF%8=Cb-ArKWIVxtM0VKsq(2_)o38;yVjM-VUsK~_mO34+kr9h0_g zKnIs`8PU-ZXK-P^ab0Q%m%H}Ch+P^nX= z&Z&JrRp+kAc8-a$SS0-yD~*)Mm3ThKWC)MAff1SvDOGAEWk?rD=QCQHplQzP+lKS$ zHbQ1(glz@0rXBZR)9xaM*XIU*bz_&-oJDb3WZpAp-n#+s>U*-Z!op zR=|13IHIxLH&Y#tJ?uP~9O1q~*w@0o7-lRC73noF^cw@ypFxt820K^4y@XBH zc$h9Q{a}W}B*5rDO7AF`t6>zF?l8$PLtw_i=w){f;3OFOb%wc=L6JX00S<<_T%ZOB zJ6+hn)}nfEeHx~=CV#!&bW}SX5cYYE+~n+P!nHBM1YRQCorK*Hwk+U8*wjof6YfE< zd&8u_bPzbQ;$bd;=?2pmrVot%yAZ%Y-2t1L6ZHhqj3p=F5inQ5WWWrA$%N?vlMU0- zFNlUk?W*x_Iud-Y)K02@nfvS&2_3@d+esHtW5sVB`zuNNVH5rNf9Hrt1)jQlbK!Sj z;R_;@wH`2qvziJ4zdN`kEfDd(BGN;Bf_`&hh<-D50{pz(lHP_*_?1*}_`S$2X%Gr5 z4kcX`cagwwt6qd}`H;K+B79>-Idl;L=;3S^_=Zn7;W*LoM+^LSA9MT$5zZmllpd8j z41P-xAN`(zA$*+(_-Tg4I3U8AC<^jRL4UQt$G^l0HwyY!1^y2rJ<|n#kI2`L^EqPJ zIozhc+EdV&tr$jc#&DRTqja7l%5$M0p?C=2Am)!r0{@Um&()$F^fpk==mn9<8Y9yA zYAQ$4xJAF;wUC|@@S}LACj!7vua}-J(9?K!E=J1_zGCFYpXXbkS4H~6dpYU*f}^*< z4|$9u{w@ZRk`}^wrv-dE!Vy&=Lgv=vu=KJrd;Fe-(@ttPnPWeJn|?Jdg#VmK=QL6N zK~64tTA+Ui{#2gp-{XW8f<8yY8x-y5W`SQW=-0l$5qAju^cKQ@tpz-Fx90Lo#d@xU zQQbV~0!Q7S(mDTc0PxfMS4s=vFBj$T(=JYbkXzDjfuFaRBkmRWEp)W7Q%nOg?z<`@a>`@ z)QNPK3H&!=0O$sO`aLPq`Sld;K3|LjG;JkvRMg*vBAicZIC`g;Pj7;g(yui($uLp&s#8mYLRRDBG=TZ($v}WXS<}SGhs_Q zI-_t=;mxy)T!o7!j2%69!TiDr({Gwv$o-o4nOaaXodTFXclK?C(&D0pi)PPv&74{! zjV~;iSTvtIbLTCbD;3WxoHuX5;zIZVDk^kIMXp8j3#Pk(AZTvkeDWw*xCFrnMtDVF z*CIEO7tSE;%;KVg>GLU~sg#|9Ia3Q}&6zrL`s}&j5IzW%2fw&TN9hR!SAlEE!osOH z7rLg-yJ>2Hd(qT+(@UgzH@OxR7fAD_&z-xVK$^Lzu#jh@xKT!sN<>{;P$V)?T;N(T z?H^AQb#ohhtEh0|w9GhKyq=ThjS=N3+1lvP;d8s(muPl(+43*D~q z`O@s7;-ZDq3ks!~g$rg%Gm8of=TMXs6Du;;f*DkDt_5>vh?1LHbj?CnX3;#eb3y@1 zE0TYs+*qP`nhF{xBa#XViE0fI&6qtOWh+X8=bjhCVk1mF$B|T^o>`)h3YOG4da>yR z=vqLioI;mUjI6sC6^@=*xTt9Mg832(s|a3HXe_%FS>d8Z3l>pGMg>F~P?Z0CPB3@& zO$%MKsJ#0ZE$E*j!HtX*^yx#+q2!!rKsOf@Of4eU0Kyk={M5o3(_Pb1u{RYVG1PCQ zsfF`rFp8-~g+wLgICDpho;tAqU;`RtKm#Jkf&GV|U&crl_AkbOxV!1!2nF$&4Y}xF ztO-y2R^V(RbaLx)@vwCl56K+5?n66E`ltUW#AG%na7^dfvoTWHqG|Lgb{1b$s3{EOFezV(gpGsbcJ z3ytts2>ixI_~!+_suBKY5&rf@_$Mdv@OL)C9}xIY8sR?>;qPsPzvA~i9JLYth`=9i zgnw7y>l@)mtmWZI;`~D|x4Q(ry%9b{6Ec2rjqv>+;`oF{_?z_b#X3e0UlQ(5LoAA?2_@O5JEEE266MmryKg@(LG2w@s@E#L>qzPYc!jCfH z*O>65P51{*csiTZf9p(m^Lf*H6Ml?N4EqHWUNPY}n($Ye@Kq)}opI^E?IyfaN5bA| z!e4E|E8cJIO2GC_eWIk4R=I3yEJhE-yVYL9iNZq$9#53Qy$0j2+#>@Jxo455wl|Dl zuioUQp-9`w+&#&S)xD%uG57i8Mxc`R0&|~BZW@ZTbZ;l>dR`P9*nU z=I%{y8uGNA%-xgRJ;`0g+~5R*X~@&! zn7bP8nQOwmh;HT@B{<>_NYHgk&~^`OC9PiZf8#n2UFI+Fr!LcBN_g+g{&Ei=f{>&s zHPlNYZt!1Ee8C}9?$;^FC&905IT@??4=Y>LD-=tu^4>{T0s@>Z0&GtKwiKRcL>q)7 zxi3?^Bf5eCK_WfvmB5IBaH%6uL#l#5wYIO|wor0BNS1W820>zgBJwHPOi(DLyO594 zu&c8(I3~Oim~igOdQ|y}#TOo}QM}(2YpoP|6i2I z=MhS-zt-vB@AMy1ABW4C7Ak6=8|>UlvFy%G`*5)g#NyGp!3j20?fb=P$jp_&^g+34 zzbtCI@qR?*O#7EBYhwkV(1?xHD=F@LDhu+hA%Id}GdAdGD0zuM{lUJd^g4v(s}UMy_{jTHI}#a{ZR>wJnLUO}ZOBUt&j zQC2pvdbz+{D|tkAmAO{7BhO0OM~g2|f{QFp{}#leZf*#Nz4a&E=d#KP$pZZp(9Kuh znKz3oC;`Ka1Y9?s)wBPo;{Qd>=Ha)dTD@UP4W+;q932iFs>Ni;lk!jAT$GjHvMTjY zaHJ@Hrz6$r&vIlq0V%#J*X43@mLpG2b~>iX$rBt4<>W^k9+Z`w{G?+Ivo|=_G5aOQ z3(Vf^sDeF130Qm4AfX_%gaK7`RtF8)by*H;3Somj$6n!USAtoNM52pR{QJ~Cs1qgN zbR^7B2DN8Fxb4cyt(3+ZbqBI#B-rIhQvU6!rit#fsot-a)>En{uvB+$oazKhb@1x& z#-5brl_#hz$z^mhO*Lpkjk~+*ZCm!`0&lc$-l?c!bNwHPUUfZ!%k^(l{NF3dwTkzo zUB2^f1kBoSrYGGY$vzqj>BLmZ^el5&Bvv*tVV7Lp&?oG+j9}kBEt-Pw$-wLCHtOUl` z{i%v~jNR&uvDDPX&?HgA+eD`S7ynV#DBdkSw){NBe-K&E!Z5oTG~U}KG%Rn(;y%E~ zz5lpk_~*Z~9F=Bt#~bD z&#cR43r{7G<*;KA;^UZGMq^9RCy~slHR`D7M6UdnM6kJjtS>MEcXpUk>~m1}!yxE$ zP)DU7KXEmvgOB)Ot0XsrtR*vX^vpl6v) zYAu%y2R-x7EwK8Wmc8M<-h+wWhB&$GJ?t<1yOdyu5wD>r3as}imhJGpn`Izo6Us-@ zZbOO%#=X77MAp#TH6XxSMQN2g&wC&O7bp)dP`h9`qkY&`KZ&Eh5jZ`19?dXG9PqIzN|Vkyp(?qS+Ed z+kxE~V$z%TChDWw4G4nrl;+*+V9ncoJ@|=+^RIpzT#fQZzME)@(k_h3TUO(|5g}rv za4Bzyf<<1^ZzYF#eikAI9^4KrO{u7&^7FDtGx8DUWKV-EII7ZHKR_-kM8N(kG=b%> zvL>*SH-V+rxvsXN0OLa4(A?~zI&sZI3boT6IFjDPHreA5rh&9dARTBM)VxCGL(y%DnQvs(24uyn8KZ z<;OoMt&;Dn(jniuNR+7q%jMJ@n(f5eUudVmN^_}wu>0@GS$o!zLeJ{hf$m2M>-lDd zpQn`j2bq6prFW@ymaBs(n?&swRH3c%haC(-p#?U^6n+2j>gGz`RMYQvgp5Un0 zy!AumvisYM%GmUZsWQ->s3@LEK~V0t-tCE^&{(Hv?XAD8c!IxLiOrA0n#JEb&=&RW z0X^kWu==aL`;Pn9vS8LHx-Un*UVtx~8nDvIpi!^{Y&0wpo7MpdQ38X}kS3yIj6%m4 zV-G!5ehM~lN!mv^A31$$XPHu~1lk68y8Ju6dynfm!?L~PSaqB^x5s!XB)S&R=4Pnq zWZwF)E7f1CU@k0CTOY;*$uqCfkixq{zQ{q2J#-Hyp1b6kNVAXTv zImN$SIlfczsumGYD9cnX6!SQm?n{tp>WZ|>>E=z7TvsxFl%+*U+v#fKOOm*tBE&= z{TDT?JW4A|bm^Yu7X^#3wyw3X;>rxPFQeT?UU~Z%UUaqYkWzq97`8SmCagE7V_U6q zX&06x%mmk1SDqljzFc-E(&~SUHU2Ge{%fp9$*$RsEJ$ds&WYI1ZA1A%w{KrRL+_Zwl{8 z%-fHy#!6TFGrW0AFGfFgq40tevC7+bSP5Q`uOJ-w+!{)+ILDUpwWj+Dr~6lrjV zP29Vz4I(VwkJo%stUs`UbzyRC;tAvA89q2@ob|?NIBDWrq7_B`nic;sL;bfuw zdPW?Ar9j}N2jFd(?YabimUr_uERcN+P-^`KZtb9aZ#51VO!@ZQwBk_C0J$V9^b2MIT))1 zt@kRHDkXVKrd(ldQ%(f$(Mjr#&=b*JSnliT8Ae<)k5%xzr(FRamQd{;q*C;D5@4BnI}M7yD#{T}&)KD6ESe6Xq+Vhx zt3EF|JV3*G5|;lZ$2wSIyl(HdgLn1cfkB~=Kgwg(Vc5a+Qv8)N(?-gET9gDc*c67& zoponA4Y7$z@WxbHJ~9_u3nn=I#{$FE^|WsA)>!EvBrN-$0Jk!HvLn&GN4p=e$U`3v zQk-nWx*QojwiUx20jR!(yAv_Uqdq|I0g(+RviH!=VN^~>y3>E;jYRmV;}ESj6lj*z z$#i<;@}?`n(W$xE9|i1wEQoQoi6tzKgAcn8W+smmdA1-E0I7+2_qL}tQ;7N92dCt| zJ$i}szP<<&r#vy=Bh{Q0tG6G&W2L?hj40Ms-@vH^&1rn@pjbm`Z<bdma#Wa5nFW*6wW3 zV|}`rooV8DG@VO3#zICz2cV`CQg-+g`hZ*Sa;uhGZ*Xg)zOlI^9d1E^;tVUv3pq$hZXM=aA@8Ef zBAHS*p575yhbKWY#tx$}&ITujacEn*t0bPbQ1;pwEYxw_Y4K_n_uf!-UDjWJW@l1g)EaZbs1BLoVx&K8R&L_Fjk>Bb)4_uF3_j z@a9;IG8ASe+h%yLkX$%>t+F`NYF({IdDF`z_hDLg`Dm5NNrp)->~Uzfy~Wkan|=rF zw$nZ)Ax(RX(Ct?%X>Tp=2+ZOe;JQOe`$zG(I%)$D%Ua8nwCY72g;=&$_EGO-vYHxQ zdgWTiiMeoAz?8t6T>oM1R+JqE;Ls{As>}^`gV6V52*A?5D87&rBjwFVcxn_ss{a?r zc*-EXX8lba{RJHeJxWOv;@kI+)<6<6AP_@`5HZ+&_D}+&>>1?{po~IAS}YK$yUsPd zmwZ)cL4v|Hj4ji^ms|$PuQTXuZBcNTj}WrtidI?WXb%X5w&2dR9jFbWPGV|14I;mJSkOU(IJ5Oo9iWj&+6W-iNno*-ogiUdW)fI*zSpZlU{Mvc=3)7_hL~ua z*6d)4ES3Ko1QtCFfyJ}UZEl_srwl6_xWu9W@w16ZdkvBSv_V>u(vM3lW+M#bX_$V{ z;QpOBXb}9XUjbJzI&*AMOY0=94@-f*pl1>bvH4=s7$QWBr;%7pMGRaN@mX_wl%^7k zmq8ZHtMoKLVsSIt3B*JYSp@z}4WKimR&Ry(d6rrli!5S9eM)i7iY(~fmrj9=%9)5P z-tz45h;wwgV!XxE0GWlSfn*j}uJ;%+i&r8ti-U^yh{d}PO?}t#-I)CDgP@K>zBWiz zX29sQy3C@Pj;mE*2k{EBEqRv4aG6CcKP7%v$8woPY(!?k{e#Ruv=W75mRbCaN<;@V z$|n&Os{*H&1~&HECvdYsnvq#tg7gIQUh_0SW^q@0y*jo&WvULb>J>r@&pURp(_$Ti z2`%zXVt~^LEuM)(lH8X=0I>n3Kr2N%W`DLn(B&4uOWH5($d|U{rc?hXAns_9jY?Ik z1zXDrMq_UUDT`e8d(cv%Co)M(DmZ{p67gYBU-Ea4uutk;vwUl2CxRh8HOT7a`5mOG$Ea zu`W4z_aoy>F?26jUw&6-eUF`upGHWc-GCS|p0I)DHwY|92q_M-u-91umk?Myj8XVZ z1r{Xrh;A7WScE-G!(3ot$@D}77G`0^HN+QePFT@1N?0LRv~L|Pt-x0A*GVga$WC+8 zio1XpgSc5*@n@j*QCcp&6|Tl-y_1d^^@eKopj!Xy(u%v$@kJXE(u#PLP()hchCieg z!#+SG!|DS^`k391AZl`zXl2~{0^!Z270alZ!aGlU3f{bjo>E%zH-y(riF^)pq(t~} zOpDTr;arqM6Z7e%6?~J7gAEUnGjZax(L#&l_bKJe971SuKAdKu#knS-#Z{~W@`T}J zmrf#@ks>^d@WsM`(x-PKCY(AGp~Vf2_%ATU+DvqP8zZUu;y+sMdiFK48EZ{|InS|nT1Y(?B z%4N61X*$1bMnZ7_OUxD}6n4c%3J_RFWO9lf{IH&KPCwJkA$mCaZut{d*4L;6kNt2Bhsf83{nvE*b z9vTTL`ZO0sQz1obkfGmFtA>!`8B7$>LW*oAq+lm%!4Y@Cn+}|D*hs5oJ}~rYSya(L zq6!B~$gdVvJVJ3qgq5chRnVbwGvucfRiuH3iz-$jm7>GgUnoH00B$gZ{VSLPFLd1Ul&745!Y^g%NvoIf1qgexfY; zn94#sgu#W$I|#s3w|7e-hQ1c%9d9x!vn=6ul6QO$G`$_cc{7#=%VN`Xd52tKFJ}S| zoFdWzY`t5{Os4x-RCpFTu&3*8raXLZjvxg@YB!3Yme3@60ILUBk~sX6A~i`H5|p60 zryF#UKxVEZF$+2Ai+2OL{^Nl`>^;tQCLS2#^q+vlhMwshQjg&HTnLH$2$~5AASg<3 zzvB#07^uy@S}`JvIO*^W%`D8G!+C#dx4e)46Php*n26#eNi5X+5f!!F1QJtZCU>_? zMsDdj9BwxCz`GRGRBH=~C*+E*iCg3dO7tAHYx#1c0mAKDcJWdS<9fWqGwLJ7zyNHTFhmrV3xl8K>s&A^VD z8Xe5`7GjA&dq?nEA(6o41TK*X?W)b{6P7lj+`^$+bxR$V!G|Hu>=o{T)z1&!uUfNH zi}&~0JGZX!EFUYmWjrMDEO$!FTk+#{Jcq!IUGIv%k~=PtpX6<@ERVyMKT+|=!6zS8 z7FPqLs~z^c9m%7d;S;%c;!I}eBr31`f&b3KrM<v9I8r=qJ`JtUI#gTPzg;OY?Q z8XUJ1hm3@&3CO`~cX}Iai#0(W+KaRx=4yy@(`kckgtKKYBxcot)b_e7h;op5s# zr-fcMANOrVXw01w`aTlG2y+lT=AacaGDHocgwrKCp-GYxT1YZ$*IN#@7+*q-cSVZi znh?lAkTYDk*f-nNNjW|>!MkFxN0l zynw3_Va8LKaV$K$9v(dzT#+D^#542{3Ts`6X9WdT6YUJ`VKE-BqlXUe3&!KA-h^%8 zuw3zz5|HFdYst)Bl3lK>x&g8gyuL_kB6kbm4$#$hEVv;vqeo+S!aNbTd?q;J=vHkO zo%IlmYqgNd1!9l{d%)@-ztku{=7koBks-U2=NQBc}cyL`t=YK~jJj&TNOTwu*z`7^ukT}G!Y5F^p!q?f@=Z>VSsPE(K) zxDM(66c0ueY|$@;FRlmVDuLSx%MO7^+WKWR!)E}Wz;ED<1*OwJ4`TLtam2jkdt}zl z1yaZc-rlj-Ni?+f1SzYO5PJ>T$#8zmu{*i-rnk4n8xFg!3A9>op?9fc;*{ZI?2E6! z-CA@DxV&Q$tTpTHu4KcnLzExTNywN>$0pURx4L6kO`@I_C#ov4L5ZvJwz34e9;>#v zdQ<0T@hjejSozKhnt21)#04-6Z$YfrAl8Gh^Al?HXNEr=tB!G>&v=eiTX796up&|Y zLWB&#DjQBHu(CU-e7!5|lI&YYAi8>YURQzfaX1``Q3B)b`c2m1qZ1Y#Py$yaDuJa5 zO6*emE0Ps&RcU-tQ)-az5ww-zH6o8jI~$Gu2!kLL|M)m>JA17avfIFTb9QEO zWJ-5hVUvJlu21eWE;>xHO?S9POPsWqU-qBUp7W%(Uz6mK>}v*=^rVwH=I zJKDPx8EEBQnl3HB%!#8A+@-~h>#*7m>B#hN^~a)Kn*MSDPXG9zblkzbX4x zQKpVhO{DQ#_8lZwT}~n;LLc_iY1qG-q9#A2Y!ej~ewZ4sNesPdgtH^6W#Yx{FLd)H&;vr` zXHYx2uR}b=zS231EOJ6$|t3d0rb?13mZBJ#{mC(h7T z@4O^SZMHKUs;J9OA^~V!Hf@8X(4QN*?{DP36EA7zy1D{O1C8K9NEVF-I^h;P6cxJ* z&ir%XUzkU{kse~HJe#VqShL- z=#CG%FSe^xyl5cVXpYqK1a>CY76%o08s z!dA$})i+#(X78l$!zhG7vUCy^42`9DNj+I|cfy^Ly@$N@qvf?#!7M3Ujxgaf1>58P zpbWJ2@9^$O^lvAk?TOy{0z8*Fxl}Htdv^d_xk~9CS4Zz~iSR{o=?R(}5FwT@->X9) z@KCU`OT^f+mprhOqcx1|V^atvPJI(+$yyCn=ZU4V?-j;A%I>|*dZWvVl^9H()tfp6 z*Sf0%ahSzBdmE;(+eB98w882A0kYcCBBV&W4gBhbuQs|~LYie*$HMy1RMkGmBzW-J z&QO9rM->$r4*6do8=G`CdkA

>H^*fnN`pekk^+iV1Wt-6Q+l?Sx6^WO0|yf%_c^(N+Y{9lEZgXegb%n=2+1SO$Z%**{E`OQ?D83?e8LN$CjyoWkg67~Wtk&GedGkI+yv_UmGH*-7mEzN zb~6*iV72xF@T{?T^mnmc0&(7bhrA6K7pj93&J9Dl@(73m5`WQ1X+%e%^m*&&%f7Z~ z&NMg_dFz+D=9?RH2NGl%jk&ra(wH$ADK>UbxMzF(o$PKOx z@kg$i-mMlCot27E>rP85AhdV@Ml*hb#8PO+p{JPN39RJw%(}avoG9C@26)L6sCOv! zC%nT!=2_{=rPNvJkR&c}4n1RDNK&3m_73G7Y*y)@i6o`2IW5p|5|XvL=E@bkch1+^L**j>bZmiVMrAA{s;q2$!{J;4hGy50EdLBgBAxdoVS5|8@eL^vU8nA4riTZG z;C_bH$84;V!D}bpa0p9AZ1ZcZuTSBTsei^z=-0?$4(;<04(7I5{a?_#G#@DZpJxW9 zNEp#(cz>GWn(FI>-KYLI6bb81)F!9B{S_? z_xD*`Im0lZDD|Pgz1~kme<@JHx0~*}HY>m~rqW4Z5*1a3Hx)DE3v8AW54k3(tw7Im zSccVoFz>jx!7AVRHp2DRCt{QW5NTNpP|`FdPQyr*NHB)*NqVo7SA79ue=L=Y;+^Z5 zhV}q&y%FOAC~vt~{x0W_UP%lnW{oH5JuO;4%^TN3Dom})+fd;0dK-#lUm;o>#5>q% z4e>TCb*E5m_wiTbQNzl8bdo1mxEzBr2M+E+w`24pJ=}gOJtCKm1h3vvFd^;4S-o7r zMHNx>7tNv>=4#x48k0jL(lAAX4?QGM`lNbVLScrcET zNfc5Rk7)2x0;p6j5KlY27)Q#Yn3L4Ua1PCiIT0foJ&eF?m8f81`wl^&m3=_hI|eHj zY#}I!Wb`C$Gis_1D=(C^q`d@3Bt_@ohOrU)BsCMEN2=*Zbd5|L25TVz2*4P*DImQp z)LNYBBd)w)-iODlF%pH{Io|qyt|dl=*}WUs5NlC_lWBTKGd`sL1TLy5;r!Nc?xvZ& zPF+Z$>SOC{*Km|?UWyy5%o0~7U46m0?xcAek4ITo$@Oo|_3zNyBWYqXV7VpxaLn(yR3^@l4TuiHNj>A{BFu>>G{L z=!L;vXTXaF{n@S_dT}D_ZuBIu(=5jF&>@nC1;B++#SrAaP*2?0mi*cuGAkqLxf zu5W&#(cdy5Z1n||dI4JIwO~i%f0~`}%RbtLd4E{|Rcq)4yofry^aK^r7Ir((`_tue zSr&*){!a>jI@~uZXA>Ne{AseLJroOXbZ?M+s;T%0MYYb7YDy2MzOEth(EQE|3Y)mb6s zy-A(@fcA-*=kg}?^(}fOO~e;A@hL)+p@Pdof~g5>2ZCs$el% zeOF>d$)uyqSXixCSS|-E$z#A$u_x?FW%D3|XUd1(JR<*Y`%(EEyHG}eRZUu|J09ze z_VoUsJ&xajvTB=th$KwigPc+nh?_6LABHdA+5)9-YN>_LbZIAB`v@etfowb4N64=4 zp0NG0tAW45r!Je3(Zm#X(PJ}vI|R-IjkI~p@?-yaH`}7)NJVJn8-W*Jw3{~{>md59 zHuc!3VF4#+)rV>+o!I}%s};%t0>&O#`w7b^_>anuWkY*-4bpz|ulawNPjDso6H zs?}A;TYs)gXvIC086Hv6km>?OWPb9a))So$1%H&l%V}VDqgAwN-geD<0}S z@MK#)5$;On#_1MCZZFTVbFy0^=*kQ`!WQ^(CVU?7`H-({*%79`Om^eRMPtpyUNOYh@mM^V=iVj^#NWd91H7B5W*apW2$# zXmR{gz&itWCs>_+C=F$ajVsTIbGq;#pJ_s@ zAX$(hi_H6*@y~-l%UnWvP67|+rk2SnGZK(CK{xCaba|l5PeAIO4^AcSEYXm*;Rvt9D&; zOpu!Laz(mCVf1T4=L{%uXoiq*Ujqp@D~!bQoJ5}TlQ@Zp&Sn$x$_!i^(dl=b0ev3m z^Ks!Lqi1m`$$2*7nSpe-@u;-epWfD zs$@+y6p1c?c&z6{Rz*JZBhB)~nd4{MyptxeH{Z8i{iM_E7Rt zd6Lr0bJBT|lK!iKRc554yhLcbac1%xGsu%xnTOEw(@|<8%X3EZB$ob`KyjAK+ZEb| zsLT-v2iJ_s@l}8foPVsqds~x32`WJ51Bo4g63m3-HaOapw>|4~whaD1ZU=IOrJV0| z`9`0<_0XX0CUGAJwFp=r0b7%ZloND`fW0E%;0QQ60uVOJ_0%-;GhUNDgs^`0jEX41rcz01Pn&NM*#AH zE&?&(K7B2TpPWpjlM|fM?4nZ|c2M#$J1D_1(Z`EUb5IhPYdYR&dQeh`6JGN{$tMr! zscCXhl0DB9(%FL&>~Qtdj*5)42PG95XAeqDM_9%Y$k~IEd>rjHI#)S+P(mjcXAer6 zHjT3fC1(#x&K{J|;dA5TBYi{5*NSHkN}`VlF>bJ9ptA=hWKRy9b&g?WWvL!Yjl)M``68S6wh9Y1?1Z)#IgGh*gmqx&$5imOfPKba75pYoi zTp0n^M8L-);Byi1wFvl51l$t=4@N*zPaqCT?2!YFt`YE(2sk7H(j6qhHa-I06ajCI zfSw4r8ep>rCI4YxJ=|7`?6stWw*F3ddcRdn?}2b@7>|MIc@n$FOs_%i!?haHaJ)jn zZ6w@#_ohnj<4STpmD-jQ_zsSS59ZRFryj-A>#TIWl)e0bR1W}@FyL$g7qIBMTOvaN zNYa562H?r*yfV5H#@Wrl829x3#E8VGe~Q`#z_NV(J<&&{iWQs37lmu;CBQ36_wL zsZF?X1y$l!Lv6y?H|T>nY7WH(oeAukGS&H5TnQ~Lz-35E+$)ela&-xGUHK|bOBxem zO&Sx1g4m=np%7GB2MPeMyyZI+SWu(wp43BfDP6LJc7#{uQVWA(-jGX=)0C;}NjON8 zCcMHm?5{?86FBs$0lf&Qi6(_IBs9{nn`lxX5lFNqg`6|cq!34%6i91Dzla)z*RlS= z&17AT!n;C^0;!5)jXLz|uTY~f09>b1qkz|K+4ih*2WnQRQMiId-?AEoD-oiuMj-=F z$^OgHZ33eu$DfhL7r zcp>;JRVm=R7U{iQ-SBXbR4HIEhbjf&2nNoY4pj=iW>hIGq1ov_M3url7>VfZ&?FS9 zf=7tDDh0WMbSb2_ph`im7>EV>o4929rvETkr;twC6rlE?v?D?SAhaoXm^OtI(57$< zH9#++umt&DU5Quvx;6#&ti{l#@DD~t+7#O0wY63SXNCt_8hjPZV^oaiHVap$VDY3* z!HaHG8H8A2cc~K&rGN1fjl-M{g!QB`d!zp=U(& zk(>#nM9cUOktT-2f~H>3 z5dB`jx+Vry;IXfoMj-4NWXl!(GkjP=bi&&Kp@4zC*ML8+Qb*pXz>T#6;zFF$S)QO^ zh#WTwM`CCOR*#MuY+3lFx(#*U-=?c(_yDENR5ReIE;3KHGyMHVe+8Ed?1gK}h8|ROKQ@l$_8w1li z^#cN-zVj&I19(USRa9IXL%#62oBMEW3}ws#668Ws2XhkNt%R)`%{I{5405pLA&N$O zoLW3;brwq%fQ33x!U3PdLjZor3M~sMWPo=Q3O|6P=zx9z*D-*BPe#D?0z`}5Kp^aA z$p!;Gx1}GzH5?-FK><>IqNUI~PA6B(K&w-xZU7JV))5tZuF6s%ufkU{QAlVec)!OY za^uwr>1FVk^)iqTj=Z?SLi8C*C@e}O)6MXJod8`ogSj;yCxr<{Yo^CzdTYkeG7Wnp zrkjBUBy=kIvfk z95dYv>%b~%f-+B_-lNoSkT8@0E8>W5hJ+S%Gt5Gva@`Eg&~V)h3yFm3W}wpJdKnf& zfz$FyTrUH`D@Zp36{bly!$Pi`0go7>bu+O1^9Qg!=opbsotn~UvE-+>P7(T2S2wET zS?u9E0J4*uPzN(j`vS%AYe!KhR2JlAdeR*j{x{%#fj+8cLaPrZfgZVd9Jdjfl89ys zJx~jijX^R7;*1x1g+QnIlOphKFCx zlq>o=*vR{5&|n~&0@dpr>gV2p3I45YxniyZpLh6WitBZXEGtzlz7A~fBL0-3L2`wf znY=ACn2i-qRaS7W4J(Anq!3`TBgy|g4TT*M5mV>zog11s;fM1w3?r0=qu+|Hisr)9 z#Z)^eTyw)A^<(%U8YrPVq`qy?IYH-y_6k;F5|qTS36p(Fd6{FD`U){od%e=6*@52k zG0hG^0Ficif+`D-85d&I8;fTQ6!R>-l@vA1!uO_Wj1a9|_RWN^mW4k|L4OM4jKqt}T)456o`nr(! zm&@hVw}CL4WECd~6`A7Flj0wPL#7NJ?QJX%82bo0qQb(a5Pe@NBL_0z&_a^i2D^YPTMcmX;rWqY< zCF-4~T9t=~0N09DG3!Y)QAxB?>qdd=UTiUk8ayC8!)R4T1w<)1fUQ z&P*KXy%^jNQ-08YC@DCh>rZM22xc?!UI1!mji+W6!f2A3>qTmAZ9P_$(5b#E!4n$#4p+=NCctCE{ z!>Nk4uxZh5LKD(8I_yO_6b;7DFVz{SJ42&Gl#dD@v<=^*n6p^xzmRERFtTz9Wrc}P z8@EdBWx--}ihZQ9fR$6~oN##0?Gl>_O-dWaO4_?%V!Hw?`qSzfD{jcMXpgc$c>L;m z);Ue<^A+%IqPW41tkUM$O1Abg1pr;ZwHOqk>zi-1uv4?EyJo|uZZspOF`p}Lu(m^| zZmhV0P68r|8x}xu105Qej*-xnn<#Fmh2jQ$hLd$IW8)qqJfRlVeEQOp&< zi-Ap)^&$0@{V3~BI2~3KD5bhU+UZDjqK#xYv4&8v9J*XirofY(j%ji-eLpDq5r+pk zmy@4l+cr6QgJT_YbCo7?GS_!#I_iK}7oR1yGNqnWa48^Hg8X!Xmhe#BhE)eaIK25v z_1{cI%1%oLb>@;?JBs8JaMh@DarD5A(#(wrG;$giXy3~EJpojO3_}rP(6{*~C&I!Z5;b}dB+ zPtbuARO%i;lZ1|Ebt{r%zW6YKU3?JTqdl6o!i8m_f2z5ZIMxXeL}%7nFbRZnSJq=R z?6EkJ^r(D~yBH_QS=^5zN)n%fCb}9DU#eNyDdCfQF*nqhzR?}}fr;U2l;G9jjXf!` z{6dFZwvMG7GG0Ust@A>?i|i{0UW841K`25x=9Bf*AM6m5L`7*rhI=D_~hR+4LRpgV4X3rfwd2@B9GWoqV%AK29n!w=t#@l`1)rD zP*bsa3GhOiGAQF%U2D1Yg7&m4ccqOWwc4_8IW4{6EtlO!gux*ThQTrv;IQa7Oc&8> zo50MeQyF#qHb!l|i+u(eA4*}V)b0hDTxpvzQi`a#S1AP?=Dw?hZ>8;SKnPlRKgzuq7~b=QH^P|#(9n|O;)G*3 zVZfFRzz8LudnW6iN!$}*^aNz6%}L>8F*=z=kX71#N@s}@3Uut8Za0v~Wx)0+0K|oI zd0)^WdB5pRU!gUe8_vx*Z=qgq5}% z;6lL(T{JT|Yz?{j|wKlxfY`EQs$CK!)lcn#K|Iu{0Lnak7j&{=i%m-yuMV51G{f zA+tgi(0-iZfqNqn2I4m{dOfv&XD;G?5fJD8D zj-1PoIaomJs)A+{6B~Vx5EP6*F$7kDF`n{KUmFEaB>cxzidu|S{W)$g$00SJ05ZnY zIQRx^H9+a~Cvhr=d@u~CH!x~3WbR{r_i?{U+Zq7KqVU|tJWF)XBH@WJZUJbFr+)+) z$}tZ(WPBsdR@z-0m(FqE7{)lnfJtgbz|md!R@%A&LcGFzEcXsy&H_aCe*r|^24S0LyR^6vl|y0aby1XC1SLY>2+$Z$_hOYs`R6>pkMIO+tKcxg-pJ{0;dFD% zbTc^}LY@Z57*Fry9I2cm&CD@aa8%m*140JGcsic5*f>ioGmFJw`3XIRSPJoR#>A)IMw6jMrssh?m%oDP7-c$&r8j}2xyI!>06qaT=yjcHJoK@ z6wB)d%Vy4k_$mMynq#E7^A5Hc&q)B8EMLWF_% zTt=^Rk7h37&IH8A)8T-K8>tuw&=^mz=HW=h5{S2uW#sy&K@9mbS!kM(@$@_}k(QMe z7XFtMuMl(_R89Qn!$KW~HlEH7p$j7mt9-H*VLOGLC+rErW*2urH(9u+2^%{W#*c#s zX3rHi2xA-|WBnoaa|3m<0fMa3)&~$0 z6}_KptJAeH=wbyOLT*5n8|x3TpZkQGK_LE9viLa4*B@d(w~2ed67fbjFTrK3Kg52H zw0$F-drX8_Bf*6|>dD5ncbggyzNvHlSIxq+N#P=qI7>ko$!w%E_L<#g@L zbgeiYLY4p->kqM?`=lSwKUqBg0ow;~AY?3{b^t;KqW5!8ah5-uSspi79u_Q!?_Pk$ z`a|sJW)W8)el}Ugyj8$ldO3;x+(=G&MHFS4phW0H0UGNMv7c+pdD=zs*mRysTMR(t zCwf2kVPBS{kH|8Tw3E45A-oHam=NbyfX4bm?C0($%0T=Bj6t8#?_n;Q(N_a9Vifzi z0?slciseRwWin?$d=mi~>kqM?yM*Wh@s~1kJ-&;Wi{k46h_4{d1B7%(@8{G$j9Vki z$n4k5MciKj;@o=x5jRq?3!t(75c|1x#1e>qoGc@phdJ^QhBU7~n(XJ&dV{Z}{hR~* z#&|0Bb8#Gp)Yt(T<0;?I9YtvR*v9vBUoi}*KLcPWhUWXZZQQTYRs{fAjM~qw*FFCv zJQ2o20FCif?B_hdA>(|o5c|2g90!hD7)SJe?rPy%X>$TXywUr)%M9;+!W-c@02qu9?Ky@av+NS09u-!m7nd<%#MwF=vQK*o3~ z_Hz|P7l^N9P!)!svDGBIY9QTLE$IIe>^8skjNC zF`kP3Tne!S;s=mrWT6j7_GL)(czSj}NBfiiw*A~?4wQQIe(wL`t}a?_3cCRzo~NQRz0A`x93 z-Tc?c&{L3{GnAp*C~@@EcNcQXK6d#g1MuDVEV)9?_Pyzv3ZQ$oZx3F9X8U(%r@rl) zh<72lF>?*Q35KU7J9WFumW?M`-_xI-f{}mFetfRKAuF3ZA485LN8-uTQR$H9@b92)g*m^*bv+HiC$@G6hp?^u9xG>(9lEzZA zpQq+ab80?hsR`|1{LqvjVrV37?je`a`veja_n?ae=R?eoBZD5reQ=}3Jy1y4>4@35 zPQwdI_}ouE0)B#E_T^xFh-F8ygO%moK)wbe(ItS(*gfyG@9{0T^uOQ|K&U%m3zhYO{U_WJc$VT(UmlJ7tJc2~%=sbg z#;1l^B%$@Dw1&PBjtY>5K4IUSIiouRe;@Hx(n%uTfAI|F~uG~XGx4*6@Y zGchs!T4W^`N;ybNM>$}nZVRe2eB?Rh_{&bsxF%Kv#xv3VM3g z96e8!^b6Q^2PbY8&q2ERO^468^`ZJ8I9X@gK@n3bw=p2>x+*ImXX6&v4OxNi*?2d_ zyIR`YF5FR}uI5DJwq^w^*=f|%a2fFs{ke#9MR&A`qN6=4IvUt0X!naoFQ9%_m05>l zUKAI;S=($^L-3!GXLg`1C2UVN5_5njX0ItRAG5@Sb~fs0)TL}Q>sq(9;ZpII3nN`4 zI0ldXaUW+z@dEZ%@LxoUt17n1sc2S#IMu<*sLGWuDv!5f|Cgvl`T{&|!%RG>qe+CqfYw$BiChgZG;U+9*ikn(YEl# zDS7~h3s8i?b){qO?|E6m54UQFfyS)Rk6e+%FS-67761E+Q9*K9=T_88!}5LA^px+u zDqpqilL(jgrR+OS0y5lzoxqZqXK|+OlgsuIGVODAqp1_V^;XmsF4sLYb~BOV1NV4Q!zpiqzg!tR6VFIE;@Sv8t{5Fto-IBTl?Hqu)*$PxJ3(YE&lX#L zk_a*)xpFkV38kZdfv9X#YZN3>iL@Sll1H_lNzh^yBhU{c2n1^{&lWrS*G%}z(fAyc z?(rlA+^$K$bQMequz+zJ*PvcZ)E$~oPbGB6CTNAwauf6b5q4la?w#T7jsH54$b_aI z$V6Ouw%A?QGf`PA6muAiC?$O7NhkOi1e-cFiLk0EUtn~c!Sy(C#W&%)i_l$~po<9I zwF!E1(>QRC86&-(w&BEgUK2hCq0euE?m%cX&_?xr45O*J89c?C!Iy-_@YjfXJE6Nb zLBB|7e0aAJ^+P~Ym77YJat-5#sk4jvyUDTF;h7keJV>zpF`l~{3 z@jPJ9fQ$2o*~%M1+~LIM54TWQ&a~Ra?5=%R5m?#|7R)P8Ai;8WHM@?=ZtbS8#;Tls zAXV*wauNx}-P5BrjrJk57{W5B8|^{5@;%6k;)`ex(u<~PDhgu*;>t3ndVLD7p&9-% zlnPDzeQDay#rrx!cDyoZMadm@*~Vf$_YDgPkgN648jyz$=^=y;MYIX=Wf|!dGA#uj zKoV#v;9JE2>=#Q3pfxpL%`HLwEWVhQX#S7Y3}hE7w47;QFYbL+(ApQ%e(^}S7F`s= zp8aDrk^3uYs(VjutXPL+gP^Dr65J;@ZELYic-Gwd0q_0+iid={h4+Vs!+5yahnGvJ zqtm}nyk)%paPXxiyDpNCtP!Ugg1xZ1@%dE zW_t2RGx*Ivtb1FrC!$hpE2&kCWVPu(BS^bsx7#JG1zXb}Nut9e{9!Ny;cahiXTAU= zz!Lo@zyJLoB@W3)9vOfE+iJBkj1(Jdw~?0=6BBE*5fB?IfS8yVPScA0;MJP_n&c;C zz`#L+hYY_-Bq9os~UiQX~I;xIhx0k5?No!zYEq4KOmUdep#- zyCEFj0{miJ9wE~TemdNoy;IQcGx7Wgcn`wbg$pWVc9>utcP92@zme}@5${Wae|a_{r9QF&a zYhXVG`x)2|!Cnn}HEcI*AM6`oFM*v2dp_(-VHd#e0DBVbqxXcvV_|;^dlc+9U|$aV zDcF5s2Vh?W`&QUpVCTbb2YUo;3HHUX|An-*fqfA6p*7*~m#{yC{Q>M(VZRCcaoB5N z%3#h1-=DnU@ZVsnVLpX90P{1 z<}sM9Fb83}-ibWG+z4|g%*!x)VC-e#@KBgJFu#YXhB*S0T#kIf+znF=6Nb6W9}dri z`6JBdFrDtgS`B6u%=<9y0^#tLFn7Vc2P5Ad4o`r22<8CHr9q?}rUJ&cDjdEF=0TYK zFvC}4UHl}{^aR#Tu(!ZYhW#|`?y!TfJHal3ZHIk5>|g$bGJ$;rb}!fmV7G()8SEb( zM;c&%47(QgYp`F1{U_Kj!1lv_3ibln55c|;_G;L}Vf$cT1bYeW*0ATp{%#%G2<-P^ zPlCM(_E^}D!5#&>6!zt?=fLg@`&!r+!M+@J7uXlVZU;LSwgg*y3~dBy{RixWuwRC~ z9%eZFwxS)r5A!ul1B@MQ`bon58gJz*NC} z1Je$DD-~uE%nF#NVLpQS87BFD)HlrCFfYM;4b$lX)F(^{%(F0G!z4b4G{BU>ya96( zX7C@-7hs-;`4Oh?LnsHBzrq}WNqHE359U=E`y<#S!K{Y)80L~k!{NCwn_yc1F&v%@ z^W3xHaNKkMbLQBFNL-bsxTg!)?Emjf*X1IgRrX|LgvI zR(|{Z*4k@7&fe?n^B8#efFyqHa)MqNg^k#jA(o#>dB45bc=g^5P?uL&H!tnqpO0LK ze7cPo5B?F1gKvVh;LG3~@F}ngdmhwe~LJSR3I(L66Aj5b)@&x)P>AIE<)}{UPpRA z!#{R|EI@8WUPnUD?%zKZxfppA`3ExYIp#ZU3V}x=Ly)OR6j^{=i#&$>7Wo9}`#kF% zITN`ExfS^x@)J2ERmRBKwfjkq?nm zk=@8pWG8YgvK=`L>4yZ7!L*?ZZSI2{gOnjNkRKw;kOz=wk@u0V%$@HelaM5`2HA>y zi1cS}pN1rnHON-P+0OMIsYliz+mJ(k%e4Tx2zdlK=vA)g^kEzMH;}&|hrkU)PC;fP zKSov}_aiSOUmycsqdue_xdz#Yyp8nS0gu!p*C1Pv?!V)@fYc*vk=K#n@0mNuPmq_8 zv9uwCEMrW%ffs`g*baV%oR929&PDbh(~ws-UYzk=()_rP7q?~!jx-dK_?KcOw!5o-)jkJYp#V&Ssz%ppZZ(~63Qlobt`Udf-L zu^dqFA0M#)Lbf8^kjIga_UzyPEAVCTPViUY4d7bvYVgP4gkp%fK-r;=hW^N-TNEvb#(u!P)tVT8<&my~!zaznS znWxBTWCqfPT#ej?Y(d^ZK1cez$Gk@-BIh9;$TH++WE1iVl6QFEJg^2;= z?-)e=VE!@t_b&&hf*Zioj%9xEitRD52pn=8e!)M03&GU|oKLXN3Cs`haj+Y&Bi{kT z;DG-ihJnw5^T6i8%oFhWQsM(QzH`SCN5GBXW#Ct3%o}i6Ir9o!1nvP##^IO8QSsUM z0Y3(B1G}Hid;`Pa4loMt1@8pqxzNX880%c#Qo59D= zWgT>PoD(CgIWRv;Iq)WMHCR;oB}7X*9`Mg_I^O;@?`^^!d=A_I)<42|1%CtX1sD99 z{)9LWk5UhKJ~#%v4LlFLjO*A!@XXig2Pm<7Gx!T|7r62b>hI$?_rA$^f@i-)|G@#f z7*Fska3dIcn|i_PK&LP7Z~j34!8^c_;Fo(S2OjxH%7OQQ>p)5UTyXL(nC9fo3*{Z& z>)>w7c?dZOo)`Si!XxPm?>C)bs3JJv^h0`fbX(}0cG$@$6c0K|ii%A+GIStoL5K^2 zp+)(Vdjz|c_B=hBIl!@Ec%(jP|Rqj2APV^9OM#Y75crPGdCDomOnN)U}<1% zFubHoX|Q1NL8ZZ>MP2jf;O?m{TVI$XM~e?y(q(C2S$=hoU_ogxOo6e%&{(PUY&8>S zZu@qF#~nrd7vevqeeLP?4Th4B(SCxyRrIufaeirVz@k8TFg(3Sa6oA=R4RUFh+os+ z66~MF{$8;kAKaCfzux$hL+Ni7`q@Ww?H9u6Z%^xGtWxOrqL+V2-I*T@EtUQ*3DAwj z(%(e~l?RIg1M=LS(G=UWS#Hn7tEor+LC%54WN(L@*PZC^LqE`zaR2aIRA(Rj!|?o2 z*Ku&(ss6@bI9;t<%+xoSd9X1@eUq@8gnlBi^k(&Yh4gEgoX4eI#s&+P$a!4cwKO<# z(ZS`x%KS@v1V_TqV=O!;(9{^`?^^u*9DnD?w+OWJYy6E5F6~Bq;N26Z}>1+2)~)-%j*{P9+9> zoAF!THSj_A;5x={dDr|DnK+p$#nV+0c=YSWZ&DY>`EVNdAJWdu@tfx9uV_u&xe)u) zrn7f7jQdCWy8^xxUOpyj#=y-g_xBkUeKNkJ2hv4Ct493C@TMGM&L}HuV25Oi4S!-R;o0@ z?dK-^)#A^8ez(KRKN5V4^iz}6{~`R5?5%zz{5RDq1xUS~cn0Zts7go{l zM1Nuv_ei2IM}PW(c_#WV(WlU##NI*lfh|e*}J({fK8|R{wCOel+@3Jip8q zf28~z^aacLFj^c+#J@T1-d+FbUqnAZ^qOQn%a1mN@3j8mUw|*Q#-LgM@PCKzwEp2w z_*v)qu@8O*{4qX1z1Z_s!~6XVf?t`%UVa37BYdaz4}ap7Jfjf*%nP&rwJsT#6#9a% z(XT*%H2Pfe;g{&wq2DY%bH#^6off26oAKI5`S4Yl@nJN&UhL(Q@Yl$EZViT($UbPX z>`NAPIm?X@W``9x);WfHe#;7;ad7;=~I49JTVF7bkY( zZxR0dapFt(yW#zDq8}?@^eXm?-!4u}`=&Tii@)a8+&_QIII#lzBd_Jz$hV0T8x9H# z3w?c@*p0u1ckm2A&WqWPUY}lwG9SD2;@tm2_s8k|d49Q#Bl7>WG^5Y8&dbs7M}KSD z|Fs$ZYtb*hGk5)$pznKEdcIfYFG&|SacDjIE6|@S{udm$zu>a=gjxUSx1!%D`uvN7 zAuhMlubuEsck}E`c!@tAzf$f?^nXWh=0$xl#17xo*N=O|qwe7uhg&YxdZ0N*UxI!N z`dy%jo8Rx9wvc+Ognxv4GOPdhOnnM|!*kj152XAG^cBx%ACJN4e~JDH^nWDp2aEF0G%ac4P?S7dKz3ulfaj=3$o$KXdN2I>%RN{AsH{XOM3QzK!98$vYxeOI z{b=-mLjQ;qknz1PJq(Vta}N5;csBc<=<{dWcGgMeY-5^Cw^Po?66{aqIj-KCN-5)K zJ^V=cr9ybGv*G-Cp5}|zw0k@DE3v;*2;Y2{diJ3oxQ*x2qSqwt?ZrL&`S8=&CrDiN z#=pSmuAT_!)-xLWB|NtrDB~K)%L|6iG556MQtFw5ekso{uM|R)oR1Xz!8>?fC0wrf z{kXZW_1}i6mw&;(M&SQv1orC9*apRWf8f?rJI%$@6z@^|QuXWg!?E`jw_E$G+;V?a z+^TrW3YY(>9{=eY_qgy{7pqk-OG}dcR|1mkaeg54Sbk+L$?6aH9Urg9EPa#C(Bocu z?C55bP-)Wkw_TJUC`ek~bD%-9rz6J}(jJ@f%^bT&R@$iyy6w?naVoVUi93?m+M8te&>Q-?$ChuEy#>j(bBsD- zYiN?K&!W=K7rZ`{ymrU=N{_?pZ@*8!SC13KJW2nf$KP=e|LK3jz1(r}J07kBQK2|f zu}-mFak1jnink~}sQ8rPtBUU_eyZ422kcPAV-<%hRw&L?tW#`PT&#Gt;w_2~Dn6z7 zs^WW!pDK3M345sGv5LbLD->rc)+x3tE>^r+@fO7g6`xXkRq;K=PZhh`1>rvdxsF5G z&zChXXl{%jZx6Cy0bDcEhQw98_65yt@vn1W{pYg#CjaBh%1#OwjGr>&`0$9r;=&@D z@Qo?J>ap>_YaDR$({Uh^4>&enWb$1cyRI_%gB-gaGx@HLUH6&%!H$g&nLN=!=XECE z-3e#qdpMyi=Re@|bi!HrUXG2knf`k_yWaLPZa&}yoyt$M@rO7I|0f$CO2=)lHsJJe zY`o6o`#S41j%V`Uak9mg0J}UJhcoSY;F@J#2b@El%~|=*gDLT~i}ROG-`>jgDnHA3 z%5_cn^L*nht4nx$J?GG*5&X^bv`6ZbRDtAW8J{M2@niP~Vk>DrlIL_D=hg7p+JAEn z{=OW%{7@^KpEtPPW#doc0-KHBlEcrq9QLo}u>S~N#-XnaEJ?xng1qp>N(?hvUYjwn z!Z#>y?xE2gsr+}3at&o%Bn?yk*UFoBAMo;D0_8cqorUijGjo3-x#`L~yIuYV980P( ze1GRWop1I!FaM=ip3~nMqkPy{JI+tl&v`!oYt=sN^Rody+c!@xBL2Ilz&xuyYIJtK33lD|E+yjHm1^kyYIhA{R~jP zW4cS6p#0Iw-=MsC)`@PI@~2k0hTGMCyz+f#xWruLtCU}$ym_~QE~5OjS+3y?YX3v! zKT+P!pG%Zqf2M0#ruNq;Z|`gDv%=ezUpm`0H18D1J`A77OZd%t9no#E^NRA@eEc5e z?RCf6cOkxspAA0ydzt6i=Fj)Am+}9H1ZvW97aaMGeKz|ObH&dbyq%vfUgX*SLPnGj6go=&#JTUB3DDL*@7S_{)`d^nT0E z+nbaR`FLq|wt09<4!#gS+3X+9VK4s}U^e?Fa_}$Z;N_ihHa{Qa;N@Mith8<>(eVP25IsEs>K3jVR=HLhC;78`*%X9FRIry`4@Qpcmc{iV} z{fQj*3v%$6<=_d2S^dj9^gQPfr$BxJ?@QU*b0hZb);ir^JYf9v@3fzP-0+?63!Yc| z&3fIpRljEV&i4x+!S}UBUgW`0U*^?kIOfnv{P?+hg91plertQ>AOTNbUEq;gt6O_(+$q@#<3SL+n5O z3NBUT|;|bp#PFT(_pi)hFWB3plls{MI7{!6uzQ90~CQ2X0nat+P5 zgJjP(?dj<()%jDee%9#tuh`%ccAZYn;lC~iKR*Y*((wHGH?2^iuePW8K9{iTYfTP6 zcNxB?^To@qy?JLz_Ce*FcDekOE;vu-@UuhhyNz-UD>SVC0Y1cYjAwlD_IwyQZ(Dgp+XD(^i#nI~Ldfi~| z)iNi-OZ)xta8wTeXQ-cFXqrK@NVowtuL`NjndhYrW^{dDO7zT&w&v z?U#)kH|OyGP!9e@_0vr!!qeKX-z)#F1_1L88{KZ@-#N!MG~ZN_{TM#%QW<}H_Hg6L zjwf8&yqhOmp!0dpJuYwO^LN$0aG}eaZ=lE?2OoB+jK9Nk@J>x4l}xqO)fLt_wXsC3 zzA>4KB_gS&NKL$@IhJ%Hk=m9>eY~YQ8jsYbS`x`fv~8YK)6&!$kELR@g~Lac6lFJx z)HOCYMxu#Cbbch(oJ!1h>JrhWSfsYCscAk%JVy~MQyHf1$<{<;bE+eve{wEp4gRwp648Km1B%R-!Q#D{F`*W~GlNx5ZP9 zX{w?n(G*Qp#9N}Nv_(~HUaGW?ahY75YUwb3rW&u)XF2te##BpmMr&&IMe=2}K%b}Z2%?I=wq8|#~=v^7=75^`J~ zt7)l?xyR%io1$?)Js~#FPfe{!b+#K`lEcC9lC*sVQ!v@!RyVvv+)hlI6MCSAvUn_- zC{5H0Gj3ig)?90*$)s4bFf*FR$5Jwl|{A;No07M2tj5BD4Ect`;zXr?1GB%&#nN+-vfm~Z0}iI#-Z zp6p0cySpeU#ZSK6ldN-XVsd(Gv?k_vXe%oo#k9&=8JZ%$M*34cYUoJeGx}3JYGg4@ zNjHR6d4fPkP13H9_Bu~<2_eat8#UZsapd-@J`#<`%`7yziVCOQQ;yY}YMFarz3JL( zuk$wGK%VNYxjR5LZK=jMOR+wdO4|5k8z^nnNJcfymX1U!rj8fY;hac_PmeR+Yf|%DV-flhX{wHJ z?judnd6~}In$gx|a-I;KuFS1QL$|3W6041-qE32LC9Ig?Fc}#mDPtC3c`TV4+g9h; z6KzdnW{T1uQ%?IstLH`KLKKZRUJ&!Q%$+uV)?QLli8cwZv05rWa5-{V+7mix43ocX zW-O6xY-!exWjd0JmfMAnnp8_ub*8q-TuCxHbHQ{PS!A+`GIi&{I$aKdHcZYAh z&3vu(7behVQf}i8a6?SU^A7MV8DVL7@_?1h7xa5?Z)uQK|_3N!qT1ng8Yoo*7qn3$f zm$U>i8!pToGLy@6l~pd4Mmvq3Za7#2)kzj{Dw;?+b+MMZ%pSETH4K>M#bdq^mjUw( ziZoIv-{_`8k7v^((@VBLnbnyS9rv~MK!kQjK%&)w88gd?Ostw5Ik3+oV8|}dY*x~% zlciHtBS&^2aTJNvH?=goU6+_7ZtP|;!--RvoZpm+R)YjigAF#<5KT5Xg|+jWF>x`K za0=_2+X~xt9eEET=n}DbR0?RWHJ)+`Wvl2Ek}a%nA(LWn<`lBR8!>7zQKT@|pqtBv zTAbJ;x2U_zbW7O0l#;q|VA>UJYOEn65vhn8i1 zg~bL{SUtCWCh?fkq=EeHiCDe;E-Q`PFGva?atmt94-k@cJR&+<-hQ`X(SFw@wvxn$ zHx1!HzM-<^?e`rP?@(i__oF@!j`w+i&z|^3aj_` zf0mXXs}AjVB^K>>YEr%*;x8YClQH2o)J^vJgQ9&eYSs3<*<*P;<1KH$Q?Y1%%W925 z9iBAS`^wwzS<>h!D_PrqOSh8As9zgwzkj*U-q>mZTfZ&e?km6jZC7KF+uMxf@Bf8p zrG9&U>~}O4TV>EaDeTEh_7cQf-hOXmQGQ$RP5%1j_T5|FewSnMPW=wh`tXn6S`b{=nA-r{Z8NENoc{my6gw=6G@i=-x7-oBSzr{x!-^CtiKdkm$_U0dFM zHx%yWw!cVEjMdwATR-xC&|AO#zNk_@Bq7On7v5ycTYL_sx4dJYi-p{Ro4tj}hkeJ} zI2K>FefGK3W-V{)vBiANapbX-a7ETY(U$My<865_+PUjB-C^rL)MdX5tU_#iZ2Xn= lVcYInx_@^a>()@oa~f~5^;%Bam#uu!5ZCfeYpCci|G#@~rMmzC literal 0 HcmV?d00001 diff --git a/third_party/securec/CMakeLists.txt b/third_party/securec/CMakeLists.txt new file mode 100644 index 00000000..e360a6eb --- /dev/null +++ b/third_party/securec/CMakeLists.txt @@ -0,0 +1,11 @@ +SET(CMAKE_BUILD_TYPE "Debug") +SET(CMAKE_C_FLAGS_DEBUG "$ENV{CFLAGS} -fPIC -O0 -Wall -Wno-deprecated-declarations -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -D_LIBCPP_INLINE_VISIBILITY='' -D'_LIBCPP_EXTERN_TEMPLATE(...)='") +SET(CMAKE_C_FLAGS_RELEASE "$ENV{CFLAGS} -fPIC -O3 -Wall -Wno-deprecated-declarations") +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +#add flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -I/usr/local/include -Werror") + + +include_directories(./include) +add_subdirectory(src) diff --git a/third_party/securec/include/securec.h b/third_party/securec/include/securec.h new file mode 100644 index 00000000..b627a3c3 --- /dev/null +++ b/third_party/securec/include/securec.h @@ -0,0 +1,634 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 +#define __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 + +#include "securectype.h" +#include + +#ifndef SECUREC_HAVE_ERRNO_H +#if SECUREC_IN_KERNEL +#define SECUREC_HAVE_ERRNO_H 0 +#else +#define SECUREC_HAVE_ERRNO_H 1 +#endif +#endif + +/* EINVAL ERANGE may defined in errno.h */ +#if SECUREC_HAVE_ERRNO_H +#include +#endif + +/* define error code */ +#if defined(SECUREC_NEED_ERRNO_TYPE) || !defined(__STDC_WANT_LIB_EXT1__) || \ + (defined(__STDC_WANT_LIB_EXT1__) && (__STDC_WANT_LIB_EXT1__ == 0)) +#ifndef SECUREC_DEFINED_ERRNO_TYPE +#define SECUREC_DEFINED_ERRNO_TYPE +/* just check whether macrodefinition exists. */ +#ifndef errno_t +typedef int errno_t; +#endif +#endif +#endif + +/* success */ +#ifndef EOK +#define EOK 0 +#endif + +#ifndef EINVAL +/* The src buffer is not correct and destination buffer cant not be reset */ +#define EINVAL 22 +#endif + +#ifndef EINVAL_AND_RESET +/* Once the error is detected, the dest buffer must be reseted! */ +#define EINVAL_AND_RESET (22 | 128) +#endif + +#ifndef ERANGE +/* The destination buffer is not long enough and destination buffer can not be reset */ +#define ERANGE 34 +#endif + +#ifndef ERANGE_AND_RESET +/* Once the error is detected, the dest buffer must be reseted! */ +#define ERANGE_AND_RESET (34 | 128) +#endif + +#ifndef EOVERLAP_AND_RESET +/* Once the buffer overlap is detected, the dest buffer must be reseted! */ +#define EOVERLAP_AND_RESET (54 | 128) +#endif + +/* if you need export the function of this library in Win32 dll, use __declspec(dllexport) */ +#ifndef SECUREC_API +#if defined(SECUREC_DLL_EXPORT) +#define SECUREC_API __declspec(dllexport) +#elif defined(SECUREC_DLL_IMPORT) +#define SECUREC_API __declspec(dllimport) +#else +/* Standardized function declaration . If a security function is declared in the your code, + * it may cause a compilation alarm,Please delete the security function you declared + * Adding extern under windows will cause the system to have inline functions to expand, + * so do not add the extern in default + */ +#if defined(_MSC_VER) +#define SECUREC_API +#else +#define SECUREC_API extern +#endif +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + /* + * Description: The GetHwSecureCVersion function get SecureC Version string and version number. + * Parameter: verNumber - to store version number + * Return: version string + */ + SECUREC_API const char *GetHwSecureCVersion(unsigned short *verNumber); + +#if SECUREC_ENABLE_MEMSET + /* + * Description: The memset_s function copies the value of c (converted to an unsigned char) into each of + * the first count characters of the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: c - the value to be copied + * Parameter: count -copies fisrt count characters of dest + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t memset_s(void *dest, size_t destMax, int c, size_t count); +#endif + +#ifndef SECUREC_ONLY_DECLARE_MEMSET +#define SECUREC_ONLY_DECLARE_MEMSET 0 +#endif + +#if SECUREC_ONLY_DECLARE_MEMSET == 0 + +#if SECUREC_ENABLE_MEMMOVE + /* + * Description: The memmove_s function copies n characters from the object pointed to by src + * into the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: src -source address + * Parameter: count -copies count wide characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t memmove_s(void *dest, size_t destMax, const void *src, size_t count); +#endif + +#if SECUREC_ENABLE_MEMCPY + /* + * Description: The memcpy_s function copies n characters from the object pointed to + * by src into the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: src -source address + * Parameter: count -copies count characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count); +#endif + +#if SECUREC_ENABLE_STRCPY + /* + * Description: The strcpy_s function copies the string pointed to by strSrc (including + * the terminating null character) into the array pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) + * Parameter: strSrc -source address + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t strcpy_s(char *strDest, size_t destMax, const char *strSrc); +#endif + +#if SECUREC_ENABLE_STRNCPY + /* + * Description: The strncpy_s function copies not more than n successive characters (not including + * the terminating null character) + * from the array pointed to by strSrc to the array pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) + * Parameter: strSrc -source address + * Parameter: count -copies count characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t strncpy_s(char *strDest, size_t destMax, const char *strSrc, size_t count); +#endif + +#if SECUREC_ENABLE_STRCAT + /* + * Description: The strcat_s function appends a copy of the string pointed to by strSrc (including + * the terminating null character) + * to the end of the string pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating null wide character) + * Parameter: strSrc -source address + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t strcat_s(char *strDest, size_t destMax, const char *strSrc); +#endif + +#if SECUREC_ENABLE_STRNCAT + /* + * Description: The strncat_s function appends not more than n successive characters (not including + * the terminating null character) + * from the array pointed to by strSrc to the end of the string pointed to by strDest. + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) + * Parameter: strSrc -source address + * Parameter: count -copies count characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t strncat_s(char *strDest, size_t destMax, const char *strSrc, size_t count); +#endif + +#if SECUREC_ENABLE_VSPRINTF + /* + * Description: The vsprintf_s function is equivalent to the vsprintf function except for the Parameter: destMax + * and the explicit runtime-constraints violation + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null wide characte) + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1. + */ + SECUREC_API int vsprintf_s(char *strDest, size_t destMax, const char *format, + va_list argList) SECUREC_ATTRIBUTE(3, 0); +#endif + +#if SECUREC_ENABLE_SPRINTF + /* + * Description: The sprintf_s function is equivalent to the sprintf function except for the Parameter: destMax + * and the explicit runtime-constraints violation + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) + * Parameter: format - fromat string + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1. + */ + SECUREC_API int sprintf_s(char *strDest, size_t destMax, const char *format, ...) SECUREC_ATTRIBUTE(3, 4); +#endif + +#if SECUREC_ENABLE_VSNPRINTF + /* + * Description: The vsnprintf_s function is equivalent to the vsnprintf function except for the Parameter: + * destMax/count and the explicit runtime-constraints violation + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) + * Parameter: count - do not write more than count bytes to strDest(not including the terminating null byte ('\0')) + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1.Pay special attention to returning -1 when truncation occurs + */ + SECUREC_API int vsnprintf_s(char *strDest, size_t destMax, size_t count, const char *format, + va_list argList) SECUREC_ATTRIBUTE(4, 0); +#endif + +#if SECUREC_ENABLE_SNPRINTF + /* + * Description: The snprintf_s function is equivalent to the snprintf function except for the Parameter: + * destMax/count and the explicit runtime-constraints violation + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) + * Parameter: count - do not write more than count bytes to strDest(not including the terminating null byte ('\0')) + * Parameter: format - fromat string + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1.Pay special attention to returning -1 when truncation occurs + */ + SECUREC_API int snprintf_s(char *strDest, size_t destMax, size_t count, const char *format, + ...) SECUREC_ATTRIBUTE(4, 5); +#endif + +#if SECUREC_SNPRINTF_TRUNCATED + /* + * Description: The vsnprintf_truncated_s function is equivalent to the vsnprintf_s function except + * no count Parameter: and Return: value + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1.Pay special attention to returning destMax - 1 when truncation occurs + */ + SECUREC_API int vsnprintf_truncated_s(char *strDest, size_t destMax, const char *format, + va_list argList) SECUREC_ATTRIBUTE(3, 0); + + /* + * Description: The snprintf_truncated_s function is equivalent to the snprintf_2 function except + * no count Parameter: and Return: value + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null byte ('\0')) + * Parameter: format - fromat string + * Return: the number of characters printed(not including the terminating null byte ('\0')), + * If an error occurred Return: -1.Pay special attention to returning destMax - 1 when truncation occurs + */ + SECUREC_API int snprintf_truncated_s(char *strDest, size_t destMax, + const char *format, ...) SECUREC_ATTRIBUTE(3, 4); +#endif + +#if SECUREC_ENABLE_SCANF + /* + * Description: The scanf_s function is equivalent to fscanf_s with the argument stdin + * interposed before the arguments to scanf_s + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int scanf_s(const char *format, ...); +#endif + +#if SECUREC_ENABLE_VSCANF + /* + * Description: The vscanf_s function is equivalent to scanf_s, with the variable argument list replaced by argList + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vscanf_s(const char *format, va_list argList); +#endif + +#if SECUREC_ENABLE_SSCANF + /* + * Description: The sscanf_s function is equivalent to fscanf_s, except that input is obtained from a + * string (specified by the argument buffer) rather than from a stream + * Parameter: buffer - read character from buffer + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int sscanf_s(const char *buffer, const char *format, ...); +#endif + +#if SECUREC_ENABLE_VSSCANF + /* + * Description: The vsscanf_s function is equivalent to sscanf_s, with the variable argument list + * replaced by argList + * Parameter: buffer - read character from buffer + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vsscanf_s(const char *buffer, const char *format, va_list argList); +#endif + +#if SECUREC_ENABLE_FSCANF + /* + * Description: The fscanf_s function is equivalent to fscanf except that the c, s, and [ conversion specifiers + * apply to a pair of arguments (unless assignment suppression is indicated by a*) + * Parameter: stream - stdio file stream + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int fscanf_s(FILE *stream, const char *format, ...); +#endif + +#if SECUREC_ENABLE_VFSCANF + /* + * Description: The vfscanf_s function is equivalent to fscanf_s, with the variable argument list + * replaced by argList + * Parameter: stream - stdio file stream + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vfscanf_s(FILE *stream, const char *format, va_list argList); +#endif + +#if SECUREC_ENABLE_STRTOK + /* + * Description: The strtok_s function parses a string into a sequence of strToken, + * replace all characters in strToken string that match to strDelimit set with 0. + * On the first call to strtok_s the string to be parsed should be specified in strToken. + * In each subsequent call that should parse the same string, strToken should be NULL + * Parameter: strToken - the string to be delimited + * Parameter: strDelimit -specifies a set of characters that delimit the tokens in the parsed string + * Parameter: context -is a pointer to a char * variable that is used internally by strtok_s function + * Return: On the first call returns the address of the first non \0 character, otherwise NULL is returned. + * In subsequent calls, the strtoken is set to NULL, and the context set is the same as the previous call, + * return NULL if the *context string length is equal 0, otherwise return *context. + */ + SECUREC_API char *strtok_s(char *strToken, const char *strDelimit, char **context); +#endif + +#if SECUREC_ENABLE_GETS && SECUREC_IN_KERNEL == 0 + /* + * Description: The gets_s function reads at most one less than the number of characters specified + * by destMax from the stream pointed to by stdin, into the array pointed to by buffer + * Parameter: buffer - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating null character) + * Return: buffer if there was no runtime-constraint violation,If an error occurred Return: NULL. + */ + SECUREC_API char *gets_s(char *buffer, size_t destMax); +#endif + + +#if SECUREC_ENABLE_WCHAR_FUNC +#if SECUREC_ENABLE_MEMCPY + /* + * Description: The wmemcpy_s function copies n successive wide characters from the object pointed to + * by src into the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: src -source address + * Parameter: count -copies count wide characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wmemcpy_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count); +#endif + +#if SECUREC_ENABLE_MEMMOVE + /* + * Description: The wmemmove_s function copies n successive wide characters from the object + * pointed to by src into the object pointed to by dest. + * Parameter: dest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: src -source address + * Parameter: count -copies count wide characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wmemmove_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count); +#endif + +#if SECUREC_ENABLE_STRCPY + /* + * Description: The wcscpy_s function copies the wide string pointed to by strSrc (including theterminating + * null wide character) into the array pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer + * Parameter: strSrc -source address + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wcscpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc); +#endif + +#if SECUREC_ENABLE_STRNCPY + /* + * Description: The wcsncpy_s function copies not more than n successive wide characters (not including the + * terminating null wide character) from the array pointed to by strSrc to the array pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) + * Parameter: strSrc -source address + * Parameter: count -copies count wide characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wcsncpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count); +#endif + +#if SECUREC_ENABLE_STRCAT + /* + * Description: The wcscat_s function appends a copy of the wide string pointed to by strSrc (including the + * terminating null wide character) to the end of the wide string pointed to by strDest + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) + * Parameter: strSrc -source address + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wcscat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc); +#endif + +#if SECUREC_ENABLE_STRNCAT + /* + * Description: The wcsncat_s function appends not more than n successive wide characters (not including the + * terminating null wide character) from the array pointed to by strSrc to the end of the wide string pointed to + * by strDest. + * Parameter: strDest - destination address + * Parameter: destMax -The maximum length of destination buffer(including the terminating wide character) + * Parameter: strSrc -source address + * Parameter: count -copies count wide characters from the src + * Return: EOK if there was no runtime-constraint violation + */ + SECUREC_API errno_t wcsncat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count); +#endif + +#if SECUREC_ENABLE_STRTOK + /* + * Description: The wcstok_s function is the wide-character equivalent of the strtok_s function + * Parameter: strToken - the string to be delimited + * Parameter: strDelimit -specifies a set of characters that delimit the tokens in the parsed string + * Parameter: context -is a pointer to a char * variable that is used internally by strtok_s function + * Return: a pointer to the first character of a token, or a null pointer if there is no token + * or there is a runtime-constraint violation. + */ + SECUREC_API wchar_t *wcstok_s(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context); +#endif + +#if SECUREC_ENABLE_VSPRINTF + /* + * Description: The vswprintf_s function is the wide-character equivalent of the vsprintf_s function + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null ) + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of characters printed(not including the terminating null wide characte), + * If an error occurred Return: -1. + */ + SECUREC_API int vswprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, va_list argList); +#endif + +#if SECUREC_ENABLE_SPRINTF + + /* + * Description: The swprintf_s function is the wide-character equivalent of the sprintf_s function + * Parameter: strDest - produce output according to a format ,write to the character string strDest + * Parameter: destMax - The maximum length of destination buffer(including the terminating null ) + * Parameter: format - fromat string + * Return: the number of characters printed(not including the terminating null wide characte), + * If an error occurred Return: -1. + */ + SECUREC_API int swprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, ...); +#endif + +#if SECUREC_ENABLE_FSCANF + /* + * Description: The fwscanf_s function is the wide-character equivalent of the fscanf_s function + * Parameter: stream - stdio file stream + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int fwscanf_s(FILE *stream, const wchar_t *format, ...); +#endif + +#if SECUREC_ENABLE_VFSCANF + /* + * Description: The vfwscanf_s function is the wide-character equivalent of the vfscanf_s function + * Parameter: stream - stdio file stream + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vfwscanf_s(FILE *stream, const wchar_t *format, va_list argList); +#endif + +#if SECUREC_ENABLE_SCANF + /* + * Description: The wscanf_s function is the wide-character equivalent of the scanf_s function + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int wscanf_s(const wchar_t *format, ...); +#endif + +#if SECUREC_ENABLE_VSCANF + /* + * Description: The vwscanf_s function is the wide-character equivalent of the vscanf_s function + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vwscanf_s(const wchar_t *format, va_list argList); +#endif + +#if SECUREC_ENABLE_SSCANF + /* + * Description: The swscanf_s function is the wide-character equivalent of the sscanf_s function + * Parameter: buffer - read character from buffer + * Parameter: format - fromat string + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int swscanf_s(const wchar_t *buffer, const wchar_t *format, ...); +#endif + +#if SECUREC_ENABLE_VSSCANF + /* + * Description: The vswscanf_s function is the wide-character equivalent of the vsscanf_s function + * Parameter: buffer - read character from buffer + * Parameter: format - fromat string + * Parameter: argList - instead of a variable number of arguments + * Return: the number of input items assigned, If an error occurred Return: -1. + */ + SECUREC_API int vswscanf_s(const wchar_t *buffer, const wchar_t *format, va_list argList); +#endif +#endif /* SECUREC_ENABLE_WCHAR_FUNC */ +#endif + + /* those functions are used by macro ,must declare hare , also for without function declaration warning */ + extern errno_t strncpy_error(char *strDest, size_t destMax, const char *strSrc, size_t count); + extern errno_t strcpy_error(char *strDest, size_t destMax, const char *strSrc); + +#if SECUREC_WITH_PERFORMANCE_ADDONS + /* those functions are used by macro */ + extern errno_t memset_sOptAsm(void *dest, size_t destMax, int c, size_t count); + extern errno_t memset_sOptTc(void *dest, size_t destMax, int c, size_t count); + extern errno_t memcpy_sOptAsm(void *dest, size_t destMax, const void *src, size_t count); + extern errno_t memcpy_sOptTc(void *dest, size_t destMax, const void *src, size_t count); + +/* strcpy_sp is a macro, NOT a function in performance optimization mode. */ +#define strcpy_sp(dest, destMax, src) ((__builtin_constant_p((destMax)) && \ + __builtin_constant_p((src))) ? \ + SECUREC_STRCPY_SM((dest), (destMax), (src)) : \ + strcpy_s((dest), (destMax), (src))) + +/* strncpy_sp is a macro, NOT a function in performance optimization mode. */ +#define strncpy_sp(dest, destMax, src, count) ((__builtin_constant_p((count)) && \ + __builtin_constant_p((destMax)) && \ + __builtin_constant_p((src))) ? \ + SECUREC_STRNCPY_SM((dest), (destMax), (src), (count)) : \ + strncpy_s((dest), (destMax), (src), (count))) + +/* strcat_sp is a macro, NOT a function in performance optimization mode. */ +#define strcat_sp(dest, destMax, src) ((__builtin_constant_p((destMax)) && \ + __builtin_constant_p((src))) ? \ + SECUREC_STRCAT_SM((dest), (destMax), (src)) : \ + strcat_s((dest), (destMax), (src))) + +/* strncat_sp is a macro, NOT a function in performance optimization mode. */ +#define strncat_sp(dest, destMax, src, count) ((__builtin_constant_p((count)) && \ + __builtin_constant_p((destMax)) && \ + __builtin_constant_p((src))) ? \ + SECUREC_STRNCAT_SM((dest), (destMax), (src), (count)) : \ + strncat_s((dest), (destMax), (src), (count))) + +/* memcpy_sp is a macro, NOT a function in performance optimization mode. */ +#define memcpy_sp(dest, destMax, src, count) (__builtin_constant_p((count)) ? \ + (SECUREC_MEMCPY_SM((dest), (destMax), (src), (count))) : \ + (__builtin_constant_p((destMax)) ? \ + (((size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & \ + (unsigned long long)(-2)) < SECUREC_MEM_MAX_LEN)) ? \ + memcpy_sOptTc((dest), (destMax), (src), (count)) : ERANGE) : \ + memcpy_sOptAsm((dest), (destMax), (src), (count)))) + +/* memset_sp is a macro, NOT a function in performance optimization mode. */ +#define memset_sp(dest, destMax, c, count) (__builtin_constant_p((count)) ? \ + (SECUREC_MEMSET_SM((dest), (destMax), (c), (count))) : \ + (__builtin_constant_p((destMax)) ? \ + (((size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & \ + (unsigned long long)(-2)) < SECUREC_MEM_MAX_LEN)) ? \ + memset_sOptTc((dest), (destMax), (c), (count)) : ERANGE) : \ + memset_sOptAsm((dest), (destMax), (c), (count)))) +#else +#define strcpy_sp strcpy_s +#define strncpy_sp strncpy_s +#define strcat_sp strcat_s +#define strncat_sp strncat_s +#define memcpy_sp memcpy_s +#define memset_sp memset_s +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +#endif /* __SECUREC_H__5D13A042_DC3F_4ED9_A8D1_882811274C27 */ + diff --git a/third_party/securec/include/securectype.h b/third_party/securec/include/securectype.h new file mode 100644 index 00000000..0aed2a67 --- /dev/null +++ b/third_party/securec/include/securectype.h @@ -0,0 +1,542 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 +#define __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 + +#ifndef SECUREC_USING_STD_SECURE_LIB +#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if defined(__STDC_WANT_SECURE_LIB__) && __STDC_WANT_SECURE_LIB__ == 0 +/* Security functions have been provided since vs2005, default use of system library functions */ +#define SECUREC_USING_STD_SECURE_LIB 0 +#else +#define SECUREC_USING_STD_SECURE_LIB 1 +#endif +#else +#define SECUREC_USING_STD_SECURE_LIB 0 +#endif +#endif + + +/* Compatibility with older Secure C versions, shielding VC symbol redefinition warning */ +#if defined(_MSC_VER) && _MSC_VER >= 1400 && SECUREC_USING_STD_SECURE_LIB == 0 +#ifndef SECUREC_DISABLE_CRT_FUNC +#define SECUREC_DISABLE_CRT_FUNC 1 +#endif +#ifndef SECUREC_DISABLE_CRT_IMP +#define SECUREC_DISABLE_CRT_IMP 1 +#endif +#else /* MSC VER */ +#ifndef SECUREC_DISABLE_CRT_FUNC +#define SECUREC_DISABLE_CRT_FUNC 0 +#endif +#ifndef SECUREC_DISABLE_CRT_IMP +#define SECUREC_DISABLE_CRT_IMP 0 +#endif +#endif + +#if SECUREC_DISABLE_CRT_FUNC +#ifdef __STDC_WANT_SECURE_LIB__ +#undef __STDC_WANT_SECURE_LIB__ +#endif +#define __STDC_WANT_SECURE_LIB__ 0 +#endif + +#if SECUREC_DISABLE_CRT_IMP +#ifdef _CRTIMP_ALTERNATIVE +#undef _CRTIMP_ALTERNATIVE +#endif +#define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ +#endif + +/* Compile in kernel under macro control */ +#ifndef SECUREC_IN_KERNEL +#ifdef __KERNEL__ +#define SECUREC_IN_KERNEL 1 +#else +#define SECUREC_IN_KERNEL 0 +#endif +#endif + +#if SECUREC_IN_KERNEL +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 0 +#endif +#ifndef SECUREC_ENABLE_WCHAR_FUNC +#define SECUREC_ENABLE_WCHAR_FUNC 0 +#endif +#else /* SECUREC_IN_KERNEL */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 1 +#endif +#ifndef SECUREC_ENABLE_WCHAR_FUNC +#define SECUREC_ENABLE_WCHAR_FUNC 1 +#endif +#endif + + +/* Default secure function declaration, default declarations for non-standard functions */ +#ifndef SECUREC_SNPRINTF_TRUNCATED +#define SECUREC_SNPRINTF_TRUNCATED 1 +#endif + +#if SECUREC_USING_STD_SECURE_LIB +#if defined(_MSC_VER) && _MSC_VER >= 1400 +/* Declare secure functions that are not available in the vs compiler */ +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 1 +#endif +/* vs 2005 have vsnprintf_s function */ +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +/* vs 2005 have vsnprintf_s function Adapt the snprintf_s of the security function */ +#define snprintf_s _snprintf_s +#define SECUREC_ENABLE_SNPRINTF 0 +#endif +/* befor vs 2010 do not have v functions */ +#if _MSC_VER <= 1600 || defined(SECUREC_FOR_V_SCANFS) +#ifndef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSCANF +#define SECUREC_ENABLE_VSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 1 +#endif +#endif + +#else /* _MSC_VER */ +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 0 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +#define SECUREC_ENABLE_SNPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_MEMMOVE +#define SECUREC_ENABLE_MEMMOVE 0 +#endif +#ifndef SECUREC_ENABLE_MEMCPY +#define SECUREC_ENABLE_MEMCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRCPY +#define SECUREC_ENABLE_STRCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRNCPY +#define SECUREC_ENABLE_STRNCPY 0 +#endif +#ifndef SECUREC_ENABLE_STRCAT +#define SECUREC_ENABLE_STRCAT 0 +#endif +#ifndef SECUREC_ENABLE_STRNCAT +#define SECUREC_ENABLE_STRNCAT 0 +#endif +#ifndef SECUREC_ENABLE_SPRINTF +#define SECUREC_ENABLE_SPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_VSPRINTF +#define SECUREC_ENABLE_VSPRINTF 0 +#endif +#ifndef SECUREC_ENABLE_SSCANF +#define SECUREC_ENABLE_SSCANF 0 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 0 +#endif +#ifndef SECUREC_ENABLE_SCANF +#define SECUREC_ENABLE_SCANF 0 +#endif +#ifndef SECUREC_ENABLE_VSCANF +#define SECUREC_ENABLE_VSCANF 0 +#endif + +#ifndef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif +#ifndef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#ifndef SECUREC_ENABLE_STRTOK +#define SECUREC_ENABLE_STRTOK 0 +#endif +#ifndef SECUREC_ENABLE_GETS +#define SECUREC_ENABLE_GETS 0 +#endif + +#else /* SECUREC_USE_STD_SECURE_LIB */ + +#ifndef SECUREC_ENABLE_MEMSET +#define SECUREC_ENABLE_MEMSET 1 +#endif +#ifndef SECUREC_ENABLE_MEMMOVE +#define SECUREC_ENABLE_MEMMOVE 1 +#endif +#ifndef SECUREC_ENABLE_MEMCPY +#define SECUREC_ENABLE_MEMCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRCPY +#define SECUREC_ENABLE_STRCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRNCPY +#define SECUREC_ENABLE_STRNCPY 1 +#endif +#ifndef SECUREC_ENABLE_STRCAT +#define SECUREC_ENABLE_STRCAT 1 +#endif +#ifndef SECUREC_ENABLE_STRNCAT +#define SECUREC_ENABLE_STRNCAT 1 +#endif +#ifndef SECUREC_ENABLE_SPRINTF +#define SECUREC_ENABLE_SPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_VSPRINTF +#define SECUREC_ENABLE_VSPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_SNPRINTF +#define SECUREC_ENABLE_SNPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_VSNPRINTF +#define SECUREC_ENABLE_VSNPRINTF 1 +#endif +#ifndef SECUREC_ENABLE_SSCANF +#define SECUREC_ENABLE_SSCANF 1 +#endif +#ifndef SECUREC_ENABLE_VSSCANF +#define SECUREC_ENABLE_VSSCANF 1 +#endif +#ifndef SECUREC_ENABLE_SCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF 1 +#else +#define SECUREC_ENABLE_SCANF 0 +#endif +#endif +#ifndef SECUREC_ENABLE_VSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_VSCANF 1 +#else +#define SECUREC_ENABLE_VSCANF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_FSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_FSCANF 1 +#else +#define SECUREC_ENABLE_FSCANF 0 +#endif +#endif +#ifndef SECUREC_ENABLE_VFSCANF +#if SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_VFSCANF 1 +#else +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#endif + +#ifndef SECUREC_ENABLE_STRTOK +#define SECUREC_ENABLE_STRTOK 1 +#endif +#ifndef SECUREC_ENABLE_GETS +#define SECUREC_ENABLE_GETS 1 +#endif +#endif /* SECUREC_USE_STD_SECURE_LIB */ + +#if SECUREC_ENABLE_SCANF_FILE == 0 +#if SECUREC_ENABLE_FSCANF +#undef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif +#if SECUREC_ENABLE_VFSCANF +#undef SECUREC_ENABLE_VFSCANF +#define SECUREC_ENABLE_VFSCANF 0 +#endif +#if SECUREC_ENABLE_SCANF +#undef SECUREC_ENABLE_SCANF +#define SECUREC_ENABLE_SCANF 0 +#endif +#if SECUREC_ENABLE_FSCANF +#undef SECUREC_ENABLE_FSCANF +#define SECUREC_ENABLE_FSCANF 0 +#endif + +#endif + +#if SECUREC_IN_KERNEL +#include +#include +#else +#include +#include +#include +#endif + +/* If you need high performance, enable the SECUREC_WITH_PERFORMANCE_ADDONS macro, default is enable . + * The macro is automatically closed on the windows platform and linux kernel + */ +#ifndef SECUREC_WITH_PERFORMANCE_ADDONS +#if SECUREC_IN_KERNEL +#define SECUREC_WITH_PERFORMANCE_ADDONS 0 +#else +#define SECUREC_WITH_PERFORMANCE_ADDONS 1 +#endif +#endif + +/* if enable SECUREC_COMPATIBLE_WIN_FORMAT, the output format will be compatible to Windows. */ +#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) && !defined(SECUREC_COMPATIBLE_LINUX_FORMAT) +#if !defined(SECUREC_COMPATIBLE_WIN_FORMAT) +#define SECUREC_COMPATIBLE_WIN_FORMAT +#endif +#endif + +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) +/* in windows platform, can't use optimized function for there is no __builtin_constant_p like function */ +/* If need optimized macro, can define this: define __builtin_constant_p(x) 0 */ +#ifdef SECUREC_WITH_PERFORMANCE_ADDONS +#undef SECUREC_WITH_PERFORMANCE_ADDONS +#define SECUREC_WITH_PERFORMANCE_ADDONS 0 +#endif +#endif + +#if defined(__VXWORKS__) || defined(__vxworks) || defined(__VXWORKS) || defined(_VXWORKS_PLATFORM_) || \ + defined(SECUREC_VXWORKS_VERSION_5_4) +#if !defined(SECUREC_VXWORKS_PLATFORM) +#define SECUREC_VXWORKS_PLATFORM +#endif +#endif + +/* if enable SECUREC_COMPATIBLE_LINUX_FORMAT, the output format will be compatible to Linux. */ +#if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) +#if !defined(SECUREC_COMPATIBLE_LINUX_FORMAT) +#define SECUREC_COMPATIBLE_LINUX_FORMAT +#endif +#endif + +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT +#include +#endif + +/* add the -DSECUREC_SUPPORT_FORMAT_WARNING compiler option to supoort -Wformat. + * default does not check the format is that the same data type in the actual code + * in the product is different in the original data type definition of VxWorks and Linux. + */ +#ifndef SECUREC_SUPPORT_FORMAT_WARNING +#define SECUREC_SUPPORT_FORMAT_WARNING 0 +#endif + +/* SECUREC_PCLINT for tool do not recognize __attribute__ just for pclint */ +#if SECUREC_SUPPORT_FORMAT_WARNING && !defined(SECUREC_PCLINT) +#define SECUREC_ATTRIBUTE(x, y) __attribute__((format(printf, (x), (y)))) +#else +#define SECUREC_ATTRIBUTE(x, y) +#endif + +/* SECUREC_PCLINT for tool do not recognize __builtin_expect, just for pclint */ +#if defined(__GNUC__) && \ + ((__GNUC__ > 3 || (__GNUC__ == 3 && __GNUC_MINOR__ > 3))) && \ + !defined(SECUREC_PCLINT) +/* This is a built-in function that can be used without a declaration, if you encounter an undeclared compilation alarm, + * you can add -DSECUREC_NEED_BUILTIN_EXPECT_DECLARE to complier options + */ +#if defined(SECUREC_NEED_BUILTIN_EXPECT_DECLARE) +long __builtin_expect(long exp, long c); +#endif +#define SECUREC_LIKELY(x) __builtin_expect(!!(x), 1) +#define SECUREC_UNLIKELY(x) __builtin_expect(!!(x), 0) +#else +#define SECUREC_LIKELY(x) (x) +#define SECUREC_UNLIKELY(x) (x) +#endif + +/* define the max length of the string */ +#ifndef SECUREC_STRING_MAX_LEN +#define SECUREC_STRING_MAX_LEN (0x7fffffffUL) +#endif +#define SECUREC_WCHAR_STRING_MAX_LEN (SECUREC_STRING_MAX_LEN / sizeof(wchar_t)) + +/* add SECUREC_MEM_MAX_LEN for memcpy and memmove */ +#ifndef SECUREC_MEM_MAX_LEN +#define SECUREC_MEM_MAX_LEN (0x7fffffffUL) +#endif +#define SECUREC_WCHAR_MEM_MAX_LEN (SECUREC_MEM_MAX_LEN / sizeof(wchar_t)) + +#if SECUREC_STRING_MAX_LEN > 0x7fffffff +#error "max string is 2G" +#endif + +#if (defined(__GNUC__) && defined(__SIZEOF_POINTER__)) +#if (__SIZEOF_POINTER__ != 4) && (__SIZEOF_POINTER__ != 8) +#error "unsupported system" +#endif +#endif + +#if defined(_WIN64) || defined(WIN64) || defined(__LP64__) || defined(_LP64) +#define SECUREC_ON_64BITS +#endif + +#if (!defined(SECUREC_ON_64BITS) && defined(__GNUC__) && defined(__SIZEOF_POINTER__)) +#if __SIZEOF_POINTER__ == 8 +#define SECUREC_ON_64BITS +#endif +#endif + +#if defined(__SVR4) || defined(__svr4__) +#define SECUREC_ON_SOLARIS +#endif + +#if (defined(__hpux) || defined(_AIX) || defined(SECUREC_ON_SOLARIS)) +#define SECUREC_ON_UNIX +#endif + +/* codes should run under the macro SECUREC_COMPATIBLE_LINUX_FORMAT in unknow system on default, + * and strtold. The function + * strtold is referenced first at ISO9899:1999(C99), and some old compilers can + * not support these functions. Here provides a macro to open these functions: + * SECUREC_SUPPORT_STRTOLD -- if defined, strtold will be used + */ +#ifndef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 0 +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) +#if defined(__USE_ISOC99) || \ + (defined(_AIX) && defined(_ISOC99_SOURCE)) || \ + (defined(__hpux) && defined(__ia64)) || \ + (defined(SECUREC_ON_SOLARIS) && (!defined(_STRICT_STDC) && !defined(__XOPEN_OR_POSIX)) || \ + defined(_STDC_C99) || defined(__EXTENSIONS__)) +#undef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 1 +#endif +#endif +#if ((defined(SECUREC_WRLINUX_BELOW4) || defined(_WRLINUX_BELOW4_))) +#undef SECUREC_SUPPORT_STRTOLD +#define SECUREC_SUPPORT_STRTOLD 0 +#endif +#endif + + +#if SECUREC_WITH_PERFORMANCE_ADDONS + +#ifndef SECUREC_TWO_MIN +#define SECUREC_TWO_MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +/* for strncpy_s performance optimization */ +#define SECUREC_STRNCPY_SM(dest, destMax, src, count) \ + (((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ + (SECUREC_TWO_MIN((size_t)(count), strlen(src)) + 1) <= (size_t)(destMax)) ? \ + (((size_t)(count) < strlen(src)) ? (memcpy((dest), (src), (count)), *((char *)(dest) + (count)) = '\0', EOK) : \ + (memcpy((dest), (src), strlen(src) + 1), EOK)) : (strncpy_error((dest), (destMax), (src), (count)))) + +#define SECUREC_STRCPY_SM(dest, destMax, src) \ + (((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ + (strlen(src) + 1) <= (size_t)(destMax)) ? (memcpy((dest), (src), strlen(src) + 1), EOK) : \ + (strcpy_error((dest), (destMax), (src)))) + +/* for strcat_s performance optimization */ +#if defined(__GNUC__) +#define SECUREC_STRCAT_SM(dest, destMax, src) ({ \ + int catRet = EOK; \ + if ((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN)) { \ + char *catTmpDst = (char *)(dest); \ + size_t catRestSize = (destMax); \ + while (catRestSize > 0 && *catTmpDst != '\0') { \ + ++catTmpDst; \ + --catRestSize; \ + } \ + if (catRestSize == 0) { \ + catRet = EINVAL; \ + } else if ((strlen(src) + 1) <= catRestSize) { \ + memcpy(catTmpDst, (src), strlen(src) + 1); \ + catRet = EOK; \ + } else { \ + catRet = ERANGE; \ + } \ + if (catRet != EOK) { \ + catRet = strcat_s((dest), (destMax), (src)); \ + } \ + } else { \ + catRet = strcat_s((dest), (destMax), (src)); \ + } \ + catRet; \ +}) +#else +#define SECUREC_STRCAT_SM(dest, destMax, src) strcat_s((dest), (destMax), (src)) +#endif + +/* for strncat_s performance optimization */ +#if defined(__GNUC__) +#define SECUREC_STRNCAT_SM(dest, destMax, src, count) ({ \ + int ncatRet = EOK; \ + if ((void *)(dest) != NULL && (void *)(src) != NULL && (size_t)(destMax) > 0 && \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN) && \ + (((unsigned long long)(count) & (unsigned long long)(-2)) < SECUREC_STRING_MAX_LEN)) { \ + char *ncatTmpDest = (char *)(dest); \ + size_t ncatRestSize = (size_t)(destMax); \ + while (ncatRestSize > 0 && *ncatTmpDest != '\0') { \ + ++ncatTmpDest; \ + --ncatRestSize; \ + } \ + if (ncatRestSize == 0) { \ + ncatRet = EINVAL; \ + } else if ((SECUREC_TWO_MIN((count), strlen(src)) + 1) <= ncatRestSize) { \ + if ((size_t)(count) < strlen(src)) { \ + memcpy(ncatTmpDest, (src), (count)); \ + *(ncatTmpDest + (count)) = '\0'; \ + } else { \ + memcpy(ncatTmpDest, (src), strlen(src) + 1); \ + } \ + } else { \ + ncatRet = ERANGE; \ + } \ + if (ncatRet != EOK) { \ + ncatRet = strncat_s((dest), (destMax), (src), (count)); \ + } \ + } else { \ + ncatRet = strncat_s((dest), (destMax), (src), (count)); \ + } \ + ncatRet; \ +}) +#else +#define SECUREC_STRNCAT_SM(dest, destMax, src, count) strncat_s((dest), (destMax), (src), (count)) +#endif + +/* SECUREC_MEMCPY_SM do NOT check buffer overlap by default */ +#define SECUREC_MEMCPY_SM(dest, destMax, src, count) \ + (!(((size_t)(destMax) == 0) || \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ + ((size_t)(count) > (size_t)(destMax)) || ((void *)(dest)) == NULL || ((void *)(src) == NULL))? \ + (memcpy((dest), (src), (count)), EOK) : \ + (memcpy_s((dest), (destMax), (src), (count)))) + +#define SECUREC_MEMSET_SM(dest, destMax, c, count) \ + (!(((size_t)(destMax) == 0) || \ + (((unsigned long long)(destMax) & (unsigned long long)(-2)) > SECUREC_MEM_MAX_LEN) || \ + ((void *)(dest) == NULL) || ((size_t)(count) > (size_t)(destMax))) ? \ + (memset((dest), (c), (count)), EOK) : \ + (memset_s((dest), (destMax), (c), (count)))) + +#endif +#endif /* __SECURECTYPE_H__A7BBB686_AADA_451B_B9F9_44DACDAE18A7 */ + diff --git a/third_party/securec/src/CMakeLists.txt b/third_party/securec/src/CMakeLists.txt new file mode 100644 index 00000000..60ec0a90 --- /dev/null +++ b/third_party/securec/src/CMakeLists.txt @@ -0,0 +1,3 @@ +aux_source_directory(. SECUREC_SRCS) + +add_library(securec STATIC ${SECUREC_SRCS}) diff --git a/third_party/securec/src/fscanf_s.c b/third_party/securec/src/fscanf_s.c new file mode 100644 index 00000000..8ceda9ac --- /dev/null +++ b/third_party/securec/src/fscanf_s.c @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The fscanf_s function is equivalent to fscanf except that the c, s, + * and [ conversion specifiers apply to a pair of arguments (unless assignment suppression is indicated by a*) + * The fscanf function reads data from the current position of stream into + * the locations given by argument (if any). Each argument must be a pointer + * to a variable of a type that corresponds to a type specifier in format. + * format controls the interpretation of the input fields and has the same + * form and function as the format argument for scanf. + * + * + * stream Pointer to FILE structure. + * format Format control string, see Format Specifications. + * ... Optional arguments. + * + * + * ... The convered value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int fscanf_s(FILE *stream, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vfscanf_s(stream, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + + diff --git a/third_party/securec/src/fwscanf_s.c b/third_party/securec/src/fwscanf_s.c new file mode 100644 index 00000000..f826b7db --- /dev/null +++ b/third_party/securec/src/fwscanf_s.c @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The fwscanf_s function is the wide-character equivalent of the fscanf_s function + * The fwscanf_s function reads data from the current position of stream into + * the locations given by argument (if any). Each argument must be a pointer + * to a variable of a type that corresponds to a type specifier in format. + * format controls the interpretation of the input fields and has the same + * form and function as the format argument for scanf. + * + * + * stream Pointer to FILE structure. + * format Format control string, see Format Specifications. + * ... Optional arguments. + * + * + * ... The converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int fwscanf_s(FILE *stream, const wchar_t *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vfwscanf_s(stream, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + + diff --git a/third_party/securec/src/gets_s.c b/third_party/securec/src/gets_s.c new file mode 100644 index 00000000..57fd6231 --- /dev/null +++ b/third_party/securec/src/gets_s.c @@ -0,0 +1,75 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securecutil.h" + +static void SecTrimCRLF(char *buffer, size_t len) +{ + int i; + /* No need to determine whether integer overflow exists */ + for (i = (int)(len - 1); i >= 0 && (buffer[i] == '\r' || buffer[i] == '\n'); --i) { + buffer[i] = '\0'; + } + return; +} + +/* + * + * The gets_s function reads at most one less than the number of characters + * specified by destMax from the stream pointed to by stdin, into the array pointed to by buffer + * The line consists of all characters up to and including + * the first newline character ('\n'). gets_s then replaces the newline + * character with a null character ('\0') before returning the line. + * If the first character read is the end-of-file character, a null character + * is stored at the beginning of buffer and NULL is returned. + * + * + * buffer Storage location for input string. + * numberOfElements The size of the buffer. + * + * + * buffer is updated + * + * + * buffer Successful operation + * NULL Improper parameter or read fail + */ +char *gets_s(char *buffer, size_t numberOfElements) +{ + size_t len; +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + size_t bufferSize = ((numberOfElements == (size_t)-1) ? SECUREC_STRING_MAX_LEN : numberOfElements); +#else + size_t bufferSize = numberOfElements; +#endif + + if (buffer == NULL || bufferSize == 0 || bufferSize > SECUREC_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_PARAMTER("gets_s"); + return NULL; + } + + if (fgets(buffer, (int)bufferSize, stdin) == NULL) { + return NULL; + } + + len = strlen(buffer); + if (len > 0 && len < bufferSize) { + SecTrimCRLF(buffer, len); + } + + return buffer; +} + diff --git a/third_party/securec/src/input.inl b/third_party/securec/src/input.inl new file mode 100644 index 00000000..a5a92e56 --- /dev/null +++ b/third_party/securec/src/input.inl @@ -0,0 +1,2125 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INPUT_INL_5D13A042_DC3F_4ED9_A8D1_882811274C27 +#define INPUT_INL_5D13A042_DC3F_4ED9_A8D1_882811274C27 + +#if SECUREC_IN_KERNEL +#include +#ifndef EOF +#define EOF (-1) +#endif +#else +#if !defined(SECUREC_SYSAPI4VXWORKS) && !defined(SECUREC_CTYPE_MACRO_ADAPT) +#include +#ifdef SECUREC_FOR_WCHAR +#include /* for iswspace */ +#endif +#endif +#endif + +#define SECUREC_NUM_WIDTH_SHORT 0 +#define SECUREC_NUM_WIDTH_INT 1 +#define SECUREC_NUM_WIDTH_LONG 2 +#define SECUREC_NUM_WIDTH_LONG_LONG 3 /* also long double */ + +#define SECUREC_BUF_EXT_MUL 2 +#define SECUREC_BUFFERED_BLOK_SIZE 1024 + +#if defined(SECUREC_VXWORKS_PLATFORM) && !defined(va_copy) && !defined(__va_copy) +/* the name is the same as system macro. */ +#define __va_copy(d, s) do { \ + size_t size_of_d = (size_t)sizeof(d); \ + size_t size_of_s = (size_t)sizeof(s); \ + if (size_of_d != size_of_s) { \ + (void)memcpy((d), (s), sizeof(va_list)); \ + } else { \ + (void)memcpy(&(d), &(s), sizeof(va_list)); \ + } \ +} SECUREC_WHILE_ZERO +#endif + + +#define SECUREC_MULTI_BYTE_MAX_LEN 6 +/* Record a flag for each bit */ +#define SECUREC_BRACKET_INDEX(x) ((unsigned int)(x) >> 3) +#define SECUREC_BRACKET_VALUE(x) ((unsigned char)(1 << ((unsigned int)(x) & 7))) + + +/* Compatibility macro name cannot be modifie */ +#ifndef UNALIGNED +#if !(defined(_M_IA64)) && !(defined(_M_AMD64)) +#define UNALIGNED +#else +#define UNALIGNED __unaligned +#endif +#endif + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) +/* Max 64bit value is 0xffffffffffffffff */ +#define SECUREC_MAX_64BITS_VALUE 18446744073709551615ULL +#define SECUREC_MAX_64BITS_VALUE_DIV_TEN 1844674407370955161ULL +#define SECUREC_MAX_64BITS_VALUE_CUT_LAST_DIGIT 18446744073709551610ULL +#define SECUREC_MIN_64BITS_NEG_VALUE 9223372036854775808ULL +#define SECUREC_MAX_64BITS_POS_VALUE 9223372036854775807ULL +#define SECUREC_MIN_32BITS_NEG_VALUE 2147483648ULL +#define SECUREC_MAX_32BITS_POS_VALUE 2147483647ULL +#define SECUREC_MAX_32BITS_VALUE 4294967295ULL +#define SECUREC_MAX_32BITS_VALUE_INC 4294967296ULL +#define SECUREC_MAX_32BITS_VALUE_DIV_TEN 429496729ULL +#define SECUREC_LONG_BIT_NUM ((unsigned int)(sizeof(long) << 3U)) + +#define SECUREC_LONG_HEX_BEYOND_MAX(number) (((number) >> (SECUREC_LONG_BIT_NUM - 4U)) > 0) +#define SECUREC_LONG_OCTAL_BEYOND_MAX(number) (((number) >> (SECUREC_LONG_BIT_NUM - 3U)) > 0) + +#define SECUREC_QWORD_HEX_BEYOND_MAX(number) (((number) >> (64U - 4U)) > 0) +#define SECUREC_QWORD_OCTAL_BEYOND_MAX(number) (((number) >> (64U - 3U)) > 0) + +#define SECUREC_LP64_BIT_WIDTH 64 +#define SECUREC_LP32_BIT_WIDTH 32 + +#endif + +#define SECUREC_CHAR(x) (x) +#define SECUREC_BRACE '{' /* [ to { */ + +#ifdef SECUREC_FOR_WCHAR +#define SECUREC_SCANF_BRACKET_CONDITION(comChr, ch, table, mask) ((comChr) == SECUREC_BRACE && \ + (table) != NULL && \ + (((table)[((unsigned int)(int)(ch) & SECUREC_CHAR_MASK) >> 3] ^ (mask)) & \ + (1 << ((unsigned int)(int)(ch) & 7)))) +#else +#define SECUREC_SCANF_BRACKET_CONDITION(comChr, ch, table, mask) ((comChr) == SECUREC_BRACE && \ + (((table)[((unsigned char)(ch) & 0xff) >> 3] ^ (mask)) & (1 << ((unsigned char)(ch) & 7)))) +#endif +#define SECUREC_SCANF_STRING_CONDITION(comChr, ch) ((comChr) == SECUREC_CHAR('s') && \ + (!((ch) >= SECUREC_CHAR('\t') && (ch) <= SECUREC_CHAR('\r')) && (ch) != SECUREC_CHAR(' '))) + +/* Do not use |= optimize this code, it will cause compiling warning */ +/* only supports wide characters with a maximum length of two bytes */ +#define SECUREC_BRACKET_SET_BIT(table, ch) do { \ + unsigned int tableIndex = SECUREC_BRACKET_INDEX(((unsigned int)(int)(ch) & SECUREC_CHAR_MASK)); \ + unsigned int tableValue = SECUREC_BRACKET_VALUE(((unsigned int)(int)(ch) & SECUREC_CHAR_MASK)); \ + (table)[tableIndex] = (unsigned char)((table)[tableIndex] | tableValue); \ +} SECUREC_WHILE_ZERO + +#ifdef SECUREC_FOR_WCHAR +/* table size is 32 x 256 */ +#define SECUREC_BRACKET_TABLE_SIZE 8192 +#define SECUREC_EOF WEOF +#define SECUREC_MB_LEN 16 /* max. # bytes in multibyte char ,see MB_LEN_MAX */ +/* int to unsigned int clear e571 */ +#define SECUREC_IS_DIGIT(chr) (!((unsigned int)(int)(chr) & 0xff00) && isdigit(((unsigned int)(int)(chr) & 0x00ff))) +#define SECUREC_IS_XDIGIT(chr) (!((unsigned int)(int)(chr) & 0xff00) && isxdigit(((unsigned int)(int)(chr) & 0x00ff))) +#define SECUREC_IS_SPACE(chr) iswspace((wint_t)(int)(chr)) +#else +#define SECUREC_BRACKET_TABLE_SIZE 32 +#define SECUREC_EOF EOF +#define SECUREC_IS_DIGIT(chr) isdigit((unsigned char)(chr) & 0x00ff) +#define SECUREC_IS_XDIGIT(chr) isxdigit((unsigned char)(chr) & 0x00ff) +#define SECUREC_IS_SPACE(chr) isspace((unsigned char)(chr) & 0x00ff) +#endif + + +static SecInt SecSkipSpaceChar(SecFileStream *stream, int *counter); +static SecInt SecGetChar(SecFileStream *stream, int *counter); +static void SecUnGetChar(SecInt ch, SecFileStream *stream, int *counter); + +typedef struct { +#ifdef SECUREC_FOR_WCHAR + unsigned char *table; /* default NULL */ +#else + unsigned char table[SECUREC_BRACKET_TABLE_SIZE]; /* Array length is large enough in application scenarios */ +#endif + unsigned char mask; /* default 0 */ +} SecBracketTable; + +#ifdef SECUREC_FOR_WCHAR +#define SECUREC_INIT_BRACKET_TABLE { NULL, 0 } +#else +#define SECUREC_INIT_BRACKET_TABLE { { 0 }, 0 } +#endif + +#if SECUREC_ENABLE_SCANF_FLOAT +typedef struct { + size_t floatStrSize; /* tialization must be length of buffer in charater */ + size_t floatStrUsedLen; /* store float string len */ + SecChar buffer[SECUREC_FLOAT_BUFSIZE + 1]; + SecChar *floatStr; /* Initialization must point to buffer */ + SecChar *allocatedFloatStr; /* Initialization must be NULL to store alloced point */ +} SecFloatSpec; +#endif + +typedef struct { + SecUnsignedInt64 number64; + unsigned long number; + int numberWidth; /* 0 = SHORT, 1 = int, > 1 long or L_DOUBLE */ + int isInt64Arg; /* 1 for 64-bit integer, 0 otherwise */ + int negative; /* 0 is positive */ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + int beyondMax; /* Non-zero means beyond */ +#endif + void *argPtr; /* Variable parameter pointer */ + size_t arrayWidth; /* length of pointer Variable parameter, in charaters */ + int width; /* width number in format */ + int widthSet; /* 0 is not set width in format */ + int comChr; /* Lowercase format conversion characters */ + int oriComChr; /* store number conversion */ + signed char isWChar; /* -1/0 not wchar, 1 for wchar */ + char suppress; /* 0 is not have %* in format */ +} SecScanSpec; + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) +#define SECUREC_INIT_NUMBER_SPEC { 0, 0, 0, 0, 0, 0, NULL, 0, 0, 0, 0, 0, 0 } +#else +#define SECUREC_INIT_NUMBER_SPEC { 0, 0, 0, 0, 0, 0, NULL, 0, 0, 0, 0, 0 } +#endif + +#ifdef SECUREC_FOR_WCHAR +#define SECUREC_GETC fgetwc +#define SECUREC_UN_GETC ungetwc +#define SECUREC_CHAR_MASK 0xffff +#else +#define SECUREC_GETC fgetc +#define SECUREC_UN_GETC ungetc +#define SECUREC_CHAR_MASK 0xff +#endif + +/* + * Determine if it is a 64-bit pointer function + * return 0 is not ,1 is 64bit pointer + */ +static int SecIs64BitPtr(size_t sizeOfVoidStar) +{ + /* point size is 4 or 8 , Under the 64 bit system, the value not 0 */ + /* to clear e778 */ + if ((sizeOfVoidStar & sizeof(SecInt64)) != 0) { + return 1; + } + return 0; +} + +#if SECUREC_ENABLE_SCANF_FLOAT + +/* + * Convert a floating point string to a floating point number + */ +static void SecAssignFloat(const char *floatStr, int numberWidth, void *argPtr) +{ + char *endPtr = NULL; + double d; +#if SECUREC_SUPPORT_STRTOLD + if (numberWidth == SECUREC_NUM_WIDTH_LONG_LONG) { + long double d2 = strtold(floatStr, &endPtr); + *(long double UNALIGNED *)(argPtr) = d2; + return; + } +#endif + d = strtod(floatStr, &endPtr); + if (numberWidth > SECUREC_NUM_WIDTH_INT) { + *(double UNALIGNED *)(argPtr) = (double)d; + } else { + *(float UNALIGNED *)(argPtr) = (float)d; + } +} + +#ifdef SECUREC_FOR_WCHAR +/* + * Convert a floating point wchar string to a floating point number + * Success ret 0 + */ +static int SecAssignFloatW(const SecFloatSpec *floatSpec, const SecScanSpec *spec) +{ + /* convert float string */ + size_t mbsLen; + size_t tempFloatStrLen = (size_t)(floatSpec->floatStrSize + 1) * sizeof(wchar_t); + char *tempFloatStr = (char *)SECUREC_MALLOC(tempFloatStrLen); + + if (tempFloatStr == NULL) { + return -1; + } + tempFloatStr[0] = '\0'; + SECUREC_MASK_MSVC_CRT_WARNING + mbsLen = wcstombs(tempFloatStr, floatSpec->floatStr, tempFloatStrLen - 1); + SECUREC_END_MASK_MSVC_CRT_WARNING + if (mbsLen != (size_t)-1) { + tempFloatStr[mbsLen] = '\0'; + SecAssignFloat(tempFloatStr, spec->numberWidth, spec->argPtr); + } else { + SECUREC_FREE(tempFloatStr); + return -1; + } + SECUREC_FREE(tempFloatStr); + return 0; +} +#endif +/* + * Splice floating point string + * return 0 OK + */ +static int SecUpdateFloatString(SecChar ch, SecFloatSpec *floatSpec) +{ + floatSpec->floatStr[floatSpec->floatStrUsedLen++] = ch; /* ch must be '0' - '9' */ + if (floatSpec->floatStrUsedLen < floatSpec->floatStrSize) { + return 0; + } + if (floatSpec->allocatedFloatStr == NULL) { + /* add 1 to clear ZERO LENGTH ALLOCATIONS warning */ + size_t oriBufSize = floatSpec->floatStrSize* (SECUREC_BUF_EXT_MUL * sizeof(SecChar)) + 1; + void *tmpPointer = (void *)SECUREC_MALLOC(oriBufSize); + if (tmpPointer == NULL) { + return -1; + } + if (memcpy_s(tmpPointer, oriBufSize, floatSpec->floatStr, floatSpec->floatStrSize * sizeof(SecChar)) != EOK) { + SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ + return -1; + } + floatSpec->floatStr = (SecChar *) (tmpPointer); + floatSpec->allocatedFloatStr = (SecChar *) (tmpPointer); /* use to clear free on stack warning */ + floatSpec->floatStrSize *= SECUREC_BUF_EXT_MUL; /* this is OK, oriBufSize plus 1 just clear warning */ + return 0; + } else { + /* LSD 2014.3.6 fix, replace realloc to malloc to avoid heap injection */ + size_t oriBufSize = floatSpec->floatStrSize * sizeof(SecChar); + size_t nextSize = (oriBufSize * SECUREC_BUF_EXT_MUL) + 1; /* add 1 to clear satic check tool warning */ + /* Prevents integer overflow when calculating the wide character length. + * The maximum length of SECUREC_MAX_WIDTH_LEN is enough + */ + if (nextSize <= SECUREC_MAX_WIDTH_LEN) { + void *tmpPointer = (void *)SECUREC_MALLOC(nextSize); + if (tmpPointer == NULL) { + return -1; + } + if (memcpy_s(tmpPointer, nextSize, floatSpec->floatStr, oriBufSize) != EOK) { + SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ + return -1; + } + if (memset_s(floatSpec->floatStr, oriBufSize, 0, oriBufSize) != EOK) { + SECUREC_FREE(tmpPointer); /* This is a dead code, just to meet the coding requirements */ + return -1; + } + SECUREC_FREE(floatSpec->floatStr); + + floatSpec->floatStr = (SecChar *) (tmpPointer); + floatSpec->allocatedFloatStr = (SecChar *) (tmpPointer); /* use to clear free on stack warning */ + floatSpec->floatStrSize *= SECUREC_BUF_EXT_MUL; /* this is OK, oriBufSize plus 1 just clear warning */ + return 0; + } + } + return -1; +} +#endif + +#ifndef SECUREC_FOR_WCHAR +/* LSD only multi-bytes string need isleadbyte() function */ +static int SecIsLeadByte(SecInt ch) +{ + unsigned int c = (unsigned int)ch; +#if !(defined(_MSC_VER) || defined(_INC_WCTYPE)) + return (int)(c & 0x80); +#else + return (int)isleadbyte((int)(c & 0xff)); +#endif +} +#endif + +/* + * Parsing whether it is a wide character + */ +static void SecUpdateWcharFlagByType(SecUnsignedChar ch, SecScanSpec *spec) +{ +#if defined(SECUREC_FOR_WCHAR) && (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + signed char flagForUpperType = -1; + signed char flagForLowerType = 1; +#else + signed char flagForUpperType = 1; + signed char flagForLowerType = -1; +#endif + /* if no l or h flag */ + if (spec->isWChar == 0) { + if ((ch == SECUREC_CHAR('C')) || (ch == SECUREC_CHAR('S'))) { + spec->isWChar = flagForUpperType; + } else { + spec->isWChar = flagForLowerType; + } + } + return; +} +/* + * decode %l %ll + */ +static void SecDecodeScanQualifierL(const SecUnsignedChar **format, SecScanSpec *spec) +{ + const SecUnsignedChar *fmt = *format; + if (*(fmt + 1) == SECUREC_CHAR('l')) { + spec->isInt64Arg = 1; + spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; + ++fmt; + } else { + spec->numberWidth = SECUREC_NUM_WIDTH_LONG; +#if defined(SECUREC_ON_64BITS) && !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + /* on window 64 system sizeof long is 32bit */ + spec->isInt64Arg = 1; +#endif + spec->isWChar = 1; + } + *format = fmt; +} + +/* + * decode %I %I43 %I64 %Id %Ii %Io ... + * set finishFlag to 1 finish Flag + */ +static void SecDecodeScanQualifierI(const SecUnsignedChar **format, SecScanSpec *spec, int *finishFlag) +{ + const SecUnsignedChar *fmt = *format; + if ((*(fmt + 1) == SECUREC_CHAR('6')) && + (*(fmt + 2) == SECUREC_CHAR('4'))) { /* offset 2 for I64 */ + spec->isInt64Arg = 1; + *format = *format + 2; /* add 2 to skip I64 point to '4' next loop will inc */ + } else if ((*(fmt + 1) == SECUREC_CHAR('3')) && + (*(fmt + 2) == SECUREC_CHAR('2'))) { /* offset 2 for I32 */ + *format = *format + 2; /* add 2 to skip I32 point to '2' next loop will inc */ + } else if ((*(fmt + 1) == SECUREC_CHAR('d')) || + (*(fmt + 1) == SECUREC_CHAR('i')) || + (*(fmt + 1) == SECUREC_CHAR('o')) || + (*(fmt + 1) == SECUREC_CHAR('x')) || + (*(fmt + 1) == SECUREC_CHAR('X'))) { + spec->isInt64Arg = SecIs64BitPtr(sizeof(void *)); + } else { + /* for %I */ + spec->isInt64Arg = SecIs64BitPtr(sizeof(void *)); + *finishFlag = 1; + } +} + +static int SecDecodeScanWidth(const SecUnsignedChar **format, SecScanSpec *spec) +{ + const SecUnsignedChar *fmt = *format; + while (SECUREC_IS_DIGIT(*fmt)) { + spec->widthSet = 1; + if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(spec->width)) { + return -1; + } + spec->width = (int)SECUREC_MUL_TEN((unsigned int)spec->width) + (unsigned char)(*fmt - SECUREC_CHAR('0')); + ++fmt; + } + *format = fmt; + return 0; +} + +/* + * init default flags for each format + */ +static void SecSetDefaultScanSpec(SecScanSpec *spec) +{ + spec->number64 = 0; + spec->number = 0; + spec->numberWidth = SECUREC_NUM_WIDTH_INT; /* 0 = SHORT, 1 = int, > 1 long or L_DOUBLE */ + spec->isInt64Arg = 0; /* 1 for 64-bit integer, 0 otherwise */ + spec->negative = 0; +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + spec->beyondMax = 0; +#endif + spec->argPtr = NULL; + spec->arrayWidth = 0; + spec->width = 0; + spec->widthSet = 0; + spec->comChr = 0; + spec->isWChar = 0; + spec->suppress = 0; +} + +/* + * decode qualifier %I %L %h ... + * set finishFlag to 1 finish Flag + */ +static void SecDecodeScanQualifier(const SecUnsignedChar **format, SecScanSpec *spec, int *finishFlag) +{ + switch ((int)(unsigned char)(**(format))) { + case SECUREC_CHAR('F'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('N'): + break; + case SECUREC_CHAR('h'): + --spec->numberWidth; /* h for SHORT , hh for CHAR */ + spec->isWChar = -1; + break; +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + case SECUREC_CHAR('j'): + spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; /* intmax_t or uintmax_t */ + spec->isInt64Arg = 1; + break; + case SECUREC_CHAR('t'): /* fall-through */ /* FALLTHRU */ +#endif + case SECUREC_CHAR('z'): +#ifdef SECUREC_ON_64BITS + spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; + spec->isInt64Arg = 1; +#else + spec->numberWidth = SECUREC_NUM_WIDTH_LONG; +#endif + break; + case SECUREC_CHAR('L'): /* long double */ /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('q'): + spec->numberWidth = SECUREC_NUM_WIDTH_LONG_LONG; + spec->isInt64Arg = 1; + break; + case SECUREC_CHAR('l'): + SecDecodeScanQualifierL(format, spec); + break; + case SECUREC_CHAR('w'): + spec->isWChar = 1; + break; + case SECUREC_CHAR('*'): + spec->suppress = 1; + break; + case SECUREC_CHAR('I'): + SecDecodeScanQualifierI(format, spec, finishFlag); + break; + default: + *finishFlag = 1; + break; + } + +} +/* + * decode width and qualifier in format + */ +static int SecDecodeScanFlag(const SecUnsignedChar **format, SecScanSpec *spec) +{ + const SecUnsignedChar *fmt = *format; + int finishFlag = 0; + + do { + ++fmt; /* first skip % , next seek fmt */ + /* may %*6d , so put it inside the loop */ + if (SecDecodeScanWidth(&fmt, spec) != 0) { + return -1; + } + SecDecodeScanQualifier(&fmt, spec, &finishFlag); + } while (finishFlag == 0); + *format = fmt; + return 0; +} + + + + + +/* + * Judging whether a zeroing buffer is needed according to different formats + */ +static int SecDecodeClearFormat(const SecUnsignedChar *format, int *comChr) +{ + const SecUnsignedChar *fmt = format; + /* to lowercase */ + int ch = (unsigned char)(*fmt) | (SECUREC_CHAR('a') - SECUREC_CHAR('A')); + if (!(ch == SECUREC_CHAR('c') || ch == SECUREC_CHAR('s') || ch == SECUREC_BRACE)) { + return -1; /* first argument is not a string type */ + } + if (ch == SECUREC_BRACE) { +#if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + if (*fmt == SECUREC_CHAR('{')) { + return -1; + } +#endif + ++fmt; + if (*fmt == SECUREC_CHAR('^')) { + ++fmt; + } + if (*fmt == SECUREC_CHAR(']')) { + ++fmt; + } + while ((*fmt != SECUREC_CHAR('\0')) && (*fmt != SECUREC_CHAR(']'))) { + ++fmt; + } + if (*fmt == SECUREC_CHAR('\0')) { + return -1; /* trunc'd format string */ + } + } + *comChr = ch; + return 0; +} + +/* + * add L'\0' for wchar string , add '\0' for char string + */ +static void SecAddEndingZero(void *ptr, const SecScanSpec *spec) +{ + *(char *)ptr = '\0'; + (void)spec; /* clear not use */ +#if SECUREC_HAVE_WCHART + if (spec->isWChar > 0) { + *(wchar_t UNALIGNED *)ptr = L'\0'; + } +#endif +} + +#ifdef SECUREC_FOR_WCHAR +/* + * Clean up the first %s %c buffer to zero for wchar version + */ +void SecClearDestBufW(const wchar_t *buffer, const wchar_t *format, va_list argList) +#else +/* + * Clean up the first %s %c buffer to zero for char version + */ +void SecClearDestBuf(const char *buffer, const char *format, va_list argList) +#endif +{ + + va_list argListSave; /* backup for argList value, this variable don't need initialized */ + SecScanSpec spec; + int comChr = 0; + const SecUnsignedChar *fmt = (const SecUnsignedChar *)format; + if (fmt == NULL) { + return; + } + + /* find first % */ + while (*fmt != SECUREC_CHAR('\0') && *fmt != SECUREC_CHAR('%')) { + ++fmt; + } + if (*fmt == SECUREC_CHAR('\0')) { + return; + } + + SecSetDefaultScanSpec(&spec); + if (SecDecodeScanFlag(&fmt, &spec) != 0) { + return; + } + + /* update wchar flag for %S %C */ + SecUpdateWcharFlagByType(*fmt, &spec); + + if (spec.suppress != 0 || SecDecodeClearFormat(fmt, &comChr) != 0) { + return; + } + + if ((buffer != NULL) && (*buffer != SECUREC_CHAR('\0')) && (comChr != SECUREC_CHAR('s'))) { + /* when buffer not empty just clear %s. + * example call sscanf by argment of (" \n", "%s", s, sizeof(s)) + */ + return; + } + (void)memset(&argListSave, 0, sizeof(va_list)); /* to clear e530 argListSave not initialized */ +#if defined(va_copy) + va_copy(argListSave, argList); +#elif defined(__va_copy) /* for vxworks */ + __va_copy(argListSave, argList); +#else + argListSave = argList; +#endif + do { + void *argPtr = (void *)va_arg(argListSave, void *); + /* Get the next argument - size of the array in characters */ + size_t arrayWidth = ((size_t)(va_arg(argListSave, size_t))) & 0xFFFFFFFFUL; + va_end(argListSave); + /* to clear e438 last value assigned not used , the compiler will optimize this code */ + (void)argListSave; + /* There is no need to judge the upper limit */ + if (arrayWidth == 0 || argPtr == NULL) { + return; + } + + /* clear one char */ + SecAddEndingZero(argPtr, &spec); + } SECUREC_WHILE_ZERO; + return; + +} + +/* + * Assign number to output buffer + */ +static void SecAssignNumber(const SecScanSpec *spec) +{ + void *argPtr = spec->argPtr; + if (spec->isInt64Arg != 0) { +#if defined(SECUREC_VXWORKS_PLATFORM) +#if defined(SECUREC_VXWORKS_PLATFORM_COMP) + *(SecInt64 UNALIGNED *)argPtr = (SecInt64)(spec->number64); +#else + /* take number64 as unsigned number unsigned to int clear Compile warning */ + *(SecInt64 UNALIGNED *)argPtr = *(SecUnsignedInt64 *)(&(spec->number64)); +#endif +#else + /* take number64 as unsigned number */ + *(SecInt64 UNALIGNED *)argPtr = (SecInt64)(spec->number64); +#endif + return; + } + if (spec->numberWidth > SECUREC_NUM_WIDTH_INT) { + /* take number as unsigned number */ + *(long UNALIGNED *)argPtr = (long)(spec->number); + } else if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { + *(int UNALIGNED *)argPtr = (int)(spec->number); + } else if (spec->numberWidth == SECUREC_NUM_WIDTH_SHORT) { + /* take number as unsigned number */ + *(short UNALIGNED *)argPtr = (short)(spec->number); + } else { /* < 0 for hh format modifier */ + /* take number as unsigned number */ + *(char UNALIGNED *)argPtr = (char)(spec->number); + } +} + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) +/* + * Judge the long bit width + */ +static int SecIsLongBitEqual(int bitNum) +{ + return (unsigned int)bitNum == SECUREC_LONG_BIT_NUM; +} +#endif +/* + * Convert hexadecimal characters to decimal value + */ +static int SecHexValueOfChar(SecInt ch) +{ + /* use isdigt Causing tool false alarms */ + return (int)((ch >= '0' && ch <= '9') ? ((unsigned char)ch - '0') : + ((((unsigned char)ch | (unsigned char)('a' - 'A')) - ('a')) + 10)); /* Adding 10 is to hex value */ +} + + + +/* + * Parse decimal character to integer for 32bit . + */ +static void SecDecodeNumberDecimal(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + unsigned long decimalEdge = SECUREC_MAX_32BITS_VALUE_DIV_TEN; +#ifdef SECUREC_ON_64BITS + if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { + decimalEdge = (unsigned long)SECUREC_MAX_64BITS_VALUE_DIV_TEN; + } +#else + if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { + decimalEdge = SECUREC_MAX_32BITS_VALUE_DIV_TEN; + } +#endif + if (spec->number > decimalEdge) { + spec->beyondMax = 1; + } +#endif + spec->number = SECUREC_MUL_TEN(spec->number); +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (spec->number == SECUREC_MUL_TEN(decimalEdge)) { + SecUnsignedInt64 number64As = (unsigned long)SECUREC_MAX_64BITS_VALUE - spec->number; + if (number64As < (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0'))) { + spec->beyondMax = 1; + } + } +#endif + spec->number += (unsigned long)((SecUnsignedInt)ch - SECUREC_CHAR('0')); + +} + + +/* + * Parse Hex character to integer for 32bit . + */ +static void SecDecodeNumberHex(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (SECUREC_LONG_HEX_BEYOND_MAX(spec->number)) { + spec->beyondMax = 1; + } +#endif + spec->number = SECUREC_MUL_SIXTEEN(spec->number); + spec->number += (unsigned long)(unsigned int)SecHexValueOfChar(ch); +} + + +/* + * Parse Octal character to integer for 32bit . + */ +static void SecDecodeNumberOctal(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (SECUREC_LONG_OCTAL_BEYOND_MAX(spec->number)) { + spec->beyondMax = 1; + } +#endif + spec->number = SECUREC_MUL_EIGHT(spec->number); + spec->number += (unsigned long)((SecUnsignedInt)ch - SECUREC_CHAR('0')); +} + + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) +/* Compatible with integer negative values other than int */ +static void SecFinishNumberNegativeOther(int comChr, int numberWidth, SecScanSpec *spec) +{ + if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { + if (spec->number > (unsigned long)(1ULL << (SECUREC_LONG_BIT_NUM - 1))) { + spec->number = (unsigned long)(1ULL << (SECUREC_LONG_BIT_NUM - 1)); + } else { + spec->number = (unsigned long)(-(long)spec->number); + } + if (spec->beyondMax != 0) { + if (numberWidth < SECUREC_NUM_WIDTH_INT) { + spec->number = 0; + } else if (numberWidth == SECUREC_NUM_WIDTH_LONG) { + spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1))); + } + } + } else { /* o, u, x, X, p */ + spec->number = (unsigned long)(-(long)spec->number); + if (spec->beyondMax != 0) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } + } +} +/* Compatible processing of integer negative numbers */ +static void SecFinishNumberNegativeInt(int comChr, SecScanSpec *spec) +{ + if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { +#ifdef SECUREC_ON_64BITS + if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { + if ((spec->number > SECUREC_MIN_64BITS_NEG_VALUE)) { + spec->number = 0; + } else { + spec->number = (unsigned int)(-(int)spec->number); + } + } +#else + if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { + if ((spec->number > SECUREC_MIN_32BITS_NEG_VALUE)) { + spec->number = SECUREC_MIN_32BITS_NEG_VALUE; + } else { + spec->number = (unsigned int)(-(int)spec->number); + } + } +#endif + if (spec->beyondMax != 0) { +#ifdef SECUREC_ON_64BITS + if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { + spec->number = 0; + } +#else + if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { + spec->number = SECUREC_MIN_32BITS_NEG_VALUE; + } +#endif + } + } else { /* o, u, x, X ,p */ +#ifdef SECUREC_ON_64BITS + if (spec->number > SECUREC_MAX_32BITS_VALUE_INC) { + spec->number = SECUREC_MAX_32BITS_VALUE; + } else { + spec->number = (unsigned int)(-(int)spec->number); + } +#else + spec->number = (unsigned int)(-(int)spec->number); +#endif + if (spec->beyondMax != 0) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } + } +} + +/* Compatible with integer positive values other than int */ +static void SecFinishNumberPositiveOther(int comChr, int numberWidth, SecScanSpec *spec) +{ + if (comChr == SECUREC_CHAR('d') || comChr == SECUREC_CHAR('i')) { + if (spec->number > ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1)) { + spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1); + } + if ((spec->beyondMax != 0 && numberWidth < SECUREC_NUM_WIDTH_INT)) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } + if (spec->beyondMax != 0 && numberWidth == SECUREC_NUM_WIDTH_LONG) { + spec->number = ((unsigned long)(1UL << (SECUREC_LONG_BIT_NUM - 1)) - 1); + } + } else { + if (spec->beyondMax != 0) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } + } +} + +/* Compatible processing of integer positive numbers */ +static void SecFinishNumberPositiveInt(int comChr, SecScanSpec *spec) +{ + if ((comChr == SECUREC_CHAR('d')) || (comChr == SECUREC_CHAR('i'))) { +#ifdef SECUREC_ON_64BITS + if (SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { + if (spec->number > SECUREC_MAX_64BITS_POS_VALUE) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } + } + if (spec->beyondMax != 0 && SecIsLongBitEqual(SECUREC_LP64_BIT_WIDTH)) { + spec->number |= (unsigned long)SECUREC_MAX_64BITS_VALUE; + } +#else + if (SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { + if (spec->number > SECUREC_MAX_32BITS_POS_VALUE) { + spec->number = SECUREC_MAX_32BITS_POS_VALUE; + } + } + if (spec->beyondMax != 0 && SecIsLongBitEqual(SECUREC_LP32_BIT_WIDTH)) { + spec->number = SECUREC_MAX_32BITS_POS_VALUE; + } +#endif + } else { /* o,u,x,X,p */ + if (spec->beyondMax != 0) { + spec->number = SECUREC_MAX_32BITS_VALUE; + } + } +} + +#endif + + +/* + * Parse decimal character to integer for 64bit . + */ +static void SecDecodeNumber64Decimal(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (spec->number64 > SECUREC_MAX_64BITS_VALUE_DIV_TEN) { + spec->beyondMax = 1; + } +#endif + spec->number64 = SECUREC_MUL_TEN(spec->number64); +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (spec->number64 == SECUREC_MAX_64BITS_VALUE_CUT_LAST_DIGIT) { + SecUnsignedInt64 number64As = (SecUnsignedInt64)SECUREC_MAX_64BITS_VALUE - spec->number64; + if (number64As < (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0'))) { + spec->beyondMax = 1; + } + } +#endif + spec->number64 += (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0')); +} + +/* + * Parse Hex character to integer for 64bit . + */ +static void SecDecodeNumber64Hex(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (SECUREC_QWORD_HEX_BEYOND_MAX(spec->number64)) { + spec->beyondMax = 1; + } +#endif + spec->number64 = SECUREC_MUL_SIXTEEN(spec->number64); + spec->number64 += (SecUnsignedInt64)(unsigned int)SecHexValueOfChar(ch); + +} + +/* + * Parse Octal character to integer for 64bit . + */ +static void SecDecodeNumber64Octal(SecInt ch, SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (SECUREC_QWORD_OCTAL_BEYOND_MAX(spec->number64)) { + spec->beyondMax = 1; + } +#endif + spec->number64 = SECUREC_MUL_EIGHT(spec->number64); + spec->number64 += (SecUnsignedInt64)((SecUnsignedInt)ch - SECUREC_CHAR('0')); +} + +#define SECUREC_DECODE_NUMBER_FUNC_NUM 2 +/* Function name cannot add address symbol, causing 546 alarm */ +static void (*g_secDecodeNumberHex[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ + { SecDecodeNumberHex, SecDecodeNumber64Hex }; +static void (*g_secDecodeNumberOctal[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ + { SecDecodeNumberOctal, SecDecodeNumber64Octal }; +static void (*g_secDecodeNumberDecimal[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecInt ch, SecScanSpec *spec) = \ + { SecDecodeNumberDecimal, SecDecodeNumber64Decimal }; + +/* + * Parse 64-bit integer formatted input, return 0 when ch is a number. + */ +static int SecDecodeNumber(SecInt ch, SecScanSpec *spec) +{ + if (spec->comChr == SECUREC_CHAR('x') || spec->comChr == SECUREC_CHAR('p')) { + if (SECUREC_IS_XDIGIT(ch)) { + (*g_secDecodeNumberHex[spec->isInt64Arg])(ch, spec); + } else { + return -1; + } + return 0; + } + if (!(SECUREC_IS_DIGIT(ch))) { + return -1; + } + if (spec->comChr == SECUREC_CHAR('o')) { + if (ch < SECUREC_CHAR('8')) { + (*g_secDecodeNumberOctal[spec->isInt64Arg])(ch, spec); + } else { + return -1; + } + } else { /* comChr is 'd' */ + (*g_secDecodeNumberDecimal[spec->isInt64Arg])(ch, spec); + } + return 0; +} + + +/* + * Complete the final 32-bit integer formatted input + */ +static void SecFinishNumber(SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (spec->negative != 0) { + if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { + SecFinishNumberNegativeInt(spec->oriComChr, spec); + } else { + SecFinishNumberNegativeOther(spec->oriComChr, spec->numberWidth, spec); + } + } else { + if (spec->numberWidth == SECUREC_NUM_WIDTH_INT) { + SecFinishNumberPositiveInt(spec->oriComChr, spec); + } else { + SecFinishNumberPositiveOther(spec->oriComChr, spec->numberWidth, spec); + } + } +#else + if (spec->negative != 0) { +#if defined(__hpux) + if (spec->oriComChr != SECUREC_CHAR('p')) { + spec->number = (unsigned long)(-(long)spec->number); + } +#else + spec->number = (unsigned long)(-(long)spec->number); +#endif + } +#endif + return; +} + +/* + * Complete the final 64-bit integer formatted input + */ +static void SecFinishNumber64(SecScanSpec *spec) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && !(defined(SECUREC_ON_UNIX))) + if (spec->negative != 0) { + if (spec->oriComChr == (SECUREC_CHAR('d')) || (spec->oriComChr == SECUREC_CHAR('i'))) { + if (spec->number64 > SECUREC_MIN_64BITS_NEG_VALUE) { + spec->number64 = SECUREC_MIN_64BITS_NEG_VALUE; + } else { + spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); + } + if (spec->beyondMax != 0) { + spec->number64 = SECUREC_MIN_64BITS_NEG_VALUE; + } + } else { /* o, u, x, X, p */ + spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); + if (spec->beyondMax != 0) { + spec->number64 = SECUREC_MAX_64BITS_VALUE; + } + } + } else { + if ((spec->oriComChr == SECUREC_CHAR('d')) || (spec->oriComChr == SECUREC_CHAR('i'))) { + if (spec->number64 > SECUREC_MAX_64BITS_POS_VALUE) { + spec->number64 = SECUREC_MAX_64BITS_POS_VALUE; + } + if (spec->beyondMax != 0) { + spec->number64 = SECUREC_MAX_64BITS_POS_VALUE; + } + } else { + if (spec->beyondMax != 0) { + spec->number64 = SECUREC_MAX_64BITS_VALUE; + } + } + } +#else + if (spec->negative != 0) { +#if defined(__hpux) + if (spec->oriComChr != SECUREC_CHAR('p')) { + spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); + } +#else + spec->number64 = (SecUnsignedInt64)(-(SecInt64)spec->number64); +#endif + } +#endif + return; +} +static void (*g_secFinishNumber[SECUREC_DECODE_NUMBER_FUNC_NUM])(SecScanSpec *spec) = \ + { SecFinishNumber, SecFinishNumber64 }; + +#if SECUREC_ENABLE_SCANF_FILE + +/* + * Adjust the pointer position of the file stream + */ +static void SecSeekStream(SecFileStream *stream) +{ + if ((stream->count == 0) && feof(stream->pf)) { + /* file pointer at the end of file, don't need to seek back */ + stream->base[0] = '\0'; + return; + } + /* LSD seek to original position, bug fix 2014 1 21 */ + if (fseek(stream->pf, stream->oriFilePos, SEEK_SET)) { + /* seek failed, ignore it */ + stream->oriFilePos = 0; + return; + } + + if (stream->fileRealRead > 0) { /* LSD bug fix. when file reach to EOF, don't seek back */ +#if (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + int loops; + for (loops = 0; loops < (stream->fileRealRead / SECUREC_BUFFERED_BLOK_SIZE); ++loops) { + if (fread(stream->base, (size_t)1, (size_t)SECUREC_BUFFERED_BLOK_SIZE, + stream->pf) != SECUREC_BUFFERED_BLOK_SIZE) { + break; + } + } + if ((stream->fileRealRead % SECUREC_BUFFERED_BLOK_SIZE) != 0) { + size_t ret = fread(stream->base, (size_t)((unsigned int)stream->fileRealRead % SECUREC_BUFFERED_BLOK_SIZE), + (size_t)1, stream->pf); + if ((ret == 1 || ret == 0) && (ftell(stream->pf) < stream->oriFilePos + stream->fileRealRead)) { + (void)fseek(stream->pf, stream->oriFilePos + stream->fileRealRead, SEEK_SET); + } + } + +#else + /* in linux like system */ + if (fseek(stream->pf, stream->oriFilePos + stream->fileRealRead, SEEK_SET)) { + /* seek failed, ignore it */ + stream->oriFilePos = 0; + } +#endif + } + + return; +} + +/* + * Adjust the pointer position of the file stream and free memory + */ +static void SecAdjustStream(SecFileStream *stream) +{ + if (stream != NULL && (stream->flag & SECUREC_FILE_STREAM_FLAG) && stream->base != NULL) { + SecSeekStream(stream); + SECUREC_FREE(stream->base); + stream->base = NULL; + } + return; +} +#endif + +static void SecSkipSpaceFormat(const SecUnsignedChar **format) +{ + const SecUnsignedChar *fmt = *format; + while (SECUREC_IS_SPACE(*fmt)) { + ++fmt; + } + *format = fmt; +} +#ifndef SECUREC_FOR_WCHAR +/* + * Handling multi-character characters + */ +static int SecDecodeLeadByte(SecInt ch, const SecUnsignedChar **format, SecFileStream *stream, int *counter) +{ +#if SECUREC_HAVE_MBTOWC + char temp[SECUREC_MULTI_BYTE_MAX_LEN]; + const SecUnsignedChar *fmt = *format; + wchar_t tempWChar = L'\0'; + int ch2 = SecGetChar(stream, counter); + if (*fmt == SECUREC_CHAR('\0') || (int)(*fmt) != (ch2)) { + /* LSD in console mode, ungetc twice may cause problem */ + SecUnGetChar(ch2, stream, counter); + SecUnGetChar(ch, stream, counter); + return -1; + } + ++fmt; + if (MB_CUR_MAX >= SECUREC_UTF8_BOM_HEADER_SIZE && + (((unsigned char)ch & SECUREC_UTF8_LEAD_1ST) == SECUREC_UTF8_LEAD_1ST) && + (((unsigned char)ch2 & SECUREC_UTF8_LEAD_2ND) == SECUREC_UTF8_LEAD_2ND)) { + /* this char is very likely to be a UTF-8 char */ + int ch3 = SecGetChar(stream, counter); + temp[0] = (char)ch; + temp[1] = (char)ch2; /* 1 index of second character */ + temp[2] = (char)ch3; /* 2 index of third character */ + temp[3] = '\0'; /* 3 of string terminator position */ + + if (mbtowc(&tempWChar, temp, sizeof(temp)) > 0) { + /* succeed */ + if (*fmt == SECUREC_CHAR('\0') || (int)(*fmt) != (int)ch3) { + SecUnGetChar(ch3, stream, counter); + return -1; + } + ++fmt; + *counter = *counter - 1; + } else { + SecUnGetChar(ch3, stream, counter); + } + } + *counter = *counter - 1; /* only count as one character read */ + *format = fmt; + return 0; +#else + SecUnGetChar(ch, stream, counter); + (void)format; + return -1; +#endif +} +#endif + + + +/* + * Resolving sequence of characters from %[ format + */ +static int SecSetupBracketTable(const SecUnsignedChar **format, SecBracketTable *bracketTable) +{ + const SecUnsignedChar *fmt = *format; + SecUnsignedChar prevChar = 0; + SecUnsignedChar expCh; + SecUnsignedChar last = 0; +#if !(defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + if (*fmt == SECUREC_CHAR('{')) { + return -1; + } +#endif + /* for building "table" data */ + ++fmt; /* skip [ */ + bracketTable->mask = 0; + if (*fmt == SECUREC_CHAR('^')) { + ++fmt; + bracketTable->mask = (unsigned char)0xff; + } + if (*fmt == SECUREC_CHAR(']')) { + prevChar = SECUREC_CHAR(']'); + ++fmt; + SECUREC_BRACKET_SET_BIT(bracketTable->table, SECUREC_CHAR(']')); + } + while (*fmt != SECUREC_CHAR('\0') && *fmt != SECUREC_CHAR(']')) { + expCh = *fmt++; + if (expCh != SECUREC_CHAR('-') || prevChar == 0 || *fmt == SECUREC_CHAR(']')) { + /* normal character */ + prevChar = expCh; + SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); + } else { + /* for %[a-z] */ + expCh = *fmt++; /* get end of range */ + if (prevChar < expCh) { /* %[a-z] */ + last = expCh; + } else { + prevChar = expCh; +#if (defined(SECUREC_COMPATIBLE_WIN_FORMAT)) + /* %[z-a] */ + last = prevChar; + +#else + SECUREC_BRACKET_SET_BIT(bracketTable->table, SECUREC_CHAR('-')); + SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); + continue; +#endif + } + /* format %[a-\xff] last is 0xFF, condition (rnch <= last) cause dead loop */ + for (expCh = prevChar; expCh < last; ++expCh) { + SECUREC_BRACKET_SET_BIT(bracketTable->table, expCh); + } + SECUREC_BRACKET_SET_BIT(bracketTable->table, last); + prevChar = 0; + } + } + *format = fmt; + return 0; +} + + +#ifdef SECUREC_FOR_WCHAR +static int SecInputForWchar(SecInt ch, SecScanSpec *spec) +{ + void *endPtr = spec->argPtr; + if (spec->isWChar > 0) { + *(wchar_t UNALIGNED *)endPtr = (wchar_t)ch; + endPtr = (wchar_t *)endPtr + 1; + --spec->arrayWidth; + } else { +#if SECUREC_HAVE_WCTOMB + int temp; + char tmpBuf[SECUREC_MB_LEN + 1]; + SECUREC_MASK_MSVC_CRT_WARNING temp = wctomb(tmpBuf, (wchar_t)ch); + SECUREC_END_MASK_MSVC_CRT_WARNING + if (temp <= 0 || ((size_t)(unsigned int)temp) > sizeof(tmpBuf)) { + /* if wctomb error, then ignore character */ + return 0; + } + if (((size_t)(unsigned int)temp) > spec->arrayWidth) { + return -1; + } + if (memcpy_s(endPtr, spec->arrayWidth, tmpBuf, (size_t)(unsigned int)temp) != EOK) { + return -1; + } + endPtr = (char *)endPtr + temp; + spec->arrayWidth -= (size_t)(unsigned int)temp; +#else + return -1; +#endif + } + spec->argPtr = endPtr; + return 0; +} +#endif + + +#ifndef SECUREC_FOR_WCHAR +static int SecInputForChar(SecInt ch, SecScanSpec *spec, SecFileStream *stream, int *charCount) +{ + void *endPtr = spec->argPtr; + if (spec->isWChar > 0) { + wchar_t tempWChar = L'?'; /* set default char as ? */ +#if SECUREC_HAVE_MBTOWC + char temp[SECUREC_MULTI_BYTE_MAX_LEN + 1]; + temp[0] = (char)ch; + temp[1] = '\0'; +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) + if (SecIsLeadByte(ch)) { + temp[1] = (char)SecGetChar(stream, charCount); + temp[2] = '\0'; /* 2 of string terminator position */ + } + if (mbtowc(&tempWChar, temp, sizeof(temp)) <= 0) { + /* no string termination error for tool */ + tempWChar = L'?'; + } +#else + if (SecIsLeadByte(ch)) { + int convRes = 0; + int di = 1; + /* in Linux like system, the string is encoded in UTF-8 */ + while (convRes <= 0 && di < (int)MB_CUR_MAX && di < SECUREC_MULTI_BYTE_MAX_LEN) { + temp[di++] = (char)SecGetChar(stream, charCount); + temp[di] = '\0'; + convRes = mbtowc(&tempWChar, temp, sizeof(temp)); + } + if (convRes <= 0) { + tempWChar = L'?'; + } + } else { + if (mbtowc(&tempWChar, temp, sizeof(temp)) <= 0) { + /* no string termination error for tool */ + tempWChar = L'?'; + } + } +#endif +#endif /* SECUREC_HAVE_MBTOWC */ + *(wchar_t UNALIGNED *)endPtr = tempWChar; + /* just copy L'?' if mbtowc fails, errno is set by mbtowc */ + endPtr = (wchar_t *)endPtr + 1; + --spec->arrayWidth; + (void)charCount; + (void)stream; + } else { + *(char *)endPtr = (char)ch; + endPtr = (char *)endPtr + 1; + --spec->arrayWidth; + } + spec->argPtr = endPtr; + return 0; +} +#endif + + +#if SECUREC_ENABLE_SCANF_FLOAT + +/* no not use localeconv()->decimal_pointif onlay support '.' */ +#define SECURE_IS_FLOAT_DECIMAL(ch) ((ch) == SECUREC_CHAR('.')) +/* + * init SecFloatSpec befor parse format + */ +static void SecInitFloatSpec(SecFloatSpec *floatSpec) +{ + floatSpec->floatStr = floatSpec->buffer; + floatSpec->allocatedFloatStr = NULL; + floatSpec->floatStrSize = sizeof(floatSpec->buffer) / sizeof(floatSpec->buffer[0]); + floatSpec->floatStr = floatSpec->buffer; + floatSpec->floatStrUsedLen = 0; +} + +static void SecClearFloatSpec(SecFloatSpec *floatSpec, int *doneCount) +{ + /* LSD 2014.3.6 add, clear the stack data */ + if (memset_s(floatSpec->buffer, sizeof(floatSpec->buffer), 0, + sizeof(floatSpec->buffer)) != EOK) { + *doneCount = 0; /* This is a dead code, just to meet the coding requirements */ + } + if (floatSpec->allocatedFloatStr != NULL) { + /* pFloatStr can be alloced in SecUpdateFloatString function, clear and free it */ + if (memset_s(floatSpec->allocatedFloatStr, floatSpec->floatStrSize * sizeof(SecChar), 0, + floatSpec->floatStrSize * sizeof(SecChar)) != EOK) { + *doneCount = 0; /* This is a dead code, just to meet the coding requirements */ + } + SECUREC_FREE(floatSpec->allocatedFloatStr); + floatSpec->allocatedFloatStr = NULL; + floatSpec->floatStr = NULL; + } +} + + +/* + * scan value of exponent. + * return 0 OK + */ +static int SecInputFloatE(SecFileStream *stream, SecScanSpec *spec, SecFloatSpec *floatSpec, int *charCount) +{ + SecInt ch = SecGetChar(stream, charCount); + if (ch == SECUREC_CHAR('+') || ch == SECUREC_CHAR('-')) { + if (ch == SECUREC_CHAR('-') && SecUpdateFloatString((SecChar)'-', floatSpec) != 0) { + return -1; + } + if (spec->width != 0) { + ch = SecGetChar(stream, charCount); + --spec->width; + } + } + + while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { + if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { + return -1; + } + ch = SecGetChar(stream, charCount); + } + return 0; +} + +/* + * scan %f. + * return 0 OK + */ +static int SecInputFloat(SecFileStream *stream, SecScanSpec *spec, SecFloatSpec *floatSpec, int *charCount) +{ + int started = -1; + SecInt ch = SecGetChar(stream, charCount); + + floatSpec->floatStrUsedLen = 0; + if (ch == SECUREC_CHAR('-')) { + floatSpec->floatStr[floatSpec->floatStrUsedLen++] = SECUREC_CHAR('-'); + --spec->width; + ch = SecGetChar(stream, charCount); + } else if (ch == SECUREC_CHAR('+')) { + --spec->width; + ch = SecGetChar(stream, charCount); + } + + if (spec->widthSet == 0) { /* must care width */ + spec->width = -1; /* -1 is unlimited */ + } + + /* now get integral part */ + while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { + started = 0; + /* ch must be '0' - '9' */ + if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { + return -1; + } + ch = SecGetChar(stream, charCount); + } + + /* now get fractional part */ + if (SECURE_IS_FLOAT_DECIMAL((SecChar)ch) && spec->width-- != 0) { + /* now check for decimal */ + if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { + return -1; + } + ch = SecGetChar(stream, charCount); + while (SECUREC_IS_DIGIT(ch) && spec->width-- != 0) { + started = 0; + if (SecUpdateFloatString((SecChar)ch, floatSpec) != 0) { + return -1; + } + ch = SecGetChar(stream, charCount); + } + } + + /* now get exponent part */ + if (started == 0 && (ch == SECUREC_CHAR('e') || ch == SECUREC_CHAR('E')) && spec->width-- != 0) { + if (SecUpdateFloatString((SecChar)'e', floatSpec) != 0) { + return -1; + } + if (SecInputFloatE(stream, spec, floatSpec, charCount) != 0) { + return -1; + } + } + /* un set the last character that is not a floating point number */ + SecUnGetChar(ch, stream, charCount); + /* Make sure have a string terminator, buffer is large enough */ + floatSpec->floatStr[floatSpec->floatStrUsedLen] = SECUREC_CHAR('\0'); + return started; + +} +#endif + +/* + * scan digital part of %d %i %o %u %x %p. + * return 0 OK + */ +static int SecInputNumberDigital(SecInt firstCh, SecFileStream *stream, SecScanSpec *spec, int *charCount) +{ + SecInt ch = firstCh; + int loopFlag = 0; + int started = -1; + while (loopFlag == 0) { + /* decode ch to number */ + loopFlag = SecDecodeNumber(ch, spec); + if (loopFlag == 0) { + started = 0; + if (spec->widthSet != 0 && --spec->width == 0) { + loopFlag = 1; + } else { + ch = SecGetChar(stream, charCount); + } + } else { + SecUnGetChar(ch, stream, charCount); + } + } + + /* Handling integer negative numbers and beyond max */ + (*g_secFinishNumber[spec->isInt64Arg])(spec); + return started; + +} + +/* + * scan %d %i %o %u %x %p. + * return 0 OK + */ +static int SecInputNumber(SecFileStream *stream, SecScanSpec *spec, int *charCount) +{ + SecInt ch = SecGetChar(stream, charCount); + + if (ch == SECUREC_CHAR('+') || ch == SECUREC_CHAR('-')) { + if (ch == SECUREC_CHAR('-')) { + spec->negative = 1; + } + if (spec->widthSet != 0 && --spec->width == 0) { + return -1; + } else { + ch = SecGetChar(stream, charCount); + } + } + + if (spec->oriComChr == SECUREC_CHAR('i')) { + /* i could be d, o, or x, use d as default */ + spec->comChr = SECUREC_CHAR('d'); + } + + if (spec->oriComChr == SECUREC_CHAR('x') || spec->oriComChr == SECUREC_CHAR('i')) { + if (ch != SECUREC_CHAR('0')) { + /* scan number */ + return SecInputNumberDigital(ch, stream, spec, charCount); + } + /* now input string may be 0x123 or 0X123 or just 0 */ + /* get next char */ + ch = SecGetChar(stream, charCount); + if ((SecChar)(ch) == SECUREC_CHAR('x') || (SecChar)ch == SECUREC_CHAR('X')) { + spec->comChr = SECUREC_CHAR('x'); + ch = SecGetChar(stream, charCount); + /* length of 0x is 2 */ + if (spec->widthSet != 0 && spec->width <= (1 + 1)) { + /* length not enough for "0x" */ + return -1; + } + spec->width -= 2; /* Subtract 2 for the length of "0x" */ + } else { + if (spec->oriComChr != SECUREC_CHAR('x')) { + spec->comChr = SECUREC_CHAR('o'); + } + /* unset the character after 0 back to stream, input only '0' result is OK */ + SecUnGetChar(ch, stream, charCount); + ch = SECUREC_CHAR('0'); + } + } + /* scan number */ + return SecInputNumberDigital(ch, stream, spec, charCount); +} +/* + * scan %c %s %[ + * return 0 OK + */ +static int SecInputString(SecFileStream *stream, SecScanSpec *spec, + const SecBracketTable *bracketTable, int *charCount, int *doneCount) +{ + void *startPtr = spec->argPtr; + int suppressed= 0; + int errNoMem = 0; + + while (spec->widthSet == 0 || spec->width-- != 0) { + SecInt ch = SecGetChar(stream, charCount); + /* char condition or string condition and bracket condition. + * only supports wide characters with a maximum length of two bytes + */ + if ((ch != SECUREC_EOF) && (spec->comChr == SECUREC_CHAR('c') || + SECUREC_SCANF_STRING_CONDITION(spec->comChr, ch) || + SECUREC_SCANF_BRACKET_CONDITION(spec->comChr, ch, bracketTable->table, bracketTable->mask))) { + if (spec->suppress != 0) { + /* Used to identify processed data for %* + * use endPtr to identify will cause 613, so use suppressed + */ + suppressed = 1; + continue; + } + /* now suppress is not set */ + if (spec->arrayWidth == 0) { + errNoMem = 1; /* We have exhausted the user's buffer */ + break; + } +#ifdef SECUREC_FOR_WCHAR + errNoMem = SecInputForWchar(ch, spec); +#else + errNoMem = SecInputForChar(ch, spec, stream, charCount); +#endif + if (errNoMem != 0) { + break; + } + } else { + SecUnGetChar(ch, stream, charCount); + break; + } + } + + if (errNoMem != 0) { + /* In case of error, blank out the input buffer */ + if (spec->suppress == 0) { + SecAddEndingZero(startPtr, spec); + } + return -1; + } + + /* No input was scanned */ + if ((spec->suppress != 0 && suppressed == 0) || + (spec->suppress == 0 && startPtr == spec->argPtr)) { + return -1; + } + + if (spec->suppress == 0) { + if (spec->comChr != 'c') { + /* null-terminate strings */ + SecAddEndingZero(spec->argPtr, spec); + } + *doneCount = *doneCount + 1; + } + return 0; +} + +#ifdef SECUREC_FOR_WCHAR +/* + * alloce buffer for wchar version of %[. + * return 0 OK + */ +static int SecAllocBracketTable(SecBracketTable *bracketTable) +{ + if (bracketTable->table == NULL) { + /* table should be freed after use */ + bracketTable->table = (unsigned char *)SECUREC_MALLOC(SECUREC_BRACKET_TABLE_SIZE); + if (bracketTable->table == NULL) { + return -1; + } + } + return 0; +} + +/* + * free buffer for wchar version of %[ + */ +static void SecFreeBracketTable(SecBracketTable *bracketTable) +{ + if (bracketTable->table != NULL) { + SECUREC_FREE(bracketTable->table); + bracketTable->table = NULL; + } +} +#endif + +#ifdef SECUREC_FOR_WCHAR +/* + * Formatting input core functions for wchar version.Called by a function such as vsscanf_s + */ +int SecInputSW(SecFileStream *stream, const wchar_t *cFormat, va_list argList) +#else +/* + * Formatting input core functions for char version.Called by a function such as vswscanf_s + */ +int SecInputS(SecFileStream *stream, const char *cFormat, va_list argList) +#endif +{ + const SecUnsignedChar *format = (const SecUnsignedChar *)cFormat; + SecBracketTable bracketTable = SECUREC_INIT_BRACKET_TABLE; + SecScanSpec spec; + SecInt ch = 0; + int charCount = 0; + int doneCount = 0; + int formatError = 0; + int paraIsNull = 0; +#if SECUREC_ENABLE_SCANF_FLOAT + SecFloatSpec floatSpec; +#endif + int match = 0; + int errRet = 0; +#if SECUREC_ENABLE_SCANF_FLOAT + SecInitFloatSpec(&floatSpec); +#endif + /* format must not NULL */ + /* use err < 1 to claer 845 */ + while (errRet < 1 && *format != SECUREC_CHAR('\0')) { + /* skip space in format and space in input */ + if (SECUREC_IS_SPACE(*format)) { + SecInt nonSpaceChar = SecSkipSpaceChar(stream, &charCount); + /* eat all space chars and put fist no space char backup */ + SecUnGetChar(nonSpaceChar, stream, &charCount); + SecSkipSpaceFormat(&format); + continue; + } + + if (*format != SECUREC_CHAR('%')) { + ch = SecGetChar(stream, &charCount); + if ((int)(*format++) != (int)(ch)) { + SecUnGetChar(ch, stream, &charCount); + ++errRet; /* use plus to clear 845 */ + continue; + } +#ifndef SECUREC_FOR_WCHAR + if (SecIsLeadByte(ch) && SecDecodeLeadByte(ch, &format, stream, &charCount) != 0) { + ++errRet; + continue; + } +#endif + /* for next %n */ + if ((ch == SECUREC_EOF) && ((*format != SECUREC_CHAR('%')) || (*(format + 1) != SECUREC_CHAR('n')))) { + break; + } + continue; + } + + /* now *format is % */ + /* set default value for each % */ + SecSetDefaultScanSpec(&spec); + if (SecDecodeScanFlag(&format, &spec) != 0) { + formatError = 1; + ++errRet; + continue; + } + /* update wchar flag for %S %C */ + SecUpdateWcharFlagByType(*format, &spec); + +#if SECUREC_HAVE_WCHART == 0 + /* in kernel not support wide char */ + if (spec.isWChar > 0) { + formatError = 1; + ++errRet; + continue; + } +#endif + if (spec.widthSet != 0 && spec.width == 0) { + /* 0 width in format */ + ++errRet; + continue; + } + + spec.comChr = (unsigned char)(*format) | (SECUREC_CHAR('a') - SECUREC_CHAR('A')); /* to lowercase */ + spec.oriComChr = spec.comChr; + + if (spec.comChr != SECUREC_CHAR('n')) { + if (spec.comChr != SECUREC_CHAR('c') && spec.comChr != SECUREC_BRACE) { + ch = SecSkipSpaceChar(stream, &charCount); + } else { + ch = SecGetChar(stream, &charCount); + } + if (ch == SECUREC_EOF) { + ++errRet; + continue; + } + } + + /* now no 0 width in format and get one char from input */ + switch (spec.comChr) { + case SECUREC_CHAR('c'): /* also 'C' */ + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('s'): /* also 'S': */ + /* fall-through */ /* FALLTHRU */ + case SECUREC_BRACE: + /* check dest buffer and size */ + if (spec.suppress == 0) { + spec.argPtr = (void *)va_arg(argList, void *); + if (spec.argPtr == NULL) { + paraIsNull = 1; + ++errRet; + continue; + } + /* Get the next argument - size of the array in characters */ +#ifdef SECUREC_ON_64BITS + spec.arrayWidth = ((size_t)(va_arg(argList, size_t))) & 0xFFFFFFFFUL; +#else /* !SECUREC_ON_64BITS */ + spec.arrayWidth = (size_t)va_arg(argList, size_t); +#endif + if (spec.arrayWidth == 0 || (spec.isWChar <= 0 && spec.arrayWidth > SECUREC_STRING_MAX_LEN) || + (spec.isWChar > 0 && spec.arrayWidth > SECUREC_WCHAR_STRING_MAX_LEN)) { + /* do not clear buffer just go error */ + ++errRet; + continue; + } + /* One element is needed for '\0' for %s and %[ */ + if (spec.comChr != SECUREC_CHAR('c')) { + --spec.arrayWidth; + } + } else { + /* Set argPtr to NULL is necessary, in supress mode we don't use argPtr to store data */ + spec.argPtr = NULL; + } + + if (spec.comChr == 'c') { + if (spec.widthSet == 0) { + spec.widthSet = 1; + spec.width = 1; + } + } else if (spec.comChr == SECUREC_BRACE) { + /* malloc when first %[ is meet for wchar version */ +#ifdef SECUREC_FOR_WCHAR + if (SecAllocBracketTable(&bracketTable) != 0) { + ++errRet; + continue; + } + +#endif + (void)memset(bracketTable.table, 0, (size_t)SECUREC_BRACKET_TABLE_SIZE); + if (SecSetupBracketTable(&format, &bracketTable) != 0) { + ++errRet; + continue; + } + + if (*format == SECUREC_CHAR('\0')) { + if (spec.suppress == 0 && spec.arrayWidth > 0) { + SecAddEndingZero(spec.argPtr, &spec); + } + ++errRet; + /* truncated format */ + continue; + } + + } + /* un set last char to stream */ + SecUnGetChar(ch, stream, &charCount); + /* scanset completed. Now read string */ + if (SecInputString(stream, &spec, &bracketTable, &charCount, &doneCount) != 0) { + ++errRet; + continue; + } + break; + case SECUREC_CHAR('p'): + /* make %hp same as %p */ + spec.numberWidth = SECUREC_NUM_WIDTH_INT; +#ifdef SECUREC_ON_64BITS + spec.isInt64Arg = 1; +#endif + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('o'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('u'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('x'): + /* un set last char to stream */ + SecUnGetChar(ch, stream, &charCount); + if (SecInputNumber(stream, &spec, &charCount) != 0) { + ++errRet; + continue; + } + if (spec.suppress == 0) { + spec.argPtr = (void *)va_arg(argList, void *); + if (spec.argPtr == NULL) { + paraIsNull = 1; + ++errRet; + continue; + } + SecAssignNumber(&spec); + ++doneCount; + } + break; + case SECUREC_CHAR('n'): /* char count */ + if (spec.suppress == 0) { + spec.argPtr = (void *)va_arg(argList, void *); + if (spec.argPtr == NULL) { + paraIsNull = 1; + ++errRet; + continue; + } + spec.number = (unsigned long)(unsigned int)charCount; + spec.isInt64Arg = 0; + SecAssignNumber(&spec); + } + break; + case SECUREC_CHAR('e'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('f'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('g'): /* scan a float */ +#if SECUREC_ENABLE_SCANF_FLOAT + /* un set last char to stream */ + SecUnGetChar(ch, stream, &charCount); + if (SecInputFloat(stream, &spec, &floatSpec, &charCount) != 0) { + ++errRet; + continue; + } + if (spec.suppress == 0) { + spec.argPtr = (void *)va_arg(argList, void *); + if (spec.argPtr == NULL) { + ++errRet; + paraIsNull = 1; + continue; + } +#ifdef SECUREC_FOR_WCHAR + if (SecAssignFloatW(&floatSpec, &spec) != 0) { + ++errRet; + continue; + } +#else + SecAssignFloat(floatSpec.floatStr, spec.numberWidth, spec.argPtr); +#endif + ++doneCount; + } + + break; +#else /* SECUREC_ENABLE_SCANF_FLOAT */ + ++errRet; + continue; +#endif + default: + if ((int)(*format) != (int)ch) { + SecUnGetChar(ch, stream, &charCount); + formatError = 1; + ++errRet; + continue; + } else { + --match; + } + } + + ++match; + ++format; + if ((ch == SECUREC_EOF) && ((*format != SECUREC_CHAR('%')) || (*(format + 1) != SECUREC_CHAR('n')))) { + break; + } + } + +#ifdef SECUREC_FOR_WCHAR + SecFreeBracketTable(&bracketTable); +#endif + +#if SECUREC_ENABLE_SCANF_FLOAT + SecClearFloatSpec(&floatSpec, &doneCount); +#endif + +#if SECUREC_ENABLE_SCANF_FILE + SecAdjustStream(stream); +#endif + + if (ch == SECUREC_EOF) { + return ((doneCount || match) ? doneCount : SECUREC_SCANF_EINVAL); + } else if (formatError != 0 || paraIsNull != 0) { + /* Invalid Input Format or parameter */ + return SECUREC_SCANF_ERROR_PARA; + } + + return doneCount; +} + +#if SECUREC_ENABLE_SCANF_FILE + +#if defined(SECUREC_NO_STD_UNGETC) +/* + * Get char from stdin or buffer + */ +static SecInt SecGetCharFromStdin(SecFileStream *stream) +{ + SecInt ch; + if (stream->fUnget == 1) { + ch = (SecInt) stream->lastChar; + stream->fUnget = 0; + } else { + ch = SECUREC_GETC(stream->pf); + stream->lastChar = (unsigned int)ch; + } + return ch; +} +#else +/* + * Get char from stdin or buffer use std function + */ +static SecInt SecGetCharFromStdin(const SecFileStream *stream) +{ + SecInt ch; + ch = SECUREC_GETC(stream->pf); + return ch; +} +#endif + +static void SecSkipBomHeader(SecFileStream *stream) +{ +#ifdef SECUREC_FOR_WCHAR + if (stream->count >= SECUREC_BOM_HEADER_SIZE && + (((unsigned char)(stream->base[0]) == SECUREC_BOM_HEADER_LE_1ST && + (unsigned char)(stream->base[1]) == SECUREC_BOM_HEADER_LE_2ST) || + ((unsigned char)(stream->base[0]) == SECUREC_BOM_HEADER_BE_1ST && + (unsigned char)(stream->base[1]) == SECUREC_BOM_HEADER_BE_2ST))) { + + /* the stream->count must be a multiple of sizeof(SecChar), + * otherwise this function will return SECUREC_EOF when read the last character + */ + if ((stream->count - SECUREC_BOM_HEADER_SIZE) % (int)sizeof(SecChar) != 0) { + int ret = (int)fread(stream->base + stream->count, (size_t)1, + (size_t)SECUREC_BOM_HEADER_SIZE, stream->pf); + if (ret > 0 && ret <= SECUREC_BUFFERED_BLOK_SIZE) { + stream->count += ret; + } + } + /* it's BOM header, skip */ + stream->count -= SECUREC_BOM_HEADER_SIZE; + stream->cur += SECUREC_BOM_HEADER_SIZE; + } +#else + if (stream->count >= SECUREC_UTF8_BOM_HEADER_SIZE && + (unsigned char)(stream->base[0]) == SECUREC_UTF8_BOM_HEADER_1ST && + (unsigned char)(stream->base[1]) == SECUREC_UTF8_BOM_HEADER_2ND && + (unsigned char)(stream->base[2]) == SECUREC_UTF8_BOM_HEADER_3RD) { /* 2 offset of third head character */ + /* it's BOM header, skip */ + stream->count -= SECUREC_UTF8_BOM_HEADER_SIZE; + stream->cur += SECUREC_UTF8_BOM_HEADER_SIZE; + } +#endif +} +/* + * Get char from file stream or buffer + */ +static SecInt SecGetCharFromFile(SecFileStream *stream) +{ + SecInt ch; + if (stream->count == 0) { + int firstReadOnFile = 0; + /* load file to buffer */ + if (stream->base == NULL) { + stream->base = (char *)SECUREC_MALLOC(SECUREC_BUFFERED_BLOK_SIZE + 1); + if (stream->base == NULL) { + return SECUREC_EOF; + } + stream->base[SECUREC_BUFFERED_BLOK_SIZE] = '\0'; /* for tool Warning string null */ + } + /* LSD add 2014.3.21 */ + if (stream->oriFilePos == SECUREC_UNINITIALIZED_FILE_POS) { + stream->oriFilePos = ftell(stream->pf); /* save original file read position */ + firstReadOnFile = 1; + } + stream->count = (int)fread(stream->base, (size_t)1, (size_t)SECUREC_BUFFERED_BLOK_SIZE, stream->pf); + stream->base[SECUREC_BUFFERED_BLOK_SIZE] = '\0'; /* for tool Warning string null */ + if (stream->count == 0 || stream->count > SECUREC_BUFFERED_BLOK_SIZE) { + return SECUREC_EOF; + } + stream->cur = stream->base; + stream->flag |= SECUREC_LOAD_FILE_TO_MEM_FLAG; + if (firstReadOnFile != 0) { + SecSkipBomHeader(stream); + } + } + /* according wchar_t has two bytes */ + ch = (SecInt)((stream->count -= (int)sizeof(SecChar)) >= 0 ? \ + (SecInt)(SECUREC_CHAR_MASK & \ + (unsigned int)(int)(*((const SecChar *)(const void *)stream->cur))) : SECUREC_EOF); + stream->cur += sizeof(SecChar); + + if (ch != SECUREC_EOF && stream->base != NULL) { + stream->fileRealRead += (int)sizeof(SecChar); + } + return ch; +} +#endif + +/* + * Get char for wchar version + */ +static SecInt SecGetChar(SecFileStream *stream, int *counter) +{ + SecInt ch = SECUREC_EOF; +#if SECUREC_ENABLE_SCANF_FILE + if ((stream->flag & SECUREC_FROM_STDIN_FLAG) > 0) { + ch = SecGetCharFromStdin(stream); + } else if ((stream->flag & SECUREC_FILE_STREAM_FLAG) > 0) { + ch = SecGetCharFromFile(stream); + } +#endif + if ((stream->flag & SECUREC_MEM_STR_FLAG) > 0) { + /* according wchar_t has two bytes */ + ch = (SecInt)((stream->count -= (int)sizeof(SecChar)) >= 0 ? \ + (SecInt)(SECUREC_CHAR_MASK & \ + (unsigned int)(int)(*((const SecChar *)(const void *)stream->cur))) : SECUREC_EOF); + stream->cur += sizeof(SecChar); + } + *counter = *counter + 1; + return ch; +} + +/* + * Unget Public realizatio char for wchar and char version + */ +static void SecUnGetCharImpl(SecInt ch, SecFileStream *stream) +{ + if ((stream->flag & SECUREC_FROM_STDIN_FLAG) > 0) { +#if SECUREC_ENABLE_SCANF_FILE +#if defined(SECUREC_NO_STD_UNGETC) + stream->lastChar = (unsigned int)ch; + stream->fUnget = 1; +#else + (void)SECUREC_UN_GETC(ch, stream->pf); +#endif +#else + (void)ch; /* to clear e438 last value assigned not used , the compiler will optimize this code */ +#endif + } else if ((stream->flag & SECUREC_MEM_STR_FLAG) || (stream->flag & SECUREC_LOAD_FILE_TO_MEM_FLAG) > 0) { + if (stream->cur > stream->base) { + stream->cur -= sizeof(SecChar); + stream->count += (int)sizeof(SecChar); + } + } +#if SECUREC_ENABLE_SCANF_FILE + if ((stream->flag & SECUREC_FILE_STREAM_FLAG) > 0 && stream->base) { + stream->fileRealRead -= (int)sizeof(SecChar); + } +#endif +} + +/* + * Unget char for char version + */ +static void SecUnGetChar(SecInt ch, SecFileStream *stream, int *counter) +{ + if (ch != SECUREC_EOF) { + SecUnGetCharImpl(ch, stream); + } + *counter = *counter - 1; +} + +/* + * Skip space char by isspace + */ +static SecInt SecSkipSpaceChar(SecFileStream *stream, int *counter) +{ + SecInt ch; + do { + ch = SecGetChar(stream, counter); + } while (ch != SECUREC_EOF && SECUREC_IS_SPACE(ch)); + return ch; +} +#endif /* __INPUT_INL__5D13A042_DC3F_4ED9_A8D1_882811274C27 */ + diff --git a/third_party/securec/src/memcpy_s.c b/third_party/securec/src/memcpy_s.c new file mode 100644 index 00000000..5eb100f4 --- /dev/null +++ b/third_party/securec/src/memcpy_s.c @@ -0,0 +1,577 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 +#include "securecutil.h" + +#ifndef SECUREC_MEMCOPY_WITH_PERFORMANCE +#define SECUREC_MEMCOPY_WITH_PERFORMANCE 0 +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMCOPY_WITH_PERFORMANCE +#ifndef SECUREC_MEMCOPY_THRESHOLD_SIZE +#define SECUREC_MEMCOPY_THRESHOLD_SIZE 64UL +#endif +/* + * Determine whether the address is 8-byte aligned, use static to increase performance + * return 0 is aligned + */ +static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) +{ + return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ +} + +#define SECUREC_SMALL_MEM_COPY do { \ + if (SECUREC_ADDR_ALIGNED_8(dest) && SECUREC_ADDR_ALIGNED_8(src)) { \ + /* use struct assignment */ \ + switch (count) { \ + case 1: \ + *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)src; \ + break; \ + case 2: \ + *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)src; \ + break; \ + case 3: \ + *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)src; \ + break; \ + case 4: \ + *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)src; \ + break; \ + case 5: \ + *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)src; \ + break; \ + case 6: \ + *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)src; \ + break; \ + case 7: \ + *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)src; \ + break; \ + case 8: \ + *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)src; \ + break; \ + case 9: \ + *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)src; \ + break; \ + case 10: \ + *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)src; \ + break; \ + case 11: \ + *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)src; \ + break; \ + case 12: \ + *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)src; \ + break; \ + case 13: \ + *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)src; \ + break; \ + case 14: \ + *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)src; \ + break; \ + case 15: \ + *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)src; \ + break; \ + case 16: \ + *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)src; \ + break; \ + case 17: \ + *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)src; \ + break; \ + case 18: \ + *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)src; \ + break; \ + case 19: \ + *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)src; \ + break; \ + case 20: \ + *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)src; \ + break; \ + case 21: \ + *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)src; \ + break; \ + case 22: \ + *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)src; \ + break; \ + case 23: \ + *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)src; \ + break; \ + case 24: \ + *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)src; \ + break; \ + case 25: \ + *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)src; \ + break; \ + case 26: \ + *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)src; \ + break; \ + case 27: \ + *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)src; \ + break; \ + case 28: \ + *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)src; \ + break; \ + case 29: \ + *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)src; \ + break; \ + case 30: \ + *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)src; \ + break; \ + case 31: \ + *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)src; \ + break; \ + case 32: \ + *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)src; \ + break; \ + case 33: \ + *(SecStrBuf33 *)dest = *(const SecStrBuf33 *)src; \ + break; \ + case 34: \ + *(SecStrBuf34 *)dest = *(const SecStrBuf34 *)src; \ + break; \ + case 35: \ + *(SecStrBuf35 *)dest = *(const SecStrBuf35 *)src; \ + break; \ + case 36: \ + *(SecStrBuf36 *)dest = *(const SecStrBuf36 *)src; \ + break; \ + case 37: \ + *(SecStrBuf37 *)dest = *(const SecStrBuf37 *)src; \ + break; \ + case 38: \ + *(SecStrBuf38 *)dest = *(const SecStrBuf38 *)src; \ + break; \ + case 39: \ + *(SecStrBuf39 *)dest = *(const SecStrBuf39 *)src; \ + break; \ + case 40: \ + *(SecStrBuf40 *)dest = *(const SecStrBuf40 *)src; \ + break; \ + case 41: \ + *(SecStrBuf41 *)dest = *(const SecStrBuf41 *)src; \ + break; \ + case 42: \ + *(SecStrBuf42 *)dest = *(const SecStrBuf42 *)src; \ + break; \ + case 43: \ + *(SecStrBuf43 *)dest = *(const SecStrBuf43 *)src; \ + break; \ + case 44: \ + *(SecStrBuf44 *)dest = *(const SecStrBuf44 *)src; \ + break; \ + case 45: \ + *(SecStrBuf45 *)dest = *(const SecStrBuf45 *)src; \ + break; \ + case 46: \ + *(SecStrBuf46 *)dest = *(const SecStrBuf46 *)src; \ + break; \ + case 47: \ + *(SecStrBuf47 *)dest = *(const SecStrBuf47 *)src; \ + break; \ + case 48: \ + *(SecStrBuf48 *)dest = *(const SecStrBuf48 *)src; \ + break; \ + case 49: \ + *(SecStrBuf49 *)dest = *(const SecStrBuf49 *)src; \ + break; \ + case 50: \ + *(SecStrBuf50 *)dest = *(const SecStrBuf50 *)src; \ + break; \ + case 51: \ + *(SecStrBuf51 *)dest = *(const SecStrBuf51 *)src; \ + break; \ + case 52: \ + *(SecStrBuf52 *)dest = *(const SecStrBuf52 *)src; \ + break; \ + case 53: \ + *(SecStrBuf53 *)dest = *(const SecStrBuf53 *)src; \ + break; \ + case 54: \ + *(SecStrBuf54 *)dest = *(const SecStrBuf54 *)src; \ + break; \ + case 55: \ + *(SecStrBuf55 *)dest = *(const SecStrBuf55 *)src; \ + break; \ + case 56: \ + *(SecStrBuf56 *)dest = *(const SecStrBuf56 *)src; \ + break; \ + case 57: \ + *(SecStrBuf57 *)dest = *(const SecStrBuf57 *)src; \ + break; \ + case 58: \ + *(SecStrBuf58 *)dest = *(const SecStrBuf58 *)src; \ + break; \ + case 59: \ + *(SecStrBuf59 *)dest = *(const SecStrBuf59 *)src; \ + break; \ + case 60: \ + *(SecStrBuf60 *)dest = *(const SecStrBuf60 *)src; \ + break; \ + case 61: \ + *(SecStrBuf61 *)dest = *(const SecStrBuf61 *)src; \ + break; \ + case 62: \ + *(SecStrBuf62 *)dest = *(const SecStrBuf62 *)src; \ + break; \ + case 63: \ + *(SecStrBuf63 *)dest = *(const SecStrBuf63 *)src; \ + break; \ + case 64: \ + *(SecStrBuf64 *)dest = *(const SecStrBuf64 *)src; \ + break; \ + default: \ + break; \ + } /* END switch */ \ + } else { \ + char *tmpDest = (char *)dest; \ + const char *tmpSrc = (const char *)src; \ + switch (count) { \ + case 64: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 63: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 62: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 61: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 60: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 59: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 58: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 57: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 56: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 55: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 54: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 53: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 52: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 51: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 50: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 49: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 48: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 47: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 46: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 45: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 44: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 43: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 42: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 41: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 40: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 39: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 38: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 37: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 36: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 35: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 34: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 33: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 32: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 31: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 30: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 29: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 28: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 27: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 26: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 25: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 24: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 23: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 22: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 21: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 20: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 19: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 18: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 17: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 16: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 15: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 14: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 13: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 12: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 11: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 10: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 9: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 8: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 7: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 6: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 5: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 4: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 3: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 2: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 1: \ + *(tmpDest++) = *(tmpSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + default: \ + break; \ + } \ + } \ +} SECUREC_WHILE_ZERO +#endif + +/* + * Handling errors + */ +static errno_t SecMemcpyError(void *dest, size_t destMax, const void *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("memcpy_s"); + return ERANGE; + } + if (dest == NULL || src == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("memcpy_s"); + if (dest != NULL) { + (void)memset(dest, 0, destMax); + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > destMax) { + (void)memset(dest, 0, destMax); + SECUREC_ERROR_INVALID_RANGE("memcpy_s"); + return ERANGE_AND_RESET; + } + if (dest == src) { + return EOK; + } + if ((dest > src && dest < (const void *)((const unsigned char *)src + count)) || \ + (src > dest && src < (void *)((unsigned char *)dest + count))) { + (void)memset(dest, 0, destMax); + SECUREC_ERROR_BUFFER_OVERLAP("memcpy_s"); + return EOVERLAP_AND_RESET; + } + /* count == 0 also return EOK */ + return EOK; +} + +#if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMCOPY_WITH_PERFORMANCE +/* + * Performance optimization + */ +static void SecDoMemcpyOpt(void *dest, const void *src, size_t count) +{ + if (count > SECUREC_MEMCOPY_THRESHOLD_SIZE) { + SecDoMemcpy(dest, src, count); + } else { + SECUREC_SMALL_MEM_COPY; + } + return; +} +#endif + +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) + /* fread API in windows will call memcpy_s and pass 0xffffffff to destMax. + * To avoid the failure of fread, we don't check desMax limit. + */ +#define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ + (dest) != NULL && (src) != NULL && \ + (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) +#else +#define SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count) (SECUREC_LIKELY((count) <= (destMax) && \ + (dest) != NULL && (src) != NULL && \ + (destMax) <= SECUREC_MEM_MAX_LEN && \ + (count) > 0 && SECUREC_MEMORY_NO_OVERLAP((dest), (src), (count)))) +#endif + +/* + * + * The memcpy_s function copies n characters from the object pointed to by src into the object pointed to by dest + * + * + * dest Destination buffer. + * destMax Size of the destination buffer. + * src Buffer to copy from. + * count Number of characters to copy + * + * + * dest buffer is updated. + * + * + * EOK Success + * EINVAL dest is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * ERANGE destMax > SECUREC_MEM_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * and dest != NULL and src != NULL + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and + * count <= destMax destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL + * and src != NULL and dest != src + * + * if an error occured, dest will be filled with 0. + * If the source and destination overlap, the behavior of memcpy_s is undefined. + * Use memmove_s to handle overlapping regions. + */ +errno_t memcpy_s(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { +#if SECUREC_MEMCOPY_WITH_PERFORMANCE + SecDoMemcpyOpt(dest, src, count); +#else + SecDoMemcpy(dest, src, count); +#endif + return EOK; + } + /* meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(memcpy_s); +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS +/* + * Performance optimization + */ +errno_t memcpy_sOptAsm(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_MEMCPY_PARAM_OK(dest, destMax, src, count)) { + SecDoMemcpyOpt(dest, src, count); + return EOK; + } + /* meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} + +/* trim judgement on "destMax <= SECUREC_MEM_MAX_LEN" */ +errno_t memcpy_sOptTc(void *dest, size_t destMax, const void *src, size_t count) +{ + if (SECUREC_LIKELY(count <= destMax && dest != NULL && src != NULL && \ + count > 0 && \ + ((dest > src && (const void *)((const unsigned char *)src + count) <= dest) || \ + (src > dest && (void *)((unsigned char *)dest + count) <= src)))) { + SecDoMemcpyOpt(dest, src, count); + return EOK; + } + /* meet some runtime violation, return error code */ + return SecMemcpyError(dest, destMax, src, count); +} +#endif + diff --git a/third_party/securec/src/memmove_s.c b/third_party/securec/src/memmove_s.c new file mode 100644 index 00000000..ec6d04a7 --- /dev/null +++ b/third_party/securec/src/memmove_s.c @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securecutil.h" + +#ifdef SECUREC_NOT_CALL_LIBC_CORE_API +/* + * Implementing memory data movement + */ +static void SecUtilMemmove(void *dst, const void *src, size_t count) +{ + unsigned char *pDest = (unsigned char *)dst; + const unsigned char *pSrc = (const unsigned char *)src; + size_t maxCount = count; + + if (dst <= src || pDest >= (pSrc + maxCount)) { + /* + * Non-Overlapping Buffers + * copy from lower addresses to higher addresses + */ + while (maxCount--) { + *pDest = *pSrc; + ++pDest; + ++pSrc; + } + } else { + /* + * Overlapping Buffers + * copy from higher addresses to lower addresses + */ + pDest = pDest + maxCount - 1; + pSrc = pSrc + maxCount - 1; + + while (maxCount--) { + *pDest = *pSrc; + + --pDest; + --pSrc; + } + } +} +#endif + +/* + * + * The memmove_s function copies count bytes of characters from src to dest. + * This function can be assigned correctly when memory overlaps. + * + * dest Destination object. + * destMax Size of the destination buffer. + * src Source object. + * count Number of characters to copy. + * + * + * dest buffer is uptdated. + * + * + * EOK Success + * EINVAL dest is NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * ERANGE destMax > SECUREC_MEM_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET count > destMax and dest != NULL and src != NULL and destMax != 0 + * and destMax <= SECUREC_MEM_MAX_LEN + * + * If an error occured, dest will be filled with 0 when dest and destMax valid. + * If some regions of the source area and the destination overlap, memmove_s + * ensures that the original source bytes in the overlapping region are copied + * before being overwritten. + */ +errno_t memmove_s(void *dest, size_t destMax, const void *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("memmove_s"); + return ERANGE; + } + if (dest == NULL || src == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("memmove_s"); + if (dest != NULL) { + (void)memset(dest, 0, destMax); + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > destMax) { + (void)memset(dest, 0, destMax); + SECUREC_ERROR_INVALID_RANGE("memmove_s"); + return ERANGE_AND_RESET; + } + if (dest == src) { + return EOK; + } + + if (count > 0) { +#ifdef SECUREC_NOT_CALL_LIBC_CORE_API + SecUtilMemmove(dest, src, count); +#else + /* use underlying memmove for performance consideration */ + (void)memmove(dest, src, count); +#endif + } + return EOK; +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(memmove_s); +#endif + diff --git a/third_party/securec/src/memset_s.c b/third_party/securec/src/memset_s.c new file mode 100644 index 00000000..cd3f9887 --- /dev/null +++ b/third_party/securec/src/memset_s.c @@ -0,0 +1,522 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMSET 1 + +#include "securecutil.h" + +#ifndef SECUREC_MEMSET_WITH_PERFORMANCE +#define SECUREC_MEMSET_WITH_PERFORMANCE 0 +#endif + +#define SECUREC_MEMSET_PARAM_OK(dest, destMax, count) (SECUREC_LIKELY((count) <= (destMax) && \ + (dest) != NULL && (destMax) <= SECUREC_MEM_MAX_LEN)) + + +#if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMSET_WITH_PERFORMANCE +/* + * Determine whether the address is 8-byte aligned, use static to increase performance + * return 0 is aligned + */ +static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) +{ + return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ +} + +/* use union to clear strict-aliasing warning */ +typedef union { + SecStrBuf32 buf32; + SecStrBuf31 buf31; + SecStrBuf30 buf30; + SecStrBuf29 buf29; + SecStrBuf28 buf28; + SecStrBuf27 buf27; + SecStrBuf26 buf26; + SecStrBuf25 buf25; + SecStrBuf24 buf24; + SecStrBuf23 buf23; + SecStrBuf22 buf22; + SecStrBuf21 buf21; + SecStrBuf20 buf20; + SecStrBuf19 buf19; + SecStrBuf18 buf18; + SecStrBuf17 buf17; + SecStrBuf16 buf16; + SecStrBuf15 buf15; + SecStrBuf14 buf14; + SecStrBuf13 buf13; + SecStrBuf12 buf12; + SecStrBuf11 buf11; + SecStrBuf10 buf10; + SecStrBuf9 buf9; + SecStrBuf8 buf8; + SecStrBuf7 buf7; + SecStrBuf6 buf6; + SecStrBuf5 buf5; + SecStrBuf4 buf4; + SecStrBuf3 buf3; + SecStrBuf2 buf2; + SecStrBuf1 buf1; +} SecStrBuf32Union; +/* C standard initializes the first member of the consortium. */ +static const SecStrBuf32 g_allZero = {{ + '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', + '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', + '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0', + '\0', '\0', '\0', '\0', '\0', '\0', '\0', '\0' +}}; +static const SecStrBuf32 g_allFF = {{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF +}}; + +static const SecStrBuf32Union *SecStrictAliasingCast(const SecStrBuf32 *buf) +{ + return (const SecStrBuf32Union *)buf; +} + +#ifndef SECUREC_MEMSET_THRESHOLD_SIZE +#define SECUREC_MEMSET_THRESHOLD_SIZE 32UL +#endif + +#define SECUREC_UNALIGNED_SET do { \ + char *pcDest = (char *)dest; \ + switch (count) { \ + case 32: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 31: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 30: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 29: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 28: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 27: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 26: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 25: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 24: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 23: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 22: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 21: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 20: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 19: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 18: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 17: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 16: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 15: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 14: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 13: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 12: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 11: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 10: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 9: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 8: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 7: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 6: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 5: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 4: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 3: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 2: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + case 1: \ + *(pcDest++) = (char)c; \ + /* fall-through */ /* FALLTHRU */ \ + default: \ + break; \ + } \ +} SECUREC_WHILE_ZERO + +#define SECUREC_ALIGNED_SET_OPT_ZERO_FF do { \ + switch (c) { \ + case 0: \ + switch (count) { \ + case 1: \ + *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)(&((SecStrictAliasingCast(&g_allZero))->buf1)); \ + break; \ + case 2: \ + *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)(&((SecStrictAliasingCast(&g_allZero))->buf2)); \ + break; \ + case 3: \ + *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)(&((SecStrictAliasingCast(&g_allZero))->buf3)); \ + break; \ + case 4: \ + *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)(&((SecStrictAliasingCast(&g_allZero))->buf4)); \ + break; \ + case 5: \ + *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)(&((SecStrictAliasingCast(&g_allZero))->buf5)); \ + break; \ + case 6: \ + *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)(&((SecStrictAliasingCast(&g_allZero))->buf6)); \ + break; \ + case 7: \ + *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)(&((SecStrictAliasingCast(&g_allZero))->buf7)); \ + break; \ + case 8: \ + *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)(&((SecStrictAliasingCast(&g_allZero))->buf8)); \ + break; \ + case 9: \ + *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)(&((SecStrictAliasingCast(&g_allZero))->buf9)); \ + break; \ + case 10: \ + *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)(&((SecStrictAliasingCast(&g_allZero))->buf10)); \ + break; \ + case 11: \ + *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)(&((SecStrictAliasingCast(&g_allZero))->buf11)); \ + break; \ + case 12: \ + *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)(&((SecStrictAliasingCast(&g_allZero))->buf12)); \ + break; \ + case 13: \ + *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)(&((SecStrictAliasingCast(&g_allZero))->buf13)); \ + break; \ + case 14: \ + *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)(&((SecStrictAliasingCast(&g_allZero))->buf14)); \ + break; \ + case 15: \ + *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)(&((SecStrictAliasingCast(&g_allZero))->buf15)); \ + break; \ + case 16: \ + *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)(&((SecStrictAliasingCast(&g_allZero))->buf16)); \ + break; \ + case 17: \ + *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)(&((SecStrictAliasingCast(&g_allZero))->buf17)); \ + break; \ + case 18: \ + *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)(&((SecStrictAliasingCast(&g_allZero))->buf18)); \ + break; \ + case 19: \ + *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)(&((SecStrictAliasingCast(&g_allZero))->buf19)); \ + break; \ + case 20: \ + *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)(&((SecStrictAliasingCast(&g_allZero))->buf20)); \ + break; \ + case 21: \ + *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)(&((SecStrictAliasingCast(&g_allZero))->buf21)); \ + break; \ + case 22: \ + *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)(&((SecStrictAliasingCast(&g_allZero))->buf22)); \ + break; \ + case 23: \ + *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)(&((SecStrictAliasingCast(&g_allZero))->buf23)); \ + break; \ + case 24: \ + *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)(&((SecStrictAliasingCast(&g_allZero))->buf24)); \ + break; \ + case 25: \ + *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)(&((SecStrictAliasingCast(&g_allZero))->buf25)); \ + break; \ + case 26: \ + *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)(&((SecStrictAliasingCast(&g_allZero))->buf26)); \ + break; \ + case 27: \ + *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)(&((SecStrictAliasingCast(&g_allZero))->buf27)); \ + break; \ + case 28: \ + *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)(&((SecStrictAliasingCast(&g_allZero))->buf28)); \ + break; \ + case 29: \ + *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)(&((SecStrictAliasingCast(&g_allZero))->buf29)); \ + break; \ + case 30: \ + *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)(&((SecStrictAliasingCast(&g_allZero))->buf30)); \ + break; \ + case 31: \ + *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)(&((SecStrictAliasingCast(&g_allZero))->buf31)); \ + break; \ + case 32: \ + *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)(&((SecStrictAliasingCast(&g_allZero))->buf32)); \ + break; \ + default: \ + break; \ + } \ + break; \ + case 0xFF: \ + switch (count) { \ + case 1: \ + *(SecStrBuf1 *)dest = *(const SecStrBuf1 *)(&((SecStrictAliasingCast(&g_allFF))->buf1)); \ + break; \ + case 2: \ + *(SecStrBuf2 *)dest = *(const SecStrBuf2 *)(&((SecStrictAliasingCast(&g_allFF))->buf2)); \ + break; \ + case 3: \ + *(SecStrBuf3 *)dest = *(const SecStrBuf3 *)(&((SecStrictAliasingCast(&g_allFF))->buf3)); \ + break; \ + case 4: \ + *(SecStrBuf4 *)dest = *(const SecStrBuf4 *)(&((SecStrictAliasingCast(&g_allFF))->buf4)); \ + break; \ + case 5: \ + *(SecStrBuf5 *)dest = *(const SecStrBuf5 *)(&((SecStrictAliasingCast(&g_allFF))->buf5)); \ + break; \ + case 6: \ + *(SecStrBuf6 *)dest = *(const SecStrBuf6 *)(&((SecStrictAliasingCast(&g_allFF))->buf6)); \ + break; \ + case 7: \ + *(SecStrBuf7 *)dest = *(const SecStrBuf7 *)(&((SecStrictAliasingCast(&g_allFF))->buf7)); \ + break; \ + case 8: \ + *(SecStrBuf8 *)dest = *(const SecStrBuf8 *)(&((SecStrictAliasingCast(&g_allFF))->buf8)); \ + break; \ + case 9: \ + *(SecStrBuf9 *)dest = *(const SecStrBuf9 *)(&((SecStrictAliasingCast(&g_allFF))->buf9)); \ + break; \ + case 10: \ + *(SecStrBuf10 *)dest = *(const SecStrBuf10 *)(&((SecStrictAliasingCast(&g_allFF))->buf10)); \ + break; \ + case 11: \ + *(SecStrBuf11 *)dest = *(const SecStrBuf11 *)(&((SecStrictAliasingCast(&g_allFF))->buf11)); \ + break; \ + case 12: \ + *(SecStrBuf12 *)dest = *(const SecStrBuf12 *)(&((SecStrictAliasingCast(&g_allFF))->buf12)); \ + break; \ + case 13: \ + *(SecStrBuf13 *)dest = *(const SecStrBuf13 *)(&((SecStrictAliasingCast(&g_allFF))->buf13)); \ + break; \ + case 14: \ + *(SecStrBuf14 *)dest = *(const SecStrBuf14 *)(&((SecStrictAliasingCast(&g_allFF))->buf14)); \ + break; \ + case 15: \ + *(SecStrBuf15 *)dest = *(const SecStrBuf15 *)(&((SecStrictAliasingCast(&g_allFF))->buf15)); \ + break; \ + case 16: \ + *(SecStrBuf16 *)dest = *(const SecStrBuf16 *)(&((SecStrictAliasingCast(&g_allFF))->buf16)); \ + break; \ + case 17: \ + *(SecStrBuf17 *)dest = *(const SecStrBuf17 *)(&((SecStrictAliasingCast(&g_allFF))->buf17)); \ + break; \ + case 18: \ + *(SecStrBuf18 *)dest = *(const SecStrBuf18 *)(&((SecStrictAliasingCast(&g_allFF))->buf18)); \ + break; \ + case 19: \ + *(SecStrBuf19 *)dest = *(const SecStrBuf19 *)(&((SecStrictAliasingCast(&g_allFF))->buf19)); \ + break; \ + case 20: \ + *(SecStrBuf20 *)dest = *(const SecStrBuf20 *)(&((SecStrictAliasingCast(&g_allFF))->buf20)); \ + break; \ + case 21: \ + *(SecStrBuf21 *)dest = *(const SecStrBuf21 *)(&((SecStrictAliasingCast(&g_allFF))->buf21)); \ + break; \ + case 22: \ + *(SecStrBuf22 *)dest = *(const SecStrBuf22 *)(&((SecStrictAliasingCast(&g_allFF))->buf22)); \ + break; \ + case 23: \ + *(SecStrBuf23 *)dest = *(const SecStrBuf23 *)(&((SecStrictAliasingCast(&g_allFF))->buf23)); \ + break; \ + case 24: \ + *(SecStrBuf24 *)dest = *(const SecStrBuf24 *)(&((SecStrictAliasingCast(&g_allFF))->buf24)); \ + break; \ + case 25: \ + *(SecStrBuf25 *)dest = *(const SecStrBuf25 *)(&((SecStrictAliasingCast(&g_allFF))->buf25)); \ + break; \ + case 26: \ + *(SecStrBuf26 *)dest = *(const SecStrBuf26 *)(&((SecStrictAliasingCast(&g_allFF))->buf26)); \ + break; \ + case 27: \ + *(SecStrBuf27 *)dest = *(const SecStrBuf27 *)(&((SecStrictAliasingCast(&g_allFF))->buf27)); \ + break; \ + case 28: \ + *(SecStrBuf28 *)dest = *(const SecStrBuf28 *)(&((SecStrictAliasingCast(&g_allFF))->buf28)); \ + break; \ + case 29: \ + *(SecStrBuf29 *)dest = *(const SecStrBuf29 *)(&((SecStrictAliasingCast(&g_allFF))->buf29)); \ + break; \ + case 30: \ + *(SecStrBuf30 *)dest = *(const SecStrBuf30 *)(&((SecStrictAliasingCast(&g_allFF))->buf30)); \ + break; \ + case 31: \ + *(SecStrBuf31 *)dest = *(const SecStrBuf31 *)(&((SecStrictAliasingCast(&g_allFF))->buf31)); \ + break; \ + case 32: \ + *(SecStrBuf32 *)dest = *(const SecStrBuf32 *)(&((SecStrictAliasingCast(&g_allFF))->buf32)); \ + break; \ + default: \ + break; \ + } \ + break; \ + default: \ + SECUREC_UNALIGNED_SET; \ + } /* END switch */ \ +} SECUREC_WHILE_ZERO +#endif + +/* + * Handling errors + */ +static errno_t SecMemsetError(void *dest, size_t destMax, int c, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("memset_s"); + return ERANGE; + } + if (dest == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("memset_s"); + return EINVAL; + } + if (count > destMax) { + (void)memset(dest, c, destMax); /* set entire buffer to value c */ + SECUREC_ERROR_INVALID_RANGE("memset_s"); + return ERANGE_AND_RESET; + } + return EOK; +} + +#if SECUREC_WITH_PERFORMANCE_ADDONS || SECUREC_MEMSET_WITH_PERFORMANCE +/* + * Performance optimization + */ +static void SecDoMemsetOpt(void *dest, int c, size_t count) +{ + if (count > SECUREC_MEMSET_THRESHOLD_SIZE) { + SecDoMemset(dest, c, count); + } else { + if (SECUREC_ADDR_ALIGNED_8(dest)) { + /* use struct assignment */ + SECUREC_ALIGNED_SET_OPT_ZERO_FF; + } else { + SECUREC_UNALIGNED_SET; + } + } + return; +} +#endif + +/* + * + * The memset_s function copies the value of c (converted to an unsigned char) + * into each of the first count characters of the object pointed to by dest. + * + * + * dest Pointer to destination. + * destMax The size of the buffer. + * c Character to set. + * count Number of characters. + * + * + * dest buffer is uptdated. + * + * + * EOK Success + * EINVAL dest == NULL and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN + * ERANGE destMax is 0 or destMax > SECUREC_MEM_MAX_LEN + * ERANGE_AND_RESET count > destMax and destMax != 0 and destMax <= SECUREC_MEM_MAX_LEN and dest != NULL + * + * if return ERANGE_AND_RESET then fill dest to c ,fill length is destMax + */ +errno_t memset_s(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { +#if SECUREC_MEMSET_WITH_PERFORMANCE + SecDoMemsetOpt(dest, c, count); +#else + SecDoMemset(dest, c, count); +#endif + return EOK; + } else { + /* meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c, count); + } +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(memset_s); +#endif + +#if SECUREC_WITH_PERFORMANCE_ADDONS +/* + * Performance optimization + */ +errno_t memset_sOptAsm(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_MEMSET_PARAM_OK(dest, destMax, count)) { + SecDoMemsetOpt(dest, c, count); + return EOK; + } + /* meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c, count); +} + +/* + * Performance optimization + */ +errno_t memset_sOptTc(void *dest, size_t destMax, int c, size_t count) +{ + if (SECUREC_LIKELY(count <= destMax && dest != NULL)) { + SecDoMemsetOpt(dest, c, count); + return EOK; + } + /* meet some runtime violation, return error code */ + return SecMemsetError(dest, destMax, c, count); +} +#endif + diff --git a/third_party/securec/src/output.inl b/third_party/securec/src/output.inl new file mode 100644 index 00000000..d4e136c5 --- /dev/null +++ b/third_party/securec/src/output.inl @@ -0,0 +1,1401 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 +#define OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 + +#define SECUREC_NULL_STRING_SIZE 8 +#define SECUREC_STATE_TABLE_SIZE 337 +#define SECUREC_OFFSET_BITS_WORD 16 +#define SECUREC_OFFSET_BITS_DWORD 32 + +#define SECUREC_OFFSET_DIV_OCTAL 3 +#define SECUREC_OFFSET_DIV_HEX 4 +#define SECUREC_RADIX_OCTAL 8 +#define SECUREC_RADIX_DECIMAL 10 +#define SECUREC_RADIX_HEX 16 +/* Use two displacements to eliminate compilation warnings */ +#define SECUREC_SHR_DWORD(x) (((x) >> 16) >> 16) +#define SECUREC_PREFIX_LEN 2 +/* size include '+' and '\0' */ +#define SECUREC_FLOAT_BUF_EXT 2 + + +#ifdef SECUREC_STACK_SIZE_LESS_THAN_1K +#define SECUREC_FMT_STR_LEN 8 +#else +#define SECUREC_FMT_STR_LEN 16 +#endif + +typedef struct { + unsigned int flags; + int fldWidth; + int precision; + int bufferIsWide; /* flag for buffer contains wide chars ;0 is not wide char */ + int dynWidth; /* %* 1 width from variable parameter ;0 not */ + int dynPrecision; /* %.* 1 precision from variable parameter ;0 not */ +} SecFormatAttr; + +typedef union { + char *str; /* not a null terminated string */ +#if SECUREC_HAVE_WCHART + wchar_t *wStr; +#endif +} SecFormatBuf; + +typedef union { + char str[SECUREC_BUFFER_SIZE + 1]; +#ifdef SECUREC_FOR_WCHAR + wchar_t wStr[SECUREC_BUFFER_SIZE + 1]; +#endif +} SecBuffer; + +#if SECUREC_ENABLE_SPRINTF_FLOAT +/* call system sprintf to format float value */ +static int SecIndirectSprintf(char *strDest, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + SECUREC_MASK_MSVC_CRT_WARNING + ret = vsprintf(strDest, format, argList); + SECUREC_END_MASK_MSVC_CRT_WARNING + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT +/* out put long double value to dest */ +static int SecFormatLongDboule(char *strDest,const SecFormatAttr *formatAttr, const char *fmt, long double ldValue) +{ + int fldWidth = ((formatAttr->flags & SECUREC_FLAG_LEFT) ? (-(formatAttr->fldWidth)) : formatAttr->fldWidth); + if (formatAttr->dynWidth && formatAttr->dynPrecision) { + return SecIndirectSprintf(strDest, fmt, fldWidth, formatAttr->precision, ldValue); + } else if (formatAttr->dynWidth) { + return SecIndirectSprintf(strDest, fmt, fldWidth, ldValue); + } else if (formatAttr->dynPrecision) { + return SecIndirectSprintf(strDest, fmt, formatAttr->precision, ldValue); + } + return SecIndirectSprintf(strDest, fmt, ldValue); +} +#endif + +/* out put double value to dest */ +static int SecFormatDboule(char *strDest, const SecFormatAttr *formatAttr, const char *fmt, double dValue) +{ + int fldWidth = ((formatAttr->flags & SECUREC_FLAG_LEFT) ? (-(formatAttr->fldWidth)) : formatAttr->fldWidth); + if (formatAttr->dynWidth && formatAttr->dynPrecision) { + return SecIndirectSprintf(strDest, fmt, fldWidth, formatAttr->precision, dValue); + } else if (formatAttr->dynWidth) { + return SecIndirectSprintf(strDest, fmt, fldWidth, dValue); + } else if (formatAttr->dynPrecision) { + return SecIndirectSprintf(strDest, fmt, formatAttr->precision, dValue); + } + return SecIndirectSprintf(strDest, fmt, dValue); +} +#endif + +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT +/* to clear e506 warning */ +static int SecIsSameSize(size_t sizeA, size_t sizeB) +{ + return sizeA == sizeB; +} +#endif + +#define SECUREC_SPECIAL_DWORD(val32, numBase) do { \ + --formatBuf.str; \ + *(formatBuf.str) = digits[(val32) % (numBase)]; \ +} while (((val32) /= (numBase)) != 0) + +#if defined(SECUREC_USE_SPECIAL_DIV64) || (defined(SECUREC_VXWORKS_VERSION_5_4) && !defined(SECUREC_ON_64BITS)) +/* + * Fast divide by 10 algorithm. + * Calculation divisor multiply 0xcccccccccccccccdULL, resultHi64 >> 3 as quotient + */ +static void SecU64Div10(SecUnsignedInt64 divisor, SecUnsignedInt64 *quotient, SecUnsignedInt32 *remainder) +{ + SecUnsignedInt64 mask = 0xffffffffULL; /* use 0xffffffffULL as 32 bit mask */ + SecUnsignedInt64 magicHi = 0xccccccccULL; /* fast divide 10 magic numbers high 32bit 0xccccccccULL */ + SecUnsignedInt64 magicLow = 0xcccccccdULL; /* fast divide 10 magic numbers low 32bit 0xcccccccdULL */ + SecUnsignedInt64 divisorHi = (SecUnsignedInt64)(SECUREC_SHR_DWORD(divisor)); /* hig 32 bit use */ + SecUnsignedInt64 divisorLow = (SecUnsignedInt64)(divisor & mask); /* low 32 bit mask */ + SecUnsignedInt64 factorHi = divisorHi * magicHi; + SecUnsignedInt64 factorLow1 = divisorHi * magicLow; + SecUnsignedInt64 factorLow2 = divisorLow * magicHi; + SecUnsignedInt64 factorLow3 = divisorLow * magicLow; + SecUnsignedInt64 carry = (factorLow1 & mask) + (factorLow2 & mask) + SECUREC_SHR_DWORD(factorLow3); + SecUnsignedInt64 resultHi64 = factorHi + SECUREC_SHR_DWORD(factorLow1) + \ + SECUREC_SHR_DWORD(factorLow2) + SECUREC_SHR_DWORD(carry); + + *quotient = resultHi64 >> 3; /* fast divide 10 magic numbers 3 */ + *remainder = (SecUnsignedInt32)(divisor - ((*quotient) * 10)); /* quotient mul 10 */ + return; +} +#if defined(SECUREC_VXWORKS_VERSION_5_4) && !defined(SECUREC_ON_64BITS) +/* + * Divide function for VXWORKS + */ +static int SecU64Div32(SecUnsignedInt64 divisor, SecUnsignedInt32 radix, + SecUnsignedInt64 *quotient, SecUnsignedInt32 *remainder) +{ + switch (radix) { + case SECUREC_RADIX_DECIMAL: + SecU64Div10(divisor, quotient, remainder); + break; + case SECUREC_RADIX_HEX: + *quotient = divisor >> SECUREC_OFFSET_DIV_HEX; + *remainder = divisor & 0xfULL; /* mask one hex number by 0xfULL */ + break; + case SECUREC_RADIX_OCTAL: + *quotient = divisor >> SECUREC_OFFSET_DIV_OCTAL; + *remainder = divisor & 0x7ULL; /* mask one hex number by 0x7ULL */ + break; + default: + return -1; + } + return 0; +} +#endif +#endif + +#if defined(SECUREC_USE_SPECIAL_DIV64) +/* The compiler does not provide 64 bit division problems */ +#define SECUREC_SPECIAL_QWORD_BASE10(val64) do { \ + SecUnsignedInt64 quotient = 0; \ + SecUnsignedInt32 digit = 0; \ + SecU64Div10((val64), &(quotient), &(digit)); \ + --formatBuf.str; \ + *(formatBuf.str) = digits[digit]; \ + (val64) = quotient; \ +} while ((val64) != 0) +#else +#define SECUREC_SPECIAL_QWORD_BASE10(val64) do { \ + --formatBuf.str; \ + *(formatBuf.str) = digits[(val64) % SECUREC_RADIX_DECIMAL]; \ +} while (((val64) /= SECUREC_RADIX_DECIMAL) != 0) +#endif +#define SECUREC_SPECIAL_QWORD(val64, numBase) do { \ + --formatBuf.str; \ + *(formatBuf.str) = digits[(val64) % (numBase)]; \ +} while (((val64) /= (numBase)) != 0) + + +#define SECUREC_SAFE_WRITE_STR_OPT(src, txtLen, outStream, outChars) do { \ + int ii_; \ + for (ii_ = 0; ii_ < (txtLen); ++ii_) { \ + *((SecChar *)(void *)((outStream)->cur)) = *(SecChar *)(src); \ + (outStream)->cur += sizeof(SecChar); \ + (src) = (src) + 1; \ + } \ + (outStream)->count -= (txtLen) * (int)(sizeof(SecChar)); \ + *(outChars) = *(outChars) + (txtLen); \ +} SECUREC_WHILE_ZERO + +#define SECUREC_SAFE_WRITE_STR(src, txtLen, outStream, outChars) do { \ + if ((txtLen) < 12) { /* performance optimization for mobile number length 12 */ \ + SECUREC_SAFE_WRITE_STR_OPT((src), (txtLen), (outStream), (outChars)); \ + } else { \ + SecDoMemcpy((outStream)->cur, (src), ((size_t)(unsigned int)(txtLen) * (sizeof(SecChar)))); \ + (outStream)->cur += (size_t)((size_t)(unsigned int)(txtLen) * (sizeof(SecChar))); \ + (outStream)->count -= (txtLen) * (int)(sizeof(SecChar)); \ + *(outChars) = *(outChars) + (txtLen); \ + } \ +} SECUREC_WHILE_ZERO + +#define SECUREC_SAFE_WRITE_CHAR(c, outStream, outChars) do { \ + *((SecChar *)(void *)((outStream)->cur)) = (SecChar)(c); \ + (outStream)->cur += sizeof(SecChar); \ + (outStream)->count -= (int)(sizeof(SecChar)); \ + *(outChars) = *(outChars) + 1; \ +} SECUREC_WHILE_ZERO + +#define SECUREC_SAFE_PADDING(padChar, padLen, outStream, outChars) do { \ + int ii_; \ + for (ii_ = 0; ii_ < (padLen); ++ii_) { \ + *((SecChar *)(void *)((outStream)->cur)) = (SecChar)(padChar); \ + (outStream)->cur += sizeof(SecChar); \ + } \ + (outStream)->count -= (padLen) * (int)(sizeof(SecChar)); \ + *(outChars) = *(outChars) + (padLen); \ +} SECUREC_WHILE_ZERO + +/* The count variable can be reduced to 0, and the external function complements the \0 terminator. */ +#define SECUREC_IS_REST_BUF_ENOUGH(stream, needLen) ((int)((stream)->count - \ + (int)(needLen) * (int)(sizeof(SecChar))) >= 0) + +#define SECUREC_FMT_STATE_OFFSET 256 +#ifdef SECUREC_FOR_WCHAR +#define SECUREC_FMT_TYPE(c, fmtTable) ((((unsigned int)(int)(c)) <= (unsigned int)(int)SECUREC_CHAR('~')) ? \ + ((fmtTable)[(unsigned char)(c)]) : 0) +#define SECUREC_DECODE_STATE(c, fmtTable, lastState) (SecFmtState)((((fmtTable)[(SECUREC_FMT_TYPE(c, (fmtTable))) * \ + ((unsigned char)STAT_INVALID + 1) + \ + (unsigned char)(lastState) + \ + SECUREC_FMT_STATE_OFFSET]))) +#else +#define SECUREC_DECODE_STATE(c, fmtTable, lastState) (SecFmtState)(((fmtTable)[((fmtTable)[(unsigned char)(c)]) * \ + ((unsigned char)STAT_INVALID + 1) + \ + (unsigned char)(lastState) + \ + SECUREC_FMT_STATE_OFFSET])) +#endif + +static void SecDecodeFlags(SecChar ch, SecFormatAttr *attr) +{ + switch (ch) { + case SECUREC_CHAR(' '): + attr->flags |= SECUREC_FLAG_SIGN_SPACE; + break; + case SECUREC_CHAR('+'): + attr->flags |= SECUREC_FLAG_SIGN; + break; + case SECUREC_CHAR('-'): + attr->flags |= SECUREC_FLAG_LEFT; + break; + case SECUREC_CHAR('0'): + attr->flags |= SECUREC_FLAG_LEADZERO; /* add zero th the front */ + break; + case SECUREC_CHAR('#'): + attr->flags |= SECUREC_FLAG_ALTERNATE; /* output %x with 0x */ + break; + default: + break; + } + return; +} + + +/* + * Decoded size identifier in format string to Reduce the number of lines of function code + */ +static int SecDecodeSizeI(SecFormatAttr *attr, const SecChar **format) +{ +#ifdef SECUREC_ON_64BITS + attr->flags |= SECUREC_FLAG_I64; /* %I to INT64 */ +#endif + if ((**format == SECUREC_CHAR('6')) && (*((*format) + 1) == SECUREC_CHAR('4'))) { + (*format) += 2; /* add 2 to skip I64 */ + attr->flags |= SECUREC_FLAG_I64; /* %I64 to INT64 */ + } else if ((**format == SECUREC_CHAR('3')) && (*((*format) + 1) == SECUREC_CHAR('2'))) { + (*format) += 2; /* add 2 to skip I32 */ + attr->flags &= ~SECUREC_FLAG_I64; /* %I64 to INT32 */ + } else if ((**format == SECUREC_CHAR('d')) || (**format == SECUREC_CHAR('i')) || + (**format == SECUREC_CHAR('o')) || (**format == SECUREC_CHAR('u')) || + (**format == SECUREC_CHAR('x')) || (**format == SECUREC_CHAR('X'))) { + /* do nothing */ + } else { + /* Compatibility code for "%I" just print I */ + return -1; + } + return 0; +} +/* + * Decoded size identifier in format string + */ +static int SecDecodeSize(SecChar ch, SecFormatAttr *attr, const SecChar **format) +{ + switch (ch) { +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + case SECUREC_CHAR('j'): + attr->flags |= SECUREC_FLAG_INTMAX; + break; +#endif + case SECUREC_CHAR('q'): + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('L'): + attr->flags |= SECUREC_FLAG_LONGLONG | SECUREC_FLAG_LONG_DOUBLE; + break; + case SECUREC_CHAR('l'): + if (**format == SECUREC_CHAR('l')) { + *format = *format + 1; + attr->flags |= SECUREC_FLAG_LONGLONG; /* long long */ + } else { + attr->flags |= SECUREC_FLAG_LONG; /* long int or wchar_t */ + } + break; + case SECUREC_CHAR('t'): + attr->flags |= SECUREC_FLAG_PTRDIFF; + break; +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + case SECUREC_CHAR('z'): + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('Z'): + attr->flags |= SECUREC_FLAG_SIZE; + break; +#endif + case SECUREC_CHAR('I'): + if (SecDecodeSizeI(attr, format) != 0) { + /* Compatibility code for "%I" just print I */ + return -1; + } + break; + case SECUREC_CHAR('h'): + if (**format == SECUREC_CHAR('h')) { + attr->flags |= SECUREC_FLAG_CHAR; /* char */ + } else { + attr->flags |= SECUREC_FLAG_SHORT; /* short int */ + } + break; + case SECUREC_CHAR('w'): + attr->flags |= SECUREC_FLAG_WIDECHAR; /* wide char */ + break; + default: + break; + } + return 0; +} + +/* + * Decoded char type identifier + */ +static int SecDecodeTypeC(SecFormatAttr *attr, unsigned int cValue, SecFormatBuf *formatBuf, SecBuffer *buffer) +{ +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) && !(defined(__hpux)) && !(defined(SECUREC_ON_SOLARIS)) + attr->flags &= ~SECUREC_FLAG_LEADZERO; +#endif + +#ifdef SECUREC_FOR_WCHAR + attr->bufferIsWide = 1; + if (attr->flags & SECUREC_FLAG_SHORT) { +#if SECUREC_HAVE_MBTOWC + /* multibyte character to wide character */ + char tmpChar[2]; /* One character string, length is 2 */ + tmpChar[0] = (char)(cValue & 0x00ff); + tmpChar[1] = '\0'; + + if (mbtowc(buffer->wStr, tmpChar, sizeof(tmpChar)) < 0) { + return -1; + } +#else + return -1; +#endif + } else { + buffer->wStr[0] = (wchar_t)cValue; + } + formatBuf->wStr = buffer->wStr; + return 1; /* only 1 wide character */ +#else /* SECUREC_FOR_WCHAR */ + attr->bufferIsWide = 0; + if (attr->flags & (SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR)) { +#if SECUREC_HAVE_WCTOMB + wchar_t wChar = (wchar_t)cValue; + int textLen; + /* wide character to multibyte character */ + SECUREC_MASK_MSVC_CRT_WARNING + textLen = wctomb(buffer->str, wChar); + SECUREC_END_MASK_MSVC_CRT_WARNING + if (textLen < 0) { + return -1; + } + formatBuf->str = buffer->str; + return textLen; +#else + return -1; +#endif + } else { + /* get multibyte character from argument */ + unsigned short temp; + temp = (unsigned short)cValue; + buffer->str[0] = (char)temp; + formatBuf->str = buffer->str; + return 1; /* only 1 character */ + } +#endif + +} + +/* literal string to print null ptr, define it as array rather than const text area + * is to avoid gcc warning with pointing const text with variable + */ +#if SECUREC_HAVE_WCHART +static wchar_t g_wStrNullString[SECUREC_NULL_STRING_SIZE] = { L'(', L'n', L'u', L'l', L'l', L')', L'\0', L'\0' }; +#endif +static char g_strNullString[SECUREC_NULL_STRING_SIZE] = "(null)"; + +static int SecDecodeTypeSchar(const SecFormatAttr *attr, SecFormatBuf *formatBuf) +{ + int finalPrecision = (attr->precision == -1) ? SECUREC_INT_MAX : attr->precision; + int textLen; + + if (formatBuf->str == NULL) { /* NULL passed, use special string */ + formatBuf->str = g_strNullString; + } + if (finalPrecision == SECUREC_INT_MAX) { + /* precision NOT assigned */ + /* The strlen performance is high when the string length is greater than 32 */ + textLen = (int)strlen(formatBuf->str); + } else { + /* precision assigned */ + size_t tmpLen; + SECUREC_CALC_STR_LEN(formatBuf->str, (size_t)(unsigned int)finalPrecision, &tmpLen); + textLen = (int)tmpLen; + } + return textLen; +} + +#if SECUREC_HAVE_WCHART +static int SecDecodeTypeSwchar(SecFormatAttr *attr, SecFormatBuf *formatBuf) +{ + int finalPrecision = (attr->precision == -1) ? SECUREC_INT_MAX : attr->precision; + int textLen; + + attr->bufferIsWide = 1; + if (formatBuf->wStr == NULL) { /* NULL passed, use special string */ + formatBuf->wStr = g_wStrNullString; + } + /* textLen in wchar_t */ + SECUREC_CALC_WSTR_LEN(formatBuf->wStr, finalPrecision, &textLen); + + return textLen; +} +#endif + +/* + * Decoded string identifier + */ +static int SecDecodeTypeS(SecFormatAttr *attr, char *argPtr, SecFormatBuf *formatBuf) +{ + int textLen; +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT)) && (!defined(SECUREC_ON_UNIX)) + attr->flags &= ~SECUREC_FLAG_LEADZERO; +#endif + formatBuf->str = argPtr; +#ifdef SECUREC_FOR_WCHAR +#if defined(SECUREC_COMPATIBLE_LINUX_FORMAT) + if (!(attr->flags & SECUREC_FLAG_LONG)) { + attr->flags |= SECUREC_FLAG_SHORT; + } +#endif + if (attr->flags & SECUREC_FLAG_SHORT) { + /* textLen now contains length in multibyte chars */ + textLen = SecDecodeTypeSchar(attr, formatBuf); + } else { + /* textLen now contains length in wide chars */ + textLen = SecDecodeTypeSwchar(attr, formatBuf); + } +#else /* SECUREC_FOR_WCHAR */ + if (attr->flags & (SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR)) { + /* textLen now contains length in wide chars */ +#if SECUREC_HAVE_WCHART + textLen = SecDecodeTypeSwchar(attr, formatBuf); +#else + textLen = 0; +#endif + } else { + /* textLen now contains length in multibyte chars */ + textLen = SecDecodeTypeSchar(attr, formatBuf); + } +#endif /* SECUREC_FOR_WCHAR */ + return textLen; +} + +/* + * Write one character to dest buffer + */ +static void SecOutputOneChar(SecChar ch, SecPrintfStream *stream, int *counter) +{ + /* normal state, write character */ + if (SECUREC_IS_REST_BUF_ENOUGH(stream, 1)) { /* only one char */ + SECUREC_SAFE_WRITE_CHAR(ch, stream, counter); /* char * cast to wchar * */ + } else { +#ifdef SECUREC_FOR_WCHAR + SecWriteCharW(ch, stream, counter); +#else + /* optimize function call to code */ + *counter = -1; + stream->count = -1; +#endif + } +} + +/* + * Check precison in format + */ +static int SecDecodePrecision(SecChar ch, SecFormatAttr *formatAttr) +{ + if (formatAttr->dynPrecision == 0) { + /* add digit to current precision */ + if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(formatAttr->precision)) { + return -1; + } + formatAttr->precision = (int)SECUREC_MUL_TEN((unsigned int)formatAttr->precision) + + (unsigned char)(ch - SECUREC_CHAR('0')); + } else { + if (formatAttr->precision < 0) { + formatAttr->precision = -1; + } + if (formatAttr->precision > SECUREC_MAX_WIDTH_LEN) { + return -1; + } + } + return 0; +} + + +/* + * Check width in format + */ +static int SecDecodeWidth(SecChar ch, SecFormatAttr *formatAttr, SecFmtState lastState) +{ + if (formatAttr->dynWidth == 0) { + if (lastState != STAT_WIDTH) { + formatAttr->fldWidth = 0; + } + if (SECUREC_MUL_TEN_ADD_BEYOND_MAX(formatAttr->fldWidth)) { + return -1; + } + formatAttr->fldWidth = (int)SECUREC_MUL_TEN((unsigned int)formatAttr->fldWidth) + + (unsigned char)(ch - SECUREC_CHAR('0')); + } else { + if (formatAttr->fldWidth < 0) { + formatAttr->flags |= SECUREC_FLAG_LEFT; + formatAttr->fldWidth = (-formatAttr->fldWidth); + if (formatAttr->fldWidth > SECUREC_MAX_WIDTH_LEN) { + return -1; + } + } + } + return 0; +} +#ifdef SECUREC_FOR_WCHAR +/* + * Formatting output core functions for wchar version.Called by a function such as vswprintf_s + * argList must not be declare as const + */ +static int SecOutputSW(SecPrintfStream *stream, const wchar_t *cFormat, va_list argList) +#else +/* + * Formatting output core functions for char version.Called by a function such as vsnprintf_s + */ +static int SecOutputS(SecPrintfStream *stream, const char *cFormat, va_list argList) +#endif +{ + const SecChar *format = cFormat; +#if SECUREC_ENABLE_SPRINTF_FLOAT + char *floatBuf = NULL; +#endif + SecFormatBuf formatBuf; + static const char *itoaUpperDigits = "0123456789ABCDEFX"; + static const char *itoaLowerDigits = "0123456789abcdefx"; + const char *digits = itoaUpperDigits; + unsigned int radix = SECUREC_RADIX_DECIMAL; + int charsOut; /* characters written */ + int prefixLen = 0; /* Must be initialized or compiler alerts */ + int padding = 0; + int textLen; /* length of the text */ + int noOutput = 0; /* Must be initialized or compiler alerts */ + SecFmtState state; + SecFmtState lastState; + SecChar prefix[SECUREC_PREFIX_LEN] = { 0 }; + SecChar ch; /* currently read character */ + static const unsigned char stateTable[SECUREC_STATE_TABLE_SIZE] = { + /* type 0: nospecial meanin; + * 1: '%'; + * 2: '.' + * 3: '*' + * 4: '0' + * 5: '1' ... '9' + * 6: ' ', '+', '-', '#' + * 7: 'h', 'l', 'L', 'F', 'w' , 'N','z','q','t','j' + * 8: 'd','o','u','i','x','X','e','f','g' + */ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x03, 0x06, 0x00, 0x06, 0x02, 0x00, + 0x04, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x08, 0x08, 0x08, 0x00, 0x07, 0x00, 0x00, 0x07, 0x00, 0x07, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x07, 0x08, 0x07, 0x00, 0x07, 0x00, 0x00, 0x08, + 0x08, 0x07, 0x00, 0x08, 0x07, 0x08, 0x00, 0x07, 0x08, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, + /* fill zero for normal char 128 byte for 0x80 - 0xff */ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + /* state 0: normal + * 1: percent + * 2: flag + * 3: width + * 4: dot + * 5: precis + * 6: size + * 7: type + * 8: invalid + */ + 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x00, 0x00, 0x01, 0x00, 0x08, 0x08, 0x08, 0x08, 0x08, + 0x01, 0x00, 0x00, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x00, 0x00, 0x00, 0x03, 0x03, 0x08, 0x05, + 0x08, 0x08, 0x00, 0x00, 0x00, 0x02, 0x02, 0x03, 0x05, 0x05, 0x08, 0x00, 0x00, 0x00, 0x03, 0x03, + 0x03, 0x05, 0x05, 0x08, 0x00, 0x00, 0x00, 0x02, 0x02, 0x02, 0x08, 0x08, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x06, 0x06, 0x06, 0x06, 0x06, 0x00, 0x00, 0x00, 0x07, 0x07, 0x07, 0x07, 0x07, 0x07, 0x00, + 0x00 + }; + + SecFormatAttr formatAttr; + SecBuffer buffer; + formatAttr.flags = 0; + formatAttr.bufferIsWide = 0; /* flag for buffer contains wide chars */ + formatAttr.fldWidth = 0; + formatAttr.precision = 0; + formatAttr.dynWidth = 0; + formatAttr.dynPrecision = 0; + charsOut = 0; + textLen = 0; + state = STAT_NORMAL; /* starting state */ + formatBuf.str = NULL; + + /* loop each format character */ + /* remove format != NULL */ + while ((ch = *format) != SECUREC_CHAR('\0') && charsOut >= 0) { + ++format; + lastState = state; + state = SECUREC_DECODE_STATE(ch, stateTable, lastState); + switch (state) { + case STAT_NORMAL: + SecOutputOneChar(ch, stream, &charsOut); + continue; + case STAT_PERCENT: + /* set default values */ + prefixLen = 0; + noOutput = 0; + formatAttr.flags = 0; + formatAttr.fldWidth = 0; + formatAttr.precision = -1; + formatAttr.bufferIsWide = 0; + formatAttr.dynWidth = 0; + formatAttr.dynPrecision = 0; + break; + case STAT_FLAG: + /* set flag based on which flag character */ + SecDecodeFlags(ch, &formatAttr); + break; + case STAT_WIDTH: + /* update width value */ + if (ch == SECUREC_CHAR('*')) { + /* get width */ + formatAttr.fldWidth = (int)va_arg(argList, int); + formatAttr.dynWidth = 1; + } else { + formatAttr.dynWidth = 0; + } + if (SecDecodeWidth(ch, &formatAttr, lastState) != 0) { + return -1; + } + break; + case STAT_DOT: + formatAttr.precision = 0; + break; + case STAT_PRECIS: + /* update precison value */ + if (ch == SECUREC_CHAR('*')) { + /* get precision from arg list */ + formatAttr.precision = (int)va_arg(argList, int); + formatAttr.dynPrecision = 1; + } else { + formatAttr.dynPrecision = 0; + } + if (SecDecodePrecision(ch, &formatAttr) != 0) { + return -1; + } + break; + case STAT_SIZE: + /* read a size specifier, set the formatAttr.flags based on it */ + if (SecDecodeSize(ch, &formatAttr, &format) != 0) { + /* Compatibility code for "%I" just print I */ + SecOutputOneChar(ch, stream, &charsOut); + state = STAT_NORMAL; + continue; + } + break; + case STAT_TYPE: + switch (ch) { + case SECUREC_CHAR('C'): + /* wide char */ + if (!(formatAttr.flags & (SECUREC_FLAG_SHORT | SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR))) { +#ifdef SECUREC_FOR_WCHAR + formatAttr.flags |= SECUREC_FLAG_SHORT; +#else + formatAttr.flags |= SECUREC_FLAG_WIDECHAR; +#endif + } + /* fall-through */ + /* FALLTHRU */ + case SECUREC_CHAR('c'): + do { + unsigned int cValue = (unsigned int)va_arg(argList, int); + textLen = SecDecodeTypeC(&formatAttr, cValue, &formatBuf, &buffer); + if (textLen < 0) { + noOutput = 1; + } + } SECUREC_WHILE_ZERO; + break; + case SECUREC_CHAR('S'): /* wide char string */ + if (!(formatAttr.flags & (SECUREC_FLAG_SHORT | SECUREC_FLAG_LONG | SECUREC_FLAG_WIDECHAR))) { +#ifndef SECUREC_FOR_WCHAR + formatAttr.flags |= SECUREC_FLAG_WIDECHAR; +#else + formatAttr.flags |= SECUREC_FLAG_SHORT; +#endif + } + /* fall-through */ + /* FALLTHRU */ + case SECUREC_CHAR('s'): + do { + char *argPtr = (char *)va_arg(argList, char *); + textLen = SecDecodeTypeS(&formatAttr, argPtr, &formatBuf); + } SECUREC_WHILE_ZERO; + break; + case SECUREC_CHAR('n'): + /* higher risk disable it */ + return -1; + case SECUREC_CHAR('E'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('F'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('G'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('A'): /* fall-through */ /* FALLTHRU */ + /* convert format char to lower , use Explicit conversion to clean up compilation warning */ + ch = (SecChar)(ch + ((SecChar)(SECUREC_CHAR('a')) - (SECUREC_CHAR('A')))); + /* fall-through */ + /* FALLTHRU */ + case SECUREC_CHAR('e'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('f'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('g'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('a'): +#if SECUREC_ENABLE_SPRINTF_FLOAT + do { + int bufferSize = 0; /* size of formatBuf.str */ + /* floating point conversion */ + formatBuf.str = buffer.str; /* output buffer for float string with default size */ + + /* compute the precision value */ + if (formatAttr.precision < 0) { + formatAttr.precision = SECUREC_FLOAT_DEFAULT_PRECISION; + } else if (formatAttr.precision == 0 && ch == SECUREC_CHAR('g')) { + formatAttr.precision = 1; + } + + /* calc buffer size to store double value + * The maximum length of SECUREC_MAX_WIDTH_LEN is enough + */ + if (formatAttr.flags & SECUREC_FLAG_LONG_DOUBLE) { + if (formatAttr.precision > (SECUREC_MAX_WIDTH_LEN - SECUREC_FLOAT_BUFSIZE_LB)) { + noOutput = 1; + break; + } + /* Long double needs to meet the basic print length */ + bufferSize = SECUREC_FLOAT_BUFSIZE_LB + formatAttr.precision + SECUREC_FLOAT_BUF_EXT; + } else { + if (formatAttr.precision > (SECUREC_MAX_WIDTH_LEN - SECUREC_FLOAT_BUFSIZE)) { + noOutput = 1; + break; + } + /* Double needs to meet the basic print length */ + bufferSize = SECUREC_FLOAT_BUFSIZE + formatAttr.precision + SECUREC_FLOAT_BUF_EXT; + } + if (formatAttr.fldWidth > bufferSize) { + bufferSize = formatAttr.fldWidth + SECUREC_FLOAT_BUF_EXT; + } + + if (bufferSize > SECUREC_BUFFER_SIZE) { + /* the current vlaue of SECUREC_BUFFER_SIZE could NOT store the + * formatted float string + */ + floatBuf = (char *)SECUREC_MALLOC(((size_t)(unsigned int)bufferSize)); + if (floatBuf != NULL) { + formatBuf.str = floatBuf; + } else { + noOutput = 1; + break; + } + } + + do { + /* add following code to call system sprintf API for float number */ + const SecChar *pFloatFmt = format - 2; /* sub 2 to the position before 'f' or 'g' */ + int k; + int fFmtStrLen; + char fFmtBuf[SECUREC_FMT_STR_LEN]; + char *fFmtStr = fFmtBuf; + char *fFmtHeap = NULL; /* to clear warning */ + + while (SECUREC_CHAR('%') != *pFloatFmt) { /* must meet '%' */ + --pFloatFmt; + } + fFmtStrLen = (int)((format - pFloatFmt) + 1); /* with ending terminator */ + if (fFmtStrLen > SECUREC_FMT_STR_LEN) { + /* if SECUREC_FMT_STR_LEN is NOT enough, alloc a new buffer */ + fFmtHeap = (char *)SECUREC_MALLOC((size_t)((unsigned int)fFmtStrLen)); + if (fFmtHeap == NULL) { + noOutput = 1; + break; + } else { + for (k = 0; k < fFmtStrLen - 1; ++k) { + /* convert wchar to char */ + fFmtHeap[k] = (char)(pFloatFmt[k]); /* copy the format string */ + } + fFmtHeap[k] = '\0'; + + fFmtStr = fFmtHeap; + } + } else { + /* purpose of the repeat code is to solve the tool alarm Redundant_Null_Check */ + for (k = 0; k < fFmtStrLen - 1; ++k) { + /* convert wchar to char */ + fFmtBuf[k] = (char)(pFloatFmt[k]); /* copy the format string */ + } + fFmtBuf[k] = '\0'; + } + + if (formatAttr.flags & SECUREC_FLAG_LONG_DOUBLE) { +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + long double tmp = (long double)va_arg(argList, long double); + textLen = SecFormatLongDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); +#else + double tmp = (double)va_arg(argList, double); + textLen = SecFormatDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); +#endif + } else { + double tmp = (double)va_arg(argList, double); + textLen = SecFormatDboule(formatBuf.str, &formatAttr, fFmtStr, tmp); + } + + if (fFmtHeap != NULL) { + /* if buffer is alloced on heap, free it */ + SECUREC_FREE(fFmtHeap); + fFmtHeap = NULL; + /* to clear e438 last value assigned not used , the compiler will + * optimize this code + */ + (void)fFmtHeap; + } + if (textLen < 0 || textLen >= bufferSize) { + /* bufferSize is large enough, just validation the return value */ + noOutput = 1; + break; + } + + /* no padding ,this variable to calculate amount of padding */ + formatAttr.fldWidth = textLen; + prefixLen = 0; /* no padding ,this variable to calculate amount of padding */ + formatAttr.flags = 0; /* clear all internal formatAttr.flags */ + break; + } SECUREC_WHILE_ZERO; + } SECUREC_WHILE_ZERO; + break; +#else + return -1; +#endif + case SECUREC_CHAR('p'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('X'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('x'): + /* unsigned lower hex output */ + digits = itoaLowerDigits; + radix = SECUREC_RADIX_HEX; + switch (ch) { + case SECUREC_CHAR('p'): + /* print a pointer */ +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) + formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; +#else + formatAttr.flags |= SECUREC_FLAG_POINTER; +#endif +#ifdef SECUREC_ON_64BITS + formatAttr.flags |= SECUREC_FLAG_I64; /* converting an int64 */ +#else + formatAttr.flags |= SECUREC_FLAG_LONG; /* converting a long */ +#endif + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) && (!defined(SECUREC_ON_UNIX)) +#if defined(SECUREC_VXWORKS_PLATFORM) + formatAttr.precision = 1; +#else + formatAttr.precision = 0; +#endif + formatAttr.flags |= SECUREC_FLAG_ALTERNATE; /* "0x" is not default prefix in UNIX */ + break; +#else + /* not linux vxwoks */ +#if defined(_AIX) || defined(SECUREC_ON_SOLARIS) + formatAttr.precision = 1; +#else + formatAttr.precision = 2 * sizeof(void *); /* 2 precision of different systems */ +#endif +#endif + +#if defined(SECUREC_ON_UNIX) + break; +#endif + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('X'): /* fall-through */ /* FALLTHRU */ + /* unsigned upper hex output */ + digits = itoaUpperDigits; + break; + default: + break; + } + + if (formatAttr.flags & SECUREC_FLAG_ALTERNATE) { + /* alternate form means '0x' prefix */ + prefix[0] = SECUREC_CHAR('0'); + prefix[1] = (SecChar)(digits[16]); /* 16 for 'x' or 'X' */ + +#if (defined(SECUREC_COMPATIBLE_LINUX_FORMAT) || defined(SECUREC_VXWORKS_PLATFORM)) + if (ch == 'p') { + prefix[1] = SECUREC_CHAR('x'); + } +#endif +#if defined(_AIX) || defined(SECUREC_ON_SOLARIS) + if (ch == 'p') { + prefixLen = 0; + } else { + prefixLen = SECUREC_PREFIX_LEN; + } +#else + prefixLen = SECUREC_PREFIX_LEN; +#endif + + } + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('u'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('o'): /* fall-through */ /* FALLTHRU */ + switch (ch) { + case SECUREC_CHAR('i'): /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('d'): /* fall-through */ /* FALLTHRU */ + /* signed decimal output */ + formatAttr.flags |= SECUREC_FLAG_SIGNED; + /* fall-through */ /* FALLTHRU */ + case SECUREC_CHAR('u'): + radix = SECUREC_RADIX_DECIMAL; + break; + case SECUREC_CHAR('o'): + /* unsigned octal output */ + radix = SECUREC_RADIX_OCTAL; + if (formatAttr.flags & SECUREC_FLAG_ALTERNATE) { + /* alternate form means force a leading 0 */ + formatAttr.flags |= SECUREC_FLAG_FORCE_OCTAL; + } + break; + default: + break; + } + + do { + + SecUnsignedInt64 number = 0; /* number to convert */ + SecInt64 l; /* temp long value */ + + /* read argument into variable l */ + if (formatAttr.flags & SECUREC_FLAG_I64) { + l = (SecInt64)va_arg(argList, SecInt64); + } else if (formatAttr.flags & SECUREC_FLAG_LONGLONG) { + l = (SecInt64)va_arg(argList, SecInt64); + } else +#ifdef SECUREC_ON_64BITS + if (formatAttr.flags & SECUREC_FLAG_LONG) { + l = (long)va_arg(argList, long); + } else +#endif /* SECUREC_ON_64BITS */ + if (formatAttr.flags & SECUREC_FLAG_CHAR) { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + l = (char)va_arg(argList, int); /* sign extend */ + if (l >= 128) { /* 128 on some platform, char is always unsigned */ + SecUnsignedInt64 tmpL = (SecUnsignedInt64)l; + unsigned char tmpCh = (unsigned char)(~(tmpL)); + l = tmpCh + 1; + formatAttr.flags |= SECUREC_FLAG_NEGATIVE; + } + } else { + l = (unsigned char)va_arg(argList, int); /* zero-extend */ + } + + } else if (formatAttr.flags & SECUREC_FLAG_SHORT) { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + l = (short)va_arg(argList, int); /* sign extend */ + } else { + l = (unsigned short)va_arg(argList, int); /* zero-extend */ + } + + } +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + else if (formatAttr.flags & SECUREC_FLAG_PTRDIFF) { + l = (ptrdiff_t)va_arg(argList, ptrdiff_t); /* sign extend */ + } else if (formatAttr.flags & SECUREC_FLAG_SIZE) { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + /* No suitable macros were found to handle the branch */ + if (SecIsSameSize(sizeof(size_t), sizeof(long))) { + l = va_arg(argList, long); /* sign extend */ + } else if (SecIsSameSize(sizeof(size_t), sizeof(long long))) { + l = va_arg(argList, long long); /* sign extend */ + } else { + l = va_arg(argList, int); /* sign extend */ + } + } else { + l = (SecInt64)(size_t)va_arg(argList, size_t); /* sign extend */ + } + } else if (formatAttr.flags & SECUREC_FLAG_INTMAX) { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + l = va_arg(argList, SecInt64); /* sign extend */ + } else { + /* sign extend */ + l = (SecInt64)(SecUnsignedInt64)va_arg(argList, SecUnsignedInt64); + } + } +#endif + else { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + l = va_arg(argList, int); /* sign extend */ + } else { + l = (unsigned int)va_arg(argList, int); /* zero-extend */ + } + + } + + /* check for negative; copy into number */ + if ((formatAttr.flags & SECUREC_FLAG_SIGNED) && l < 0) { + number = (SecUnsignedInt64)(-l); + formatAttr.flags |= SECUREC_FLAG_NEGATIVE; + } else { + number = (SecUnsignedInt64)l; + } + + if (((formatAttr.flags & SECUREC_FLAG_I64) == 0) && +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT + ((formatAttr.flags & SECUREC_FLAG_INTMAX) == 0) && +#endif +#ifdef SECUREC_ON_64BITS + ((formatAttr.flags & SECUREC_FLAG_PTRDIFF) == 0) && + ((formatAttr.flags & SECUREC_FLAG_SIZE) == 0) && +#if !defined(SECUREC_COMPATIBLE_WIN_FORMAT) /* on window 64 system sizeof long is 32bit */ + ((formatAttr.flags & SECUREC_FLAG_LONG) == 0) && +#endif +#endif + ((formatAttr.flags & SECUREC_FLAG_LONGLONG) == 0)) { + + number &= 0xffffffff; /* use 0xffffffff as 32 bit mask */ + } + + /* check precision value for default */ + if (formatAttr.precision < 0) { + formatAttr.precision = 1; /* default precision */ + } else { +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) + formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; +#else + if (!(formatAttr.flags & SECUREC_FLAG_POINTER)) { + formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; + } +#endif + if (formatAttr.precision > SECUREC_MAX_PRECISION) { + formatAttr.precision = SECUREC_MAX_PRECISION; + } + } + + /* Check if data is 0; if so, turn off hex prefix, + * 'p' add 0x prefix, otherwise not add prefix + */ + if (number == 0) { +#if !(defined(SECUREC_VXWORKS_PLATFORM) || defined(__hpux)) + prefixLen = 0; +#else + if ((ch == 'p') && (formatAttr.flags & SECUREC_FLAG_ALTERNATE)) { + prefixLen = SECUREC_PREFIX_LEN; + } else { + prefixLen = 0; + } +#endif + } + + /* Convert data to ASCII */ + formatBuf.str = &buffer.str[SECUREC_BUFFER_SIZE]; + + if (number > 0) { +#ifdef SECUREC_ON_64BITS + switch (radix) { + /* the compiler will optimize each one */ + case SECUREC_RADIX_DECIMAL: + SECUREC_SPECIAL_QWORD_BASE10(number); + break; + case SECUREC_RADIX_HEX: + SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_HEX); + break; + case SECUREC_RADIX_OCTAL: + SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_OCTAL); + break; + default: + break; + } +#else /* for 32 bits system */ + if (number <= 0xFFFFFFFFUL) { + /* in most case, the value to be converted is small value */ + SecUnsignedInt32 n32Tmp = (SecUnsignedInt32)number; + switch (radix) { + case SECUREC_RADIX_HEX: + SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_HEX); + break; + case SECUREC_RADIX_OCTAL: + SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_OCTAL); + break; + +#ifdef _AIX + /* the compiler will optimize div 10 */ + case SECUREC_RADIX_DECIMAL: + SECUREC_SPECIAL_DWORD(n32Tmp, SECUREC_RADIX_DECIMAL); + break; +#else + case SECUREC_RADIX_DECIMAL: + do { + /* fast div 10 */ + SecUnsignedInt32 q; + SecUnsignedInt32 r; + do { + *--formatBuf.str = digits[n32Tmp % SECUREC_RADIX_DECIMAL]; + q = (n32Tmp >> 1) + (n32Tmp >> 2); /* fast div magic 2 */ + q = q + (q >> 4); /* fast div magic 4 */ + q = q + (q >> 8); /* fast div magic 8 */ + q = q + (q >> 16); /* fast div magic 16 */ + q = q >> 3; /* fast div magic 3 */ + r = n32Tmp - SECUREC_MUL_TEN(q); + n32Tmp = (r > 9) ? (q + 1) : q; /* fast div magic 9 */ + } while (n32Tmp != 0); + } SECUREC_WHILE_ZERO; + break; +#endif + default: + break; + } /* end switch */ + } else { + /* the value to be converted is greater than 4G */ +#if defined(SECUREC_VXWORKS_VERSION_5_4) + do { + SecUnsignedInt32 digit = 0; /* ascii value of digit */ + SecUnsignedInt64 quotient = 0; + if (SecU64Div32(number,(SecUnsignedInt32)radix, "ient, &digit) != 0) { + noOutput = 1; + break; + } + *--formatBuf.str = digits[digit]; + number = quotient; + } while (number != 0); +#else + switch (radix) { + /* the compiler will optimize div 10 */ + case SECUREC_RADIX_DECIMAL: + SECUREC_SPECIAL_QWORD_BASE10(number); + break; + case SECUREC_RADIX_OCTAL: + SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_OCTAL); + break; + case SECUREC_RADIX_HEX: + SECUREC_SPECIAL_QWORD(number, SECUREC_RADIX_HEX); + break; + default: + break; + } +#endif + } +#endif + + } + /* compute length of number,.if textLen > 0, then formatBuf.str must be in buffer.str */ + textLen = (int)(size_t)((char *)&buffer.str[SECUREC_BUFFER_SIZE] - formatBuf.str); + if (formatAttr.precision > textLen) { + int ii; + for (ii = 0; ii < formatAttr.precision - textLen; ++ii) { + *--formatBuf.str = '0'; + } + textLen = formatAttr.precision; + } + + /* Force a leading zero if FORCEOCTAL flag set */ + if ((formatAttr.flags & SECUREC_FLAG_FORCE_OCTAL) && + (textLen == 0 || formatBuf.str[0] != '0')) { + *--formatBuf.str = '0'; + ++textLen; /* add a zero */ + } + } SECUREC_WHILE_ZERO; + break; + default: + break; + } + + while (noOutput < 1) { + if (formatAttr.flags & SECUREC_FLAG_SIGNED) { + if (formatAttr.flags & SECUREC_FLAG_NEGATIVE) { + /* prefix is a '-' */ + prefix[0] = SECUREC_CHAR('-'); + prefixLen = 1; + } else if (formatAttr.flags & SECUREC_FLAG_SIGN) { + /* prefix is '+' */ + prefix[0] = SECUREC_CHAR('+'); + prefixLen = 1; + } else if (formatAttr.flags & SECUREC_FLAG_SIGN_SPACE) { + /* prefix is ' ' */ + prefix[0] = SECUREC_CHAR(' '); + prefixLen = 1; + } + } + +#if defined(SECUREC_COMPATIBLE_LINUX_FORMAT) && (!defined(SECUREC_ON_UNIX)) + if ((formatAttr.flags & SECUREC_FLAG_POINTER) && (textLen == 0)) { + formatAttr.flags &= ~SECUREC_FLAG_LEADZERO; + formatBuf.str = &buffer.str[SECUREC_BUFFER_SIZE - 1]; + *formatBuf.str-- = '\0'; + *formatBuf.str-- = ')'; + *formatBuf.str-- = 'l'; + *formatBuf.str-- = 'i'; + *formatBuf.str-- = 'n'; + *formatBuf.str = '('; + textLen = 5; /* length of (nil) is 5 */ + } +#endif + + /* calculate amount of padding */ + padding = (formatAttr.fldWidth - textLen) - prefixLen; + + /* put out the padding, prefix, and text, in the correct order */ + + if (!(formatAttr.flags & (SECUREC_FLAG_LEFT | SECUREC_FLAG_LEADZERO)) && padding > 0) { + /* pad on left with blanks */ + if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { + /* char * cast to wchar * */ + SECUREC_SAFE_PADDING(SECUREC_CHAR(' '), padding, stream, &charsOut); + } else { + SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR(' '), padding, stream, &charsOut); + } + } + + /* write prefix */ + if (prefixLen > 0) { + SecChar *pPrefix = prefix; + if (SECUREC_IS_REST_BUF_ENOUGH(stream, prefixLen)) { + /* max prefix len is 2, use loop copy */ /* char * cast to wchar * in WCHAR version */ + SECUREC_SAFE_WRITE_STR_OPT(pPrefix, prefixLen, stream, &charsOut); + } else { + SECUREC_WRITE_STRING(prefix, prefixLen, stream, &charsOut); + } + } + + if ((formatAttr.flags & SECUREC_FLAG_LEADZERO) && !(formatAttr.flags & SECUREC_FLAG_LEFT) + && padding > 0) { + /* write leading zeros */ + if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { + /* char * cast to wchar * */ + SECUREC_SAFE_PADDING(SECUREC_CHAR('0'), padding, stream, &charsOut); + } else { + SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR('0'), padding, stream, &charsOut); + } + } + + /* write text */ +#ifndef SECUREC_FOR_WCHAR + if (formatAttr.bufferIsWide != 0 && (textLen > 0)) { +#if SECUREC_HAVE_WCTOMB + wchar_t *p = formatBuf.wStr; + int count = textLen; + while (count > 0) { + char tmpBuf[SECUREC_MB_LEN + 1]; + SECUREC_MASK_MSVC_CRT_WARNING + int retVal = wctomb(tmpBuf, *p); + SECUREC_END_MASK_MSVC_CRT_WARNING + if (retVal <= 0) { + charsOut = -1; + break; + } + SECUREC_WRITE_STRING(tmpBuf, retVal, stream, &charsOut); + --count; + ++p; + } +#else + charsOut = -1; + break; +#endif + } else { + if (SECUREC_IS_REST_BUF_ENOUGH(stream, textLen)) { + SECUREC_SAFE_WRITE_STR(formatBuf.str, textLen, stream, &charsOut); + } else { + SECUREC_WRITE_STRING(formatBuf.str, textLen, stream, &charsOut); + } + } +#else /* SECUREC_FOR_WCHAR */ + if (formatAttr.bufferIsWide == 0 && textLen > 0) { +#if SECUREC_HAVE_MBTOWC + int count = textLen; + char *p = formatBuf.str; + + while (count > 0) { + wchar_t wChar = L'\0'; + int retVal = mbtowc(&wChar, p, (size_t)MB_CUR_MAX); + if (retVal <= 0) { + charsOut = -1; + break; + } + SecWriteCharW(wChar, stream, &charsOut); + p += retVal; + count -= retVal; + } +#else + charsOut = -1; + break; +#endif + } else { + if (SECUREC_IS_REST_BUF_ENOUGH(stream, textLen)) { + /* char * cast to wchar * */ + SECUREC_SAFE_WRITE_STR(formatBuf.wStr, textLen, stream, &charsOut); + } else { + SECUREC_WRITE_STRING(formatBuf.wStr, textLen, stream, &charsOut); + } + } +#endif /* SECUREC_FOR_WCHAR */ + + if (charsOut >= 0 && (formatAttr.flags & SECUREC_FLAG_LEFT) && padding > 0) { + /* pad on right with blanks */ + if (SECUREC_IS_REST_BUF_ENOUGH(stream, padding)) { + /* char * cast to wchar * */ + SECUREC_SAFE_PADDING(SECUREC_CHAR(' '), padding, stream, &charsOut); + } else { + SECUREC_WRITE_MULTI_CHAR(SECUREC_CHAR(' '), padding, stream, &charsOut); + } + } + break; + } +#if SECUREC_ENABLE_SPRINTF_FLOAT + if (floatBuf != NULL) { + SECUREC_FREE(floatBuf); + floatBuf = NULL; + } +#endif + break; + case STAT_INVALID: + return -1; + default: + return -1; /* input format is wrong, directly return */ + } + } + + if (state != STAT_NORMAL && state != STAT_TYPE) { + return -1; + } + + return charsOut; /* the number of characters written */ +} +#endif /* OUTPUT_INL_2B263E9C_43D8_44BB_B17A_6D2033DECEE5 */ + diff --git a/third_party/securec/src/scanf_s.c b/third_party/securec/src/scanf_s.c new file mode 100644 index 00000000..e4b0e602 --- /dev/null +++ b/third_party/securec/src/scanf_s.c @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The scanf_s function is equivalent to fscanf_s with the argument stdin interposed before the arguments to scanf_s + * The scanf_s function reads data from the standard input stream stdin and + * writes the data into the location that's given by argument. Each argument + * must be a pointer to a variable of a type that corresponds to a type specifier + * in format. If copying occurs between strings that overlap, the behavior is + * undefined. + * + * + * format Format control string. + * ... Optional arguments. + * + * + * ... The converted value stored in user assigned address + * + * + * Returns the number of fields successfully converted and assigned; + * the return value does not include fields that were read but not assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ + +int scanf_s(const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vscanf_s(format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + + diff --git a/third_party/securec/src/secinput.h b/third_party/securec/src/secinput.h new file mode 100644 index 00000000..8cd92849 --- /dev/null +++ b/third_party/securec/src/secinput.h @@ -0,0 +1,156 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SEC_INPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C +#define SEC_INPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C +#include "securecutil.h" + +#define SECUREC_SCANF_EINVAL (-1) +#define SECUREC_SCANF_ERROR_PARA (-2) + +/* for internal stream flag */ +#define SECUREC_MEM_STR_FLAG 0X01 +#define SECUREC_FILE_STREAM_FLAG 0X02 +#define SECUREC_FROM_STDIN_FLAG 0X04 +#define SECUREC_LOAD_FILE_TO_MEM_FLAG 0X08 + +#define SECUREC_UNINITIALIZED_FILE_POS (-1) +#define SECUREC_BOM_HEADER_SIZE 2 +#define SECUREC_BOM_HEADER_BE_1ST 0xFEU +#define SECUREC_BOM_HEADER_BE_2ST 0xFFU +#define SECUREC_BOM_HEADER_LE_1ST 0xFFU +#define SECUREC_BOM_HEADER_LE_2ST 0xFEU +#define SECUREC_UTF8_BOM_HEADER_SIZE 3 +#define SECUREC_UTF8_BOM_HEADER_1ST 0xEFU +#define SECUREC_UTF8_BOM_HEADER_2ND 0xBBU +#define SECUREC_UTF8_BOM_HEADER_3RD 0xBFU +#define SECUREC_UTF8_LEAD_1ST 0xE0 +#define SECUREC_UTF8_LEAD_2ND 0x80 + +typedef struct { + unsigned int flag; /* mark the properties of input stream */ + int count; /* the size of buffered string in bytes */ + const char *cur; /* the pointer to next read position */ + char *base; /* the pointer to the header of buffered string */ +#if SECUREC_ENABLE_SCANF_FILE + FILE *pf; /* the file pointer */ + long oriFilePos; /* the original position of file offset when fscanf is called */ + int fileRealRead; +#if defined(SECUREC_NO_STD_UNGETC) + unsigned int lastChar; /* the char code of last input */ + int fUnget; /* the boolean flag of pushing a char back to read stream */ +#endif +#endif +} SecFileStream; + + +#define SECUREC_INIT_SEC_FILE_STREAM_COMMON(fileStream, streamFlag, curPtr, strCount) do { \ + (fileStream).flag = (streamFlag); \ + (fileStream).count = (strCount); \ + (fileStream).cur = (curPtr); \ + (fileStream).base = NULL; \ +} SECUREC_WHILE_ZERO + +#if SECUREC_ENABLE_SCANF_FILE +#if defined(SECUREC_NO_STD_UNGETC) +/* This initialization for eliminating redundant initialization. + * Compared with the previous version initialization 0, + * the current code causes the binary size to increase by some bytes + */ +#define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ + SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ + (fileStream).pf = (stream); \ + (fileStream).oriFilePos = (filePos); \ + (fileStream).fileRealRead = 0; \ + (fileStream).lastChar = 0; \ + (fileStream).fUnget = 0; \ +} SECUREC_WHILE_ZERO +#else +#define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ + SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ + (fileStream).pf = (stream); \ + (fileStream).oriFilePos = (filePos); \ + (fileStream).fileRealRead = 0; \ +} SECUREC_WHILE_ZERO +#endif +#else /* No SECUREC_ENABLE_SCANF_FILE */ +#define SECUREC_INIT_SEC_FILE_STREAM(fileStream, streamFlag, stream, filePos, curPtr, strCount) do { \ + SECUREC_INIT_SEC_FILE_STREAM_COMMON((fileStream), (streamFlag), (curPtr), (strCount)); \ +} SECUREC_WHILE_ZERO +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + extern int SecInputS(SecFileStream *stream, const char *cFormat, va_list argList); + extern void SecClearDestBuf(const char *buffer, const char *format, va_list argList); +#if SECUREC_IN_KERNEL == 0 + extern int SecInputSW(SecFileStream *stream, const wchar_t *cFormat, va_list argList); + extern void SecClearDestBufW(const wchar_t *buffer, const wchar_t *format, va_list argList); +#endif +/* 20150105 For software and hardware decoupling,such as UMG */ +#if defined(SECUREC_SYSAPI4VXWORKS) +#ifdef feof +#undef feof +#endif + extern int feof(FILE *stream); +#endif + +#if defined(SECUREC_SYSAPI4VXWORKS) || defined(SECUREC_CTYPE_MACRO_ADAPT) +#ifndef isspace +#define isspace(c) (((c) == ' ') || ((c) == '\t') || ((c) == '\r') || ((c) == '\n')) +#endif +#ifndef iswspace +#define iswspace(c) (((c) == L' ') || ((c) == L'\t') || ((c) == L'\r') || ((c) == L'\n')) +#endif +#ifndef isascii +#define isascii(c) (((unsigned char)(c)) <= 0x7f) +#endif +#ifndef isupper +#define isupper(c) ((c) >= 'A' && (c) <= 'Z') +#endif +#ifndef islower +#define islower(c) ((c) >= 'a' && (c) <= 'z') +#endif +#ifndef isalpha +#define isalpha(c) (isupper(c) || (islower(c))) +#endif +#ifndef isdigit +#define isdigit(c) ((c) >= '0' && (c) <= '9') +#endif +#ifndef isxupper +#define isxupper(c) ((c) >= 'A' && (c) <= 'F') +#endif +#ifndef isxlower +#define isxlower(c) ((c) >= 'a' && (c) <= 'f') +#endif +#ifndef isxdigit +#define isxdigit(c) (isdigit(c) || isxupper(c) || isxlower(c)) +#endif +#endif + +#ifdef __cplusplus +} +#endif +/* Reserved file operation macro interface */ +#define SECUREC_LOCK_FILE(s) +#define SECUREC_UNLOCK_FILE(s) +#define SECUREC_LOCK_STDIN(i, s) +#define SECUREC_UNLOCK_STDIN(i, s) +#endif + + diff --git a/third_party/securec/src/securecutil.c b/third_party/securec/src/securecutil.c new file mode 100644 index 00000000..1a44cfbe --- /dev/null +++ b/third_party/securec/src/securecutil.c @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Avoid duplicate header files,not include securecutil.h */ +#include "securecutil.h" + + +#if defined(ANDROID) && (SECUREC_HAVE_WCTOMB || SECUREC_HAVE_MBTOWC) +#include +#if SECUREC_HAVE_WCTOMB +/* + * Convert wide characters to narrow multi-bytes + */ +int wctomb(char *s, wchar_t wc) +{ + return wcrtomb(s, wc, NULL); +} +#endif + +#if SECUREC_HAVE_MBTOWC +/* + * Converting narrow multi-byte characters to wide characters + */ +int mbtowc(wchar_t *pwc, const char *s, size_t n) +{ + return mbrtowc(pwc, s, n, NULL); +} +#endif +#endif + +/* high Num << 8 | num of SPC Ver */ +#define SECUREC_C_VERSION (0x5 << 8) +#define SECUREC_SPC_VERSION 7 +#define SECUREC_VERSION_STR "Huawei Secure C V100R001C01SPC007B002" + +/* SPC verNumber<->verStr like: + * 0X201<->C01 + * 0X202<->SPC001 Redefine numbers after this version + * 0X502<->SPC002 + * 0X503<->SPC003 + * ... + * 0X50a<->SPC010 + * 0X50b<->SPC011 + * ... + */ +/* CP verNumber<->verStr like: + * 0X601<->CP0001 + * 0X602<->CP0002 + * ... + */ +const char *GetHwSecureCVersion(unsigned short *verNumber) +{ + if (verNumber != NULL) { + *verNumber = (unsigned short)(SECUREC_C_VERSION | SECUREC_SPC_VERSION); + } + return SECUREC_VERSION_STR; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(GetHwSecureCVersion); +#endif + diff --git a/third_party/securec/src/securecutil.h b/third_party/securec/src/securecutil.h new file mode 100644 index 00000000..98c9aad0 --- /dev/null +++ b/third_party/securec/src/securecutil.h @@ -0,0 +1,541 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F +#define SECURECUTIL_H_46C86578_F8FF_4E49_8E64_9B175241761F +#include "securec.h" + +#if (defined(_MSC_VER)) && (_MSC_VER >= 1400) +#define SECUREC_MASK_MSVC_CRT_WARNING __pragma(warning(push)) \ + __pragma(warning(disable:4996 4127)) +#define SECUREC_END_MASK_MSVC_CRT_WARNING __pragma(warning(pop)) +#else +#define SECUREC_MASK_MSVC_CRT_WARNING +#define SECUREC_END_MASK_MSVC_CRT_WARNING +#endif +#define SECUREC_WHILE_ZERO SECUREC_MASK_MSVC_CRT_WARNING while (0) SECUREC_END_MASK_MSVC_CRT_WARNING + +#ifndef SECUREC_HAVE_STRNLEN +#if (defined(_XOPEN_SOURCE) && _XOPEN_SOURCE >= 700) || (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 200809L) +#if SECUREC_IN_KERNEL +#define SECUREC_HAVE_STRNLEN 0 +#else +#if defined(__GLIBC__) && __GLIBC__ >= 2 && defined(__GLIBC_MINOR__) && __GLIBC_MINOR__ >= 10 +#define SECUREC_HAVE_STRNLEN 1 +#else +#define SECUREC_HAVE_STRNLEN 0 +#endif +#endif +#else +#define SECUREC_HAVE_STRNLEN 0 +#endif +#endif + +#if SECUREC_IN_KERNEL +/* in kernel disbale functions */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 0 +#endif +#ifndef SECUREC_ENABLE_SCANF_FLOAT +#define SECUREC_ENABLE_SCANF_FLOAT 0 +#endif +#ifndef SECUREC_ENABLE_SPRINTF_FLOAT +#define SECUREC_ENABLE_SPRINTF_FLOAT 0 +#endif +#ifndef SECUREC_HAVE_MBTOWC +#define SECUREC_HAVE_MBTOWC 0 +#endif +#ifndef SECUREC_HAVE_WCTOMB +#define SECUREC_HAVE_WCTOMB 0 +#endif +#ifndef SECUREC_HAVE_WCHART +#define SECUREC_HAVE_WCHART 0 +#endif +#else /* no in kernel */ +/* Systems that do not support file, can define this macro to 0. */ +#ifndef SECUREC_ENABLE_SCANF_FILE +#define SECUREC_ENABLE_SCANF_FILE 1 +#endif +#ifndef SECUREC_ENABLE_SCANF_FLOAT +#define SECUREC_ENABLE_SCANF_FLOAT 1 +#endif +/* Systems that do not support float, can define this macro to 0. */ +#ifndef SECUREC_ENABLE_SPRINTF_FLOAT +#define SECUREC_ENABLE_SPRINTF_FLOAT 1 +#endif +#ifndef SECUREC_HAVE_MBTOWC +#define SECUREC_HAVE_MBTOWC 1 +#endif +#ifndef SECUREC_HAVE_WCTOMB +#define SECUREC_HAVE_WCTOMB 1 +#endif +#ifndef SECUREC_HAVE_WCHART +#define SECUREC_HAVE_WCHART 1 +#endif +#endif + + +#define SECUREC_INT_MAX 2147483647 +#define SECUREC_MUL_SIXTEEN(x) ((x) << 4) +#define SECUREC_MUL_EIGHT(x) ((x) << 3) +#define SECUREC_MUL_TEN(x) ((((x) << 2) + (x)) << 1) +/* Limited format input and output width */ +#define SECUREC_MAX_WIDTH_LEN_DIV_TEN 21474836 +#define SECUREC_MAX_WIDTH_LEN SECUREC_MUL_TEN(SECUREC_MAX_WIDTH_LEN_DIV_TEN) +/* Is the x multiplied by 10 greater than */ +#define SECUREC_MUL_TEN_ADD_BEYOND_MAX(x) (((x) > SECUREC_MAX_WIDTH_LEN_DIV_TEN)) + +#define SECUREC_FLOAT_BUFSIZE (309 + 40) /* Max length of double value */ +#define SECUREC_FLOAT_BUFSIZE_LB (4932 + 40) /* Max length of long double value */ +#define SECUREC_FLOAT_DEFAULT_PRECISION 6 + +/* This macro does not handle pointer equality or integer overflow */ +#define SECUREC_MEMORY_NO_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) <= (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) <= (const char *)(src))) + +#define SECUREC_MEMORY_IS_OVERLAP(dest, src, count) \ + (((src) < (dest) && ((const char *)(src) + (count)) > (char *)(dest)) || \ + ((dest) < (src) && ((char *)(dest) + (count)) > (const char *)(src))) + +/* + * Check whether the strings overlap, len is the length of the string not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_STRING_NO_OVERLAP(dest, src, len) \ + (((src) < (dest) && ((src) + (len)) < (dest)) || \ + ((dest) < (src) && ((dest) + (len)) < (src))) + +/* + * Check whether the strings overlap for strcpy wcscpy function, dest len and src Len are not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_STRING_IS_OVERLAP(dest, src, len) \ + (((src) < (dest) && ((src) + (len)) >= (dest)) || \ + ((dest) < (src) && ((dest) + (len)) >= (src))) + +/* + * Check whether the strings overlap for strcat wcscat function, dest len and src Len are not include terminator + * Length is related to data type char or wchar , do not force conversion of types + */ +#define SECUREC_CAT_STRING_IS_OVERLAP(dest, destLen, src, srcLen) \ + (((dest) < (src) && ((dest) + (destLen) + (srcLen)) >= (src)) || \ + ((src) < (dest) && ((src) + (srcLen)) >= (dest))) + + +#if SECUREC_HAVE_STRNLEN +#define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ + *(outLen) = strnlen((str), (maxLen)); \ +} SECUREC_WHILE_ZERO +#define SECUREC_CALC_STR_LEN_OPT(str, maxLen, outLen) do { \ + if ((maxLen) > 8) { \ + /* Optimization or len less then 8 */ \ + if (*((str) + 0) == '\0') { \ + *(outLen) = 0; \ + } else if (*((str) + 1) == '\0') { \ + *(outLen) = 1; \ + } else if (*((str) + 2) == '\0') { \ + *(outLen) = 2; \ + } else if (*((str) + 3) == '\0') { \ + *(outLen) = 3; \ + } else if (*((str) + 4) == '\0') { \ + *(outLen) = 4; \ + } else if (*((str) + 5) == '\0') { \ + *(outLen) = 5; \ + } else if (*((str) + 6) == '\0') { \ + *(outLen) = 6; \ + } else if (*((str) + 7) == '\0') { \ + *(outLen) = 7; \ + } else if (*((str) + 8) == '\0') { \ + /* Optimization with a length of 8 */ \ + *(outLen) = 8; \ + } else { \ + /* The offset is 8 because the performance of 8 byte alignment is high */ \ + *(outLen) = 8 + strnlen((str) + 8, (maxLen) - 8); \ + } \ + } else { \ + SECUREC_CALC_STR_LEN((str), (maxLen), (outLen)); \ + } \ +} SECUREC_WHILE_ZERO +#else +#define SECUREC_CALC_STR_LEN(str, maxLen, outLen) do { \ + const char *strEnd = (const char *)(str); \ + size_t availableSize = (size_t)(maxLen); \ + while (availableSize > 0 && *strEnd != '\0') { \ + --availableSize; \ + ++strEnd; \ + } \ + *(outLen) = (size_t)(strEnd - (str)); \ +} SECUREC_WHILE_ZERO +#define SECUREC_CALC_STR_LEN_OPT SECUREC_CALC_STR_LEN +#endif + +#define SECUREC_CALC_WSTR_LEN(str, maxLen, outLen) do { \ + const wchar_t *strEnd = (const wchar_t *)(str); \ + *(outLen) = 0; \ + while (*(outLen) < (maxLen) && *strEnd != L'\0') { \ + *(outLen) = *(outLen) + 1; \ + ++strEnd; \ + } \ +} SECUREC_WHILE_ZERO + + +#ifdef SECUREC_FORMAT_OUTPUT_INPUT +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) || defined(__ARMCC_VERSION) +typedef __int64 SecInt64; +typedef unsigned __int64 SecUnsignedInt64; +#if defined(__ARMCC_VERSION) +typedef unsigned int SecUnsignedInt32; +#else +typedef unsigned __int32 SecUnsignedInt32; +#endif +#else +typedef unsigned int SecUnsignedInt32; +typedef long long SecInt64; +typedef unsigned long long SecUnsignedInt64; +#endif + +#ifdef SECUREC_FOR_WCHAR +#if defined(SECUREC_VXWORKS_PLATFORM) && !defined(__WINT_TYPE__) +typedef wchar_t wint_t; +#endif +typedef wchar_t SecChar; +typedef wchar_t SecUnsignedChar; +typedef wint_t SecInt; +typedef wint_t SecUnsignedInt; +#else /* no SECUREC_FOR_WCHAR */ +typedef char SecChar; +typedef unsigned char SecUnsignedChar; +typedef int SecInt; +typedef unsigned int SecUnsignedInt; +#endif +#endif + +/* Determine whether the address is 8-byte aligned + * Some systems do not have uintptr_t type, so use NULL to clear tool alarm 507 + */ +#define SECUREC_ADDR_ALIGNED_8(addr) (SecIsAddrAligned8((addr), NULL) == 0) + +/* If you define the memory allocation function, + * you need to define the function prototype. You can define this macro as a header file. + */ +#if defined(SECUREC_MALLOC_PROTOTYPE) +SECUREC_MALLOC_PROTOTYPE +#endif + +#ifndef SECUREC_MALLOC +#define SECUREC_MALLOC(x) malloc((size_t)(x)) +#endif + +#ifndef SECUREC_FREE +#define SECUREC_FREE(x) free((void *)(x)) +#endif + +/* struct for performance */ +typedef struct { + unsigned char buf[1]; /* Performance optimization code structure assignment length 1 bytes */ +} SecStrBuf1; +typedef struct { + unsigned char buf[2]; /* Performance optimization code structure assignment length 2 bytes */ +} SecStrBuf2; +typedef struct { + unsigned char buf[3]; /* Performance optimization code structure assignment length 3 bytes */ +} SecStrBuf3; +typedef struct { + unsigned char buf[4]; /* Performance optimization code structure assignment length 4 bytes */ +} SecStrBuf4; +typedef struct { + unsigned char buf[5]; /* Performance optimization code structure assignment length 5 bytes */ +} SecStrBuf5; +typedef struct { + unsigned char buf[6]; /* Performance optimization code structure assignment length 6 bytes */ +} SecStrBuf6; +typedef struct { + unsigned char buf[7]; /* Performance optimization code structure assignment length 7 bytes */ +} SecStrBuf7; +typedef struct { + unsigned char buf[8]; /* Performance optimization code structure assignment length 8 bytes */ +} SecStrBuf8; +typedef struct { + unsigned char buf[9]; /* Performance optimization code structure assignment length 9 bytes */ +} SecStrBuf9; +typedef struct { + unsigned char buf[10]; /* Performance optimization code structure assignment length 10 bytes */ +} SecStrBuf10; +typedef struct { + unsigned char buf[11]; /* Performance optimization code structure assignment length 11 bytes */ +} SecStrBuf11; +typedef struct { + unsigned char buf[12]; /* Performance optimization code structure assignment length 12 bytes */ +} SecStrBuf12; +typedef struct { + unsigned char buf[13]; /* Performance optimization code structure assignment length 13 bytes */ +} SecStrBuf13; +typedef struct { + unsigned char buf[14]; /* Performance optimization code structure assignment length 14 bytes */ +} SecStrBuf14; +typedef struct { + unsigned char buf[15]; /* Performance optimization code structure assignment length 15 bytes */ +} SecStrBuf15; +typedef struct { + unsigned char buf[16]; /* Performance optimization code structure assignment length 16 bytes */ +} SecStrBuf16; +typedef struct { + unsigned char buf[17]; /* Performance optimization code structure assignment length 17 bytes */ +} SecStrBuf17; +typedef struct { + unsigned char buf[18]; /* Performance optimization code structure assignment length 18 bytes */ +} SecStrBuf18; +typedef struct { + unsigned char buf[19]; /* Performance optimization code structure assignment length 19 bytes */ +} SecStrBuf19; +typedef struct { + unsigned char buf[20]; /* Performance optimization code structure assignment length 20 bytes */ +} SecStrBuf20; +typedef struct { + unsigned char buf[21]; /* Performance optimization code structure assignment length 21 bytes */ +} SecStrBuf21; +typedef struct { + unsigned char buf[22]; /* Performance optimization code structure assignment length 22 bytes */ +} SecStrBuf22; +typedef struct { + unsigned char buf[23]; /* Performance optimization code structure assignment length 23 bytes */ +} SecStrBuf23; +typedef struct { + unsigned char buf[24]; /* Performance optimization code structure assignment length 24 bytes */ +} SecStrBuf24; +typedef struct { + unsigned char buf[25]; /* Performance optimization code structure assignment length 25 bytes */ +} SecStrBuf25; +typedef struct { + unsigned char buf[26]; /* Performance optimization code structure assignment length 26 bytes */ +} SecStrBuf26; +typedef struct { + unsigned char buf[27]; /* Performance optimization code structure assignment length 27 bytes */ +} SecStrBuf27; +typedef struct { + unsigned char buf[28]; /* Performance optimization code structure assignment length 28 bytes */ +} SecStrBuf28; +typedef struct { + unsigned char buf[29]; /* Performance optimization code structure assignment length 29 bytes */ +} SecStrBuf29; +typedef struct { + unsigned char buf[30]; /* Performance optimization code structure assignment length 30 bytes */ +} SecStrBuf30; +typedef struct { + unsigned char buf[31]; /* Performance optimization code structure assignment length 31 bytes */ +} SecStrBuf31; +typedef struct { + unsigned char buf[32]; /* Performance optimization code structure assignment length 32 bytes */ +} SecStrBuf32; +typedef struct { + unsigned char buf[33]; /* Performance optimization code structure assignment length 33 bytes */ +} SecStrBuf33; +typedef struct { + unsigned char buf[34]; /* Performance optimization code structure assignment length 34 bytes */ +} SecStrBuf34; +typedef struct { + unsigned char buf[35]; /* Performance optimization code structure assignment length 35 bytes */ +} SecStrBuf35; +typedef struct { + unsigned char buf[36]; /* Performance optimization code structure assignment length 36 bytes */ +} SecStrBuf36; +typedef struct { + unsigned char buf[37]; /* Performance optimization code structure assignment length 37 bytes */ +} SecStrBuf37; +typedef struct { + unsigned char buf[38]; /* Performance optimization code structure assignment length 38 bytes */ +} SecStrBuf38; +typedef struct { + unsigned char buf[39]; /* Performance optimization code structure assignment length 39 bytes */ +} SecStrBuf39; +typedef struct { + unsigned char buf[40]; /* Performance optimization code structure assignment length 40 bytes */ +} SecStrBuf40; +typedef struct { + unsigned char buf[41]; /* Performance optimization code structure assignment length 41 bytes */ +} SecStrBuf41; +typedef struct { + unsigned char buf[42]; /* Performance optimization code structure assignment length 42 bytes */ +} SecStrBuf42; +typedef struct { + unsigned char buf[43]; /* Performance optimization code structure assignment length 43 bytes */ +} SecStrBuf43; +typedef struct { + unsigned char buf[44]; /* Performance optimization code structure assignment length 44 bytes */ +} SecStrBuf44; +typedef struct { + unsigned char buf[45]; /* Performance optimization code structure assignment length 45 bytes */ +} SecStrBuf45; +typedef struct { + unsigned char buf[46]; /* Performance optimization code structure assignment length 46 bytes */ +} SecStrBuf46; +typedef struct { + unsigned char buf[47]; /* Performance optimization code structure assignment length 47 bytes */ +} SecStrBuf47; +typedef struct { + unsigned char buf[48]; /* Performance optimization code structure assignment length 48 bytes */ +} SecStrBuf48; +typedef struct { + unsigned char buf[49]; /* Performance optimization code structure assignment length 49 bytes */ +} SecStrBuf49; +typedef struct { + unsigned char buf[50]; /* Performance optimization code structure assignment length 50 bytes */ +} SecStrBuf50; +typedef struct { + unsigned char buf[51]; /* Performance optimization code structure assignment length 51 bytes */ +} SecStrBuf51; +typedef struct { + unsigned char buf[52]; /* Performance optimization code structure assignment length 52 bytes */ +} SecStrBuf52; +typedef struct { + unsigned char buf[53]; /* Performance optimization code structure assignment length 53 bytes */ +} SecStrBuf53; +typedef struct { + unsigned char buf[54]; /* Performance optimization code structure assignment length 54 bytes */ +} SecStrBuf54; +typedef struct { + unsigned char buf[55]; /* Performance optimization code structure assignment length 55 bytes */ +} SecStrBuf55; +typedef struct { + unsigned char buf[56]; /* Performance optimization code structure assignment length 56 bytes */ +} SecStrBuf56; +typedef struct { + unsigned char buf[57]; /* Performance optimization code structure assignment length 57 bytes */ +} SecStrBuf57; +typedef struct { + unsigned char buf[58]; /* Performance optimization code structure assignment length 58 bytes */ +} SecStrBuf58; +typedef struct { + unsigned char buf[59]; /* Performance optimization code structure assignment length 59 bytes */ +} SecStrBuf59; +typedef struct { + unsigned char buf[60]; /* Performance optimization code structure assignment length 60 bytes */ +} SecStrBuf60; +typedef struct { + unsigned char buf[61]; /* Performance optimization code structure assignment length 61 bytes */ +} SecStrBuf61; +typedef struct { + unsigned char buf[62]; /* Performance optimization code structure assignment length 62 bytes */ +} SecStrBuf62; +typedef struct { + unsigned char buf[63]; /* Performance optimization code structure assignment length 63 bytes */ +} SecStrBuf63; +typedef struct { + unsigned char buf[64]; /* Performance optimization code structure assignment length 64 bytes */ +} SecStrBuf64; + + + + +/* User can change the error handler by modify the following definition, + * such as logging the detail error in file. + */ +#if defined(_DEBUG) || defined(DEBUG) +#if defined(SECUREC_ERROR_HANDLER_BY_ASSERT) +#define SECUREC_ERROR_INVALID_PARAMTER(msg) assert(msg "invalid argument" == NULL) +#define SECUREC_ERROR_INVALID_RANGE(msg) assert(msg "invalid dest buffer size" == NULL) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) assert(msg "buffer overlap" == NULL) +#elif defined(SECUREC_ERROR_HANDLER_BY_PRINTF) +#if SECUREC_IN_KERNEL +#define SECUREC_ERROR_INVALID_PARAMTER(msg) printk("%s invalid argument\n", msg) +#define SECUREC_ERROR_INVALID_RANGE(msg) printk("%s invalid dest buffer size\n", msg) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) printk("%s buffer overlap\n", msg) +#else +#define SECUREC_ERROR_INVALID_PARAMTER(msg) printf("%s invalid argument\n", msg) +#define SECUREC_ERROR_INVALID_RANGE(msg) printf("%s invalid dest buffer size\n", msg) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) printf("%s buffer overlap\n", msg) +#endif +#elif defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) +#define SECUREC_ERROR_INVALID_PARAMTER(msg) LogSecureCRuntimeError(msg " EINVAL\n") +#define SECUREC_ERROR_INVALID_RANGE(msg) LogSecureCRuntimeError(msg " ERANGE\n") +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) LogSecureCRuntimeError(msg " EOVERLAP\n") +#else /* no HANDLER is defined */ +#define SECUREC_ERROR_INVALID_PARAMTER(msg) ((void)0) +#define SECUREC_ERROR_INVALID_RANGE(msg) ((void)0) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) ((void)0) +#endif +#else /* no DEBUG */ +#define SECUREC_ERROR_INVALID_PARAMTER(msg) ((void)0) +#define SECUREC_ERROR_INVALID_RANGE(msg) ((void)0) +#define SECUREC_ERROR_BUFFER_OVERLAP(msg) ((void)0) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/* assembly language memory copy and memory set for X86 or MIPS ... */ +#ifdef SECUREC_USE_ASM + extern void *memcpy_opt(void *, const void *, size_t); + extern void *memset_opt(void *, int, size_t); +#endif + +#if defined(SECUREC_ERROR_HANDLER_BY_FILE_LOG) + extern void LogSecureCRuntimeError(const char *errDetail); +#endif + +#ifdef SECUREC_INLINE_DO_MEMCPY +static void SecDoMemcpy(void *dest, const void *src, size_t count) +{ + /* + * if SECUREC_USE_ASM macro is enabled, it will call assembly language function to improve performance. + */ +#ifdef SECUREC_USE_ASM + (void)memcpy_opt(dest, src, count); +#else + /* large enough, let system API do it */ + (void)memcpy(dest, src, count); +#endif +} +#endif + +#ifdef SECUREC_INLINE_DO_MEMSET +static void SecDoMemset(void *dest, int c, size_t count) +{ +#ifdef SECUREC_USE_ASM + (void)memset_opt(dest, c, count); +#else + (void)memset(dest, c, count); +#endif +} +#endif + +#ifdef SECUREC_INLINE_STR_LEN +/* The function compiler will be inlined and not placed in other files */ +static size_t SecStrMinLen(const char *str, size_t maxLen) +{ + size_t len; + SECUREC_CALC_STR_LEN(str, maxLen, &len); + return len; +} +#endif + +#ifdef SECUREC_INLINE_STR_LEN_OPT +/* The function compiler will be inlined and not placed in other files */ +static size_t SecStrMinLenOpt(const char *str, size_t maxLen) +{ + size_t len; + SECUREC_CALC_STR_LEN_OPT(str, maxLen, &len); + return len; +} +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +#endif + diff --git a/third_party/securec/src/secureinput_a.c b/third_party/securec/src/secureinput_a.c new file mode 100644 index 00000000..4f9bae83 --- /dev/null +++ b/third_party/securec/src/secureinput_a.c @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_FORMAT_OUTPUT_INPUT 1 +#ifdef SECUREC_FOR_WCHAR +#undef SECUREC_FOR_WCHAR +#endif + +#include "secinput.h" + +#include "input.inl" + diff --git a/third_party/securec/src/secureinput_w.c b/third_party/securec/src/secureinput_w.c new file mode 100644 index 00000000..7a4bef42 --- /dev/null +++ b/third_party/securec/src/secureinput_w.c @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* if some platforms don't have wchar.h, dont't include it */ +#if !(defined(SECUREC_VXWORKS_PLATFORM)) +/* This header file is placed below secinput.h, which will cause tool alarm, + * but If there is no macro above, it will cause vs2010 compiling alarm + */ +#if defined(_MSC_VER) && (_MSC_VER >= 1400) +#ifndef __STDC_WANT_SECURE_LIB__ +/* The order of adjustment is to eliminate alarm of Duplicate Block */ +#define __STDC_WANT_SECURE_LIB__ 0 +#endif +#ifndef _CRTIMP_ALTERNATIVE +#define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ +#endif +#endif +#include +#endif +#define SECUREC_ENABLE_WCHAR_FUNC 0 +#define SECUREC_FORMAT_OUTPUT_INPUT 1 +#ifndef SECUREC_FOR_WCHAR +#define SECUREC_FOR_WCHAR +#endif + +#include "secinput.h" + +#ifndef WEOF +#define WEOF ((wchar_t)(-1)) +#endif + +#include "input.inl" + diff --git a/third_party/securec/src/secureprintoutput.h b/third_party/securec/src/secureprintoutput.h new file mode 100644 index 00000000..b690ec92 --- /dev/null +++ b/third_party/securec/src/secureprintoutput.h @@ -0,0 +1,98 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SECUREPRINTOUTPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C +#define SECUREPRINTOUTPUT_H_E950DA2C_902F_4B15_BECD_948E99090D9C +#include "securecutil.h" + +/* flag definitions */ +/* Using macros instead of enumerations is because some of the enumerated types under the compiler are 16bit. */ +#define SECUREC_FLAG_SIGN 0x00001U +#define SECUREC_FLAG_SIGN_SPACE 0x00002U +#define SECUREC_FLAG_LEFT 0x00004U +#define SECUREC_FLAG_LEADZERO 0x00008U +#define SECUREC_FLAG_LONG 0x00010U +#define SECUREC_FLAG_SHORT 0x00020U +#define SECUREC_FLAG_SIGNED 0x00040U +#define SECUREC_FLAG_ALTERNATE 0x00080U +#define SECUREC_FLAG_NEGATIVE 0x00100U +#define SECUREC_FLAG_FORCE_OCTAL 0x00200U +#define SECUREC_FLAG_LONG_DOUBLE 0x00400U +#define SECUREC_FLAG_WIDECHAR 0x00800U +#define SECUREC_FLAG_LONGLONG 0x01000U +#define SECUREC_FLAG_CHAR 0x02000U +#define SECUREC_FLAG_POINTER 0x04000U +#define SECUREC_FLAG_I64 0x08000U +#define SECUREC_FLAG_PTRDIFF 0x10000U +#define SECUREC_FLAG_SIZE 0x20000U +#ifdef SECUREC_COMPATIBLE_LINUX_FORMAT +#define SECUREC_FLAG_INTMAX 0x40000U +#endif + +/* state definitions. Identify the status of the current format */ +typedef enum { + STAT_NORMAL, + STAT_PERCENT, + STAT_FLAG, + STAT_WIDTH, + STAT_DOT, + STAT_PRECIS, + STAT_SIZE, + STAT_TYPE, + STAT_INVALID +} SecFmtState; + +/* Format output buffer pointer and available size */ +typedef struct { + int count; + char *cur; +} SecPrintfStream; + + +#ifndef SECUREC_BUFFER_SIZE +#ifdef SECUREC_STACK_SIZE_LESS_THAN_1K +/* SECUREC_BUFFER_SIZE Can not be less than 23 , + * the length of the octal representation of 64-bit integers with zero lead + */ +#define SECUREC_BUFFER_SIZE 256 +#else +#define SECUREC_BUFFER_SIZE 512 +#endif +#endif +#if SECUREC_BUFFER_SIZE < 23 +#error SECUREC_BUFFER_SIZE Can not be less than 23 +#endif + +#define SECUREC_MAX_PRECISION SECUREC_BUFFER_SIZE +/* max. # bytes in multibyte char ,see MB_LEN_MAX */ +#define SECUREC_MB_LEN 16 +/* The return value of the internal function, which is returned when truncated */ +#define SECUREC_PRINTF_TRUNCATE (-2) + +#ifdef __cplusplus +extern "C" { +#endif + extern int SecVsnprintfImpl(char *string, size_t count, const char *format, va_list argList); +#if SECUREC_IN_KERNEL == 0 + extern int SecVswprintfImpl(wchar_t *string, size_t sizeInWchar, const wchar_t *format, va_list argList); +#endif +#ifdef __cplusplus +} +#endif + +#endif + + diff --git a/third_party/securec/src/secureprintoutput_a.c b/third_party/securec/src/secureprintoutput_a.c new file mode 100644 index 00000000..746878a1 --- /dev/null +++ b/third_party/securec/src/secureprintoutput_a.c @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 +#define SECUREC_FORMAT_OUTPUT_INPUT 1 +#ifdef SECUREC_FOR_WCHAR +#undef SECUREC_FOR_WCHAR +#endif + +#include "secureprintoutput.h" + +#define SECUREC_CHAR(x) x +#define SECUREC_WRITE_MULTI_CHAR SecWriteMultiChar +#define SECUREC_WRITE_STRING SecWriteString + +#ifndef EOF +#define EOF (-1) +#endif + +/* put a char to output */ +#define SECUREC_PUTC(c, outStream) ((--(outStream)->count >= 0) ? \ + (int)((unsigned int)(unsigned char)(*((outStream)->cur++) = (char)(c)) & 0xff) : EOF) +/* to clear e835 */ +#define SECUREC_PUTC_ZERO(outStream) ((--(outStream)->count >= 0) ? \ + ((*((outStream)->cur++) = (char)('\0'))) : EOF) + +static void SecWriteMultiChar(char ch, int num, SecPrintfStream *f, int *pnumwritten); +static void SecWriteString(const char *string, int len, SecPrintfStream *f, int *pnumwritten); + +#include "output.inl" + +/* + * Wide character formatted output implementation + */ +int SecVsnprintfImpl(char *string, size_t count, const char *format, va_list argList) +{ + SecPrintfStream str; + int retVal; + + str.count = (int)count; /* this count include \0 character, Must be greater than zero */ + str.cur = string; + + retVal = SecOutputS(&str, format, argList); + if ((retVal >= 0) && (SECUREC_PUTC_ZERO(&str) != EOF)) { + return retVal; + } else if (str.count < 0) { + /* the buffer was too small; we return truncation */ + string[count - 1] = '\0'; + return SECUREC_PRINTF_TRUNCATE; + } + string[0] = '\0'; /* empty the dest strDest */ + return -1; +} + +/* + * Sec write Wide character + */ +static void SecWriteMultiChar(char ch, int num, SecPrintfStream *f, int *pnumwritten) +{ + int count = num; + while (count-- > 0) { + if (SECUREC_PUTC(ch, f) == EOF) { + *pnumwritten = -1; + break; + } else { + *pnumwritten = *pnumwritten + 1; + } + } +} + +/* + * Sec write string function + */ +static void SecWriteString(const char *string, int len, SecPrintfStream *f, int *pnumwritten) +{ + const char *str = string; + int count = len; + while (count-- > 0) { + if (SECUREC_PUTC(*str, f) == EOF) { + *pnumwritten = -1; + break; + } else { + *pnumwritten = *pnumwritten + 1; + ++str; + } + } +} + diff --git a/third_party/securec/src/secureprintoutput_w.c b/third_party/securec/src/secureprintoutput_w.c new file mode 100644 index 00000000..9063ab4d --- /dev/null +++ b/third_party/securec/src/secureprintoutput_w.c @@ -0,0 +1,170 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* if some platforms don't have wchar.h, dont't include it */ +#if !(defined(SECUREC_VXWORKS_PLATFORM)) +/* This header file is placed below secinput.h, which will cause tool alarm, + * but if there is no macro above, it will cause compiling alarm + */ +#if defined(_MSC_VER) && (_MSC_VER >= 1400) +#ifndef _CRTIMP_ALTERNATIVE +#define _CRTIMP_ALTERNATIVE /* comment microsoft *_s function */ +#endif +#ifndef __STDC_WANT_SECURE_LIB__ +#define __STDC_WANT_SECURE_LIB__ 0 +#endif +#endif +#include +#endif + +#define SECUREC_ENABLE_WCHAR_FUNC 0 +#define SECUREC_INLINE_DO_MEMCPY 1 +#define SECUREC_FORMAT_OUTPUT_INPUT 1 +#ifndef SECUREC_FOR_WCHAR +#define SECUREC_FOR_WCHAR +#endif + +#include "secureprintoutput.h" + +#ifndef WEOF +#define WEOF ((wchar_t)(-1)) +#endif + +#define SECUREC_CHAR(x) L ## x +#define SECUREC_WRITE_MULTI_CHAR SecWriteMultiCharW +#define SECUREC_WRITE_STRING SecWriteStringW + +static void SecWriteCharW(wchar_t ch, SecPrintfStream *f, int *pnumwritten); +static void SecWriteMultiCharW(wchar_t ch, int num, SecPrintfStream *f, int *pnumwritten); +static void SecWriteStringW(const wchar_t *string, int len, SecPrintfStream *f, int *pnumwritten); +static int SecPutWcharStrEndingZero(SecPrintfStream *str, int zeroCount); + + +#include "output.inl" + +/* + * Wide character formatted output implementation + */ +int SecVswprintfImpl(wchar_t *string, size_t sizeInWchar, const wchar_t *format, va_list argList) +{ + SecPrintfStream str; + int retVal; /* If initialization causes e838 */ + + str.cur = (char *)string; + /* this count include \0 character, Must be greater than zero */ + str.count = (int)(sizeInWchar * sizeof(wchar_t)); + + retVal = SecOutputSW(&str, format, argList); + if ((retVal >= 0) && SecPutWcharStrEndingZero(&str, (int)sizeof(wchar_t))) { + return (retVal); + } else if (str.count < 0) { + /* the buffer was too small; we return truncation */ + string[sizeInWchar - 1] = L'\0'; + return SECUREC_PRINTF_TRUNCATE; + } + string[0] = L'\0'; + return -1; +} + +/* + * Output one zero character zero into the SecPrintfStream structure + */ +static int SecPutZeroChar(SecPrintfStream *str) +{ + if (str->count > 0) { + *(str->cur) = (char)('\0'); + str->count = str->count - 1; + str->cur = str->cur + 1; + return 0; + } + return -1; +} + +/* + * Output a wide character zero end into the SecPrintfStream structure + */ +static int SecPutWcharStrEndingZero(SecPrintfStream *str, int zeroCount) +{ + int succeed = 0; + int i = 0; + + while (i < zeroCount && (SecPutZeroChar(str) == 0)) { + ++i; + } + if (i == zeroCount) { + succeed = 1; + } + return succeed; +} + + +/* + * Output a wide character into the SecPrintfStream structure + */ +static wchar_t SecPutCharW(wchar_t ch, SecPrintfStream *f) +{ + wchar_t wcRet = 0; + if (((f)->count -= (int)sizeof(wchar_t)) >= 0) { + *(wchar_t *)(void *)(f->cur) = ch; + f->cur += sizeof(wchar_t); + wcRet = ch; + } else { + wcRet = (wchar_t)WEOF; + } + return wcRet; +} + +/* + * Output a wide character into the SecPrintfStream structure, returns the number of characters written + */ +static void SecWriteCharW(wchar_t ch, SecPrintfStream *f, int *pnumwritten) +{ + if (SecPutCharW(ch, f) == (wchar_t)WEOF) { + *pnumwritten = -1; + } else { + *pnumwritten = *pnumwritten + 1; + } +} + +/* + * Output multiple wide character into the SecPrintfStream structure, returns the number of characters written + */ +static void SecWriteMultiCharW(wchar_t ch, int num, SecPrintfStream *f, int *pnumwritten) +{ + int count = num; + while (count-- > 0) { + SecWriteCharW(ch, f, pnumwritten); + if (*pnumwritten == -1) { + break; + } + } +} + +/* + * Output a wide string into the SecPrintfStream structure, returns the number of characters written + */ +static void SecWriteStringW(const wchar_t *string, int len, SecPrintfStream *f, int *pnumwritten) +{ + const wchar_t *str = string; + int count = len; + while (count-- > 0) { + SecWriteCharW(*str++, f, pnumwritten); + if (*pnumwritten == -1) { + break; + } + } +} + diff --git a/third_party/securec/src/snprintf_s.c b/third_party/securec/src/snprintf_s.c new file mode 100644 index 00000000..0bd7ed1b --- /dev/null +++ b/third_party/securec/src/snprintf_s.c @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +#if SECUREC_ENABLE_SNPRINTF +/* + * + * The snprintf_s function is equivalent to the snprintf function + * except for the parameter destMax/count and the explicit runtime-constraints violation + * The snprintf_s function formats and stores count or fewer characters in + * strDest and appends a terminating null. Each argument (if any) is converted + * and output according to the corresponding format specification in format. + * The formatting is consistent with the printf family of functions; If copying + * occurs between strings that overlap, the behavior is undefined. + * + * + * strDest Storage location for the output. + * destMax The size of the storage location for output. Size + * in bytes for snprintf_s or size in words for snwprintf_s. + * count Maximum number of character to store. + * format Format-control string. + * ... Optional arguments. + * + * + * strDest is updated + * + * + * return the number of characters written, not including the terminating null + * return -1 if an error occurs. + * return -1 if count < destMax and the output string has been truncated + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + * + */ +int snprintf_s(char *strDest, size_t destMax, size_t count, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vsnprintf_s(strDest, destMax, count, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(snprintf_s); +#endif +#endif + +#if SECUREC_SNPRINTF_TRUNCATED +/* + * + * The snprintf_truncated_s function is equivalent to the snprintf function + * except for the parameter destMax/count and the explicit runtime-constraints violation + * The snprintf_truncated_s function formats and stores count or fewer characters in + * strDest and appends a terminating null. Each argument (if any) is converted + * and output according to the corresponding format specification in format. + * The formatting is consistent with the printf family of functions; If copying + * occurs between strings that overlap, the behavior is undefined. + * + * + * strDest Storage location for the output. + * destMax The size of the storage location for output. Size + * in bytes for snprintf_truncated_s or size in words for snwprintf_s. + * format Format-control string. + * ... Optional arguments. + * + * + * strDest is updated + * + * + * return the number of characters written, not including the terminating null + * return -1 if an error occurs. + * return destMax-1 if output string has been truncated + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + * + */ +int snprintf_truncated_s(char *strDest, size_t destMax, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vsnprintf_truncated_s(strDest, destMax, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(snprintf_truncated_s); +#endif + +#endif + + diff --git a/third_party/securec/src/sprintf_s.c b/third_party/securec/src/sprintf_s.c new file mode 100644 index 00000000..54a79604 --- /dev/null +++ b/third_party/securec/src/sprintf_s.c @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The sprintf_s function is equivalent to the sprintf function + * except for the parameter destMax and the explicit runtime-constraints violation + * The sprintf_s function formats and stores a series of characters and values + * in strDest. Each argument (if any) is converted and output according to + * the corresponding format specification in format. The format consists of + * ordinary characters and has the same form and function as the format argument + * for printf. A null character is appended after the last character written. + * If copying occurs between strings that overlap, the behavior is undefined. + * + * + * strDest Storage location for output. + * destMax Maximum number of characters to store. + * format Format-control string. + * ... Optional arguments + * + * + * strDest is updated + * + * + * return the number of bytes stored in strDest, not counting the terminating null character. + * return -1 if an error occurred. + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int sprintf_s(char *strDest, size_t destMax, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vsprintf_s(strDest, destMax, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(sprintf_s); +#endif + + diff --git a/third_party/securec/src/sscanf_s.c b/third_party/securec/src/sscanf_s.c new file mode 100644 index 00000000..c8f097ef --- /dev/null +++ b/third_party/securec/src/sscanf_s.c @@ -0,0 +1,61 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The sscanf_s function is equivalent to fscanf_s, + * except that input is obtained from a string (specified by the argument buffer) rather than from a stream + * The sscanf function reads data from buffer into the location given by each + * argument. Every argument must be a pointer to a variable with a type that + * corresponds to a type specifier in format. The format argument controls the + * interpretation of the input fields and has the same form and function as + * the format argument for the scanf function. + * If copying takes place between strings that overlap, the behavior is undefined. + * + * + * buffer Stored data. + * format Format control string, see Format Specifications. + * ... Optional arguments. + * + * + * ... The converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int sscanf_s(const char *buffer, const char *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vsscanf_s(buffer, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(sscanf_s); +#endif + + diff --git a/third_party/securec/src/strcat_s.c b/third_party/securec/src/strcat_s.c new file mode 100644 index 00000000..6bf1379b --- /dev/null +++ b/third_party/securec/src/strcat_s.c @@ -0,0 +1,102 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_STR_LEN 1 +#define SECUREC_INLINE_STR_LEN_OPT 1 +#define SECUREC_INLINE_DO_MEMCPY 1 +#include "securecutil.h" + +/* + * Befor this function, the basic parameter checking has been done + */ +static errno_t SecDoStrcat(char *strDest, size_t destMax, const char *strSrc) +{ + size_t destLen = SecStrMinLen(strDest, destMax); + /* Only optimize strSrc, do not apply this function to strDest */ + size_t srcLen = SecStrMinLenOpt(strSrc, destMax - destLen); + + if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { + strDest[0] = '\0'; + if (strDest + destLen <= strSrc && destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_BUFFER_OVERLAP("strcat_s"); + return EOVERLAP_AND_RESET; + } + if (srcLen + destLen >= destMax || strDest == strSrc) { + strDest[0] = '\0'; + if (destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_INVALID_RANGE("strcat_s"); + return ERANGE_AND_RESET; + } + SecDoMemcpy(strDest + destLen, strSrc, srcLen + 1); /* single character length include \0 */ + return EOK; +} + +/* + * + * The strcat_s function appends a copy of the string pointed to by strSrc (including the terminating null character) + * to the end of the string pointed to by strDest. + * The initial character of strSrc overwrites the terminating null character of strDest. + * strcat_s will return EOVERLAP_AND_RESET if the source and destination strings overlap. + * + * Note that the second parameter is the total size of the buffer, not the + * remaining size. + * + * + * strDest Null-terminated destination string buffer. + * destMax Size of the destination string buffer. + * strSrc Null-terminated source string buffer. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid)or + * (strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN) + * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t strcat_s(char *strDest, size_t destMax, const char *strSrc) +{ + if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("strcat_s"); + return ERANGE; + } + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("strcat_s"); + if (strDest != NULL) { + strDest[0] = '\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + return SecDoStrcat(strDest, destMax, strSrc); +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(strcat_s); +#endif + diff --git a/third_party/securec/src/strcpy_s.c b/third_party/securec/src/strcpy_s.c new file mode 100644 index 00000000..e248da7c --- /dev/null +++ b/third_party/securec/src/strcpy_s.c @@ -0,0 +1,351 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_STR_LEN 1 +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +#if SECUREC_IN_KERNEL== 0 +#ifndef SECUREC_STRCOPY_THRESHOLD_SIZE +#define SECUREC_STRCOPY_THRESHOLD_SIZE 32UL +#endif + +/* + * Determine whether the address is 8-byte aligned, use static to increase performance + * return 0 is aligned + */ +static int SecIsAddrAligned8(const void *addr, const void *zeroAddr) +{ + return (int)(((size_t)((const char*)addr - (const char*)zeroAddr)) & 7); /* use 7 to check aligned 8 */ +} + +/* The purpose of converting to void is to clean up the alarm */ +#define SECUREC_SMALL_STR_COPY do { \ + if (SECUREC_ADDR_ALIGNED_8(strDest) && SECUREC_ADDR_ALIGNED_8(strSrc)) { \ + /* use struct assignment */ \ + switch (srcStrLen) { \ + case 1: \ + *(SecStrBuf1 *)(void *)strDest = *(const SecStrBuf1 *)(const void *)strSrc; \ + break; \ + case 2: \ + *(SecStrBuf2 *)(void *)strDest = *(const SecStrBuf2 *)(const void *)strSrc; \ + break; \ + case 3: \ + *(SecStrBuf3 *)(void *)strDest = *(const SecStrBuf3 *)(const void *)strSrc; \ + break; \ + case 4: \ + *(SecStrBuf4 *)(void *)strDest = *(const SecStrBuf4 *)(const void *)strSrc; \ + break; \ + case 5: \ + *(SecStrBuf5 *)(void *)strDest = *(const SecStrBuf5 *)(const void *)strSrc; \ + break; \ + case 6: \ + *(SecStrBuf6 *)(void *)strDest = *(const SecStrBuf6 *)(const void *)strSrc; \ + break; \ + case 7: \ + *(SecStrBuf7 *)(void *)strDest = *(const SecStrBuf7 *)(const void *)strSrc; \ + break; \ + case 8: \ + *(SecStrBuf8 *)(void *)strDest = *(const SecStrBuf8 *)(const void *)strSrc; \ + break; \ + case 9: \ + *(SecStrBuf9 *)(void *)strDest = *(const SecStrBuf9 *)(const void *)strSrc; \ + break; \ + case 10: \ + *(SecStrBuf10 *)(void *)strDest = *(const SecStrBuf10 *)(const void *)strSrc; \ + break; \ + case 11: \ + *(SecStrBuf11 *)(void *)strDest = *(const SecStrBuf11 *)(const void *)strSrc; \ + break; \ + case 12: \ + *(SecStrBuf12 *)(void *)strDest = *(const SecStrBuf12 *)(const void *)strSrc; \ + break; \ + case 13: \ + *(SecStrBuf13 *)(void *)strDest = *(const SecStrBuf13 *)(const void *)strSrc; \ + break; \ + case 14: \ + *(SecStrBuf14 *)(void *)strDest = *(const SecStrBuf14 *)(const void *)strSrc; \ + break; \ + case 15: \ + *(SecStrBuf15 *)(void *)strDest = *(const SecStrBuf15 *)(const void *)strSrc; \ + break; \ + case 16: \ + *(SecStrBuf16 *)(void *)strDest = *(const SecStrBuf16 *)(const void *)strSrc; \ + break; \ + case 17: \ + *(SecStrBuf17 *)(void *)strDest = *(const SecStrBuf17 *)(const void *)strSrc; \ + break; \ + case 18: \ + *(SecStrBuf18 *)(void *)strDest = *(const SecStrBuf18 *)(const void *)strSrc; \ + break; \ + case 19: \ + *(SecStrBuf19 *)(void *)strDest = *(const SecStrBuf19 *)(const void *)strSrc; \ + break; \ + case 20: \ + *(SecStrBuf20 *)(void *)strDest = *(const SecStrBuf20 *)(const void *)strSrc; \ + break; \ + case 21: \ + *(SecStrBuf21 *)(void *)strDest = *(const SecStrBuf21 *)(const void *)strSrc; \ + break; \ + case 22: \ + *(SecStrBuf22 *)(void *)strDest = *(const SecStrBuf22 *)(const void *)strSrc; \ + break; \ + case 23: \ + *(SecStrBuf23 *)(void *)strDest = *(const SecStrBuf23 *)(const void *)strSrc; \ + break; \ + case 24: \ + *(SecStrBuf24 *)(void *)strDest = *(const SecStrBuf24 *)(const void *)strSrc; \ + break; \ + case 25: \ + *(SecStrBuf25 *)(void *)strDest = *(const SecStrBuf25 *)(const void *)strSrc; \ + break; \ + case 26: \ + *(SecStrBuf26 *)(void *)strDest = *(const SecStrBuf26 *)(const void *)strSrc; \ + break; \ + case 27: \ + *(SecStrBuf27 *)(void *)strDest = *(const SecStrBuf27 *)(const void *)strSrc; \ + break; \ + case 28: \ + *(SecStrBuf28 *)(void *)strDest = *(const SecStrBuf28 *)(const void *)strSrc; \ + break; \ + case 29: \ + *(SecStrBuf29 *)(void *)strDest = *(const SecStrBuf29 *)(const void *)strSrc; \ + break; \ + case 30: \ + *(SecStrBuf30 *)(void *)strDest = *(const SecStrBuf30 *)(const void *)strSrc; \ + break; \ + case 31: \ + *(SecStrBuf31 *)(void *)strDest = *(const SecStrBuf31 *)(const void *)strSrc; \ + break; \ + case 32: \ + *(SecStrBuf32 *)(void *)strDest = *(const SecStrBuf32 *)(const void *)strSrc; \ + break; \ + default: \ + break; \ + } /* END switch */ \ + } else { \ + char *tmpStrDest = (char *)strDest; \ + const char *tmpStrSrc = (const char *)strSrc; \ + switch (srcStrLen) { \ + case 32: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 31: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 30: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 29: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 28: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 27: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 26: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 25: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 24: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 23: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 22: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 21: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 20: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 19: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 18: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 17: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 16: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 15: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 14: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 13: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 12: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 11: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 10: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 9: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 8: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 7: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 6: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 5: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 4: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 3: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 2: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + case 1: \ + *(tmpStrDest++) = *(tmpStrSrc++); \ + /* fall-through */ /* FALLTHRU */ \ + default: \ + break; \ + } \ + } \ +} SECUREC_WHILE_ZERO +#endif + +/* + * Check Src Range + */ +static errno_t CheckSrcRange(char *strDest, size_t destMax, const char *strSrc) +{ + size_t tmpDestMax = destMax; + const char *tmpSrc = strSrc; + /* use destMax as boundary checker and destMax must be greater than zero */ + while (*(tmpSrc) != '\0' && tmpDestMax > 0) { + ++tmpSrc; + --tmpDestMax; + } + if (tmpDestMax == 0) { + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("strcpy_s"); + return ERANGE_AND_RESET; + } + return EOK; +} + +/* + * Handling errors + */ +errno_t strcpy_error(char *strDest, size_t destMax, const char *strSrc) +{ + if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("strcpy_s"); + return ERANGE; + } else if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("strcpy_s"); + if (strDest != NULL) { + strDest[0] = '\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + return CheckSrcRange(strDest, destMax, strSrc); +} + +/* + * Performance optimization. srcStrLen include '\0' + */ +static void SecDoStrcpyOpt(char *strDest, const char *strSrc, size_t srcStrLen) +{ +#if SECUREC_IN_KERNEL + SecDoMemcpy(strDest, strSrc, srcStrLen); +#else + if (srcStrLen > SECUREC_STRCOPY_THRESHOLD_SIZE) { + SecDoMemcpy(strDest, strSrc, srcStrLen); + } else { + SECUREC_SMALL_STR_COPY; + } +#endif +} + +/* + * + * The strcpy_s function copies the string pointed to strSrc + * (including the terminating null character) into the array pointed to by strDest + * The destination string must be large enough to hold the source string, + * including the terminating null character. strcpy_s will return EOVERLAP_AND_RESET + * if the source and destination strings overlap. + * + * + * strDest Location of destination string buffer + * destMax Size of the destination string buffer. + * strSrc Null-terminated source string buffer. + * + * + * strDest is updated. + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * EINVAL_AND_RESET strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t strcpy_s(char *strDest, size_t destMax, const char *strSrc) +{ + if ((destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN && strDest != NULL && strSrc != NULL && strDest != strSrc)) { + size_t srcStrLen = SecStrMinLen(strSrc, destMax) + 1; /* len include \0 */ + if (srcStrLen <= destMax) { + /* use mem overlap check include \0 */ + if (SECUREC_MEMORY_NO_OVERLAP(strDest, strSrc, srcStrLen)) { + /* performance optimization srcStrLen include '\0' */ + SecDoStrcpyOpt(strDest, strSrc, srcStrLen); + return EOK; + } else { + strDest[0] = '\0'; + SECUREC_ERROR_BUFFER_OVERLAP("strcpy_s"); + return EOVERLAP_AND_RESET; + } + } + } + return strcpy_error(strDest, destMax, strSrc); +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(strcpy_s); +#endif + diff --git a/third_party/securec/src/strncat_s.c b/third_party/securec/src/strncat_s.c new file mode 100644 index 00000000..78234fd5 --- /dev/null +++ b/third_party/securec/src/strncat_s.c @@ -0,0 +1,121 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_STR_LEN 1 +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +/* + * Befor this function, the basic parameter checking has been done + */ +static errno_t SecDoStrncat(char *strDest, size_t destMax, const char *strSrc, size_t count) +{ + size_t destLen = SecStrMinLen(strDest, destMax); + /* The strSrc is no longer optimized. The reason is that when count is small, + * the efficiency of strnlen is higher than that of self realization. + */ + size_t srcLen = SecStrMinLen(strSrc, count); + + if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { + strDest[0] = '\0'; + if (strDest + destLen <= strSrc && destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_BUFFER_OVERLAP("strncat_s"); + return EOVERLAP_AND_RESET; + } + if (srcLen + destLen >= destMax || strDest == strSrc) { + strDest[0] = '\0'; + if (destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_INVALID_RANGE("strncat_s"); + return ERANGE_AND_RESET; + } + SecDoMemcpy(strDest + destLen, strSrc, srcLen); /* no terminator */ + *(strDest + destLen + srcLen) = '\0'; + return EOK; +} + +/* + * + * The strncat_s function appends not more than n successive characters + * (not including the terminating null character) + * from the array pointed to by strSrc to the end of the string pointed to by strDest + * The strncat_s function try to append the first D characters of strSrc to + * the end of strDest, where D is the lesser of count and the length of strSrc. + * If appending those D characters will fit within strDest (whose size is given + * as destMax) and still leave room for a null terminator, then those characters + * are appended, starting at the original terminating null of strDest, and a + * new terminating null is appended; otherwise, strDest[0] is set to the null + * character. + * + * + * strDest Null-terminated destination string. + * destMax Size of the destination buffer. + * strSrc Null-terminated source string. + * count Number of character to append, or truncate. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid)or + * (strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN) + * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t strncat_s(char *strDest, size_t destMax, const char *strSrc, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("strncat_s"); + return ERANGE; + } + + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("strncat_s"); + if (strDest != NULL) { + strDest[0] = '\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > SECUREC_STRING_MAX_LEN) { +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + if (count == (size_t)(-1)) { + /* Windows internal functions may pass in -1 when calling this function */ + return SecDoStrncat(strDest, destMax, strSrc, destMax); + } +#endif + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("strncat_s"); + return ERANGE_AND_RESET; + } + return SecDoStrncat(strDest, destMax, strSrc, count); +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(strncat_s); +#endif + diff --git a/third_party/securec/src/strncpy_s.c b/third_party/securec/src/strncpy_s.c new file mode 100644 index 00000000..493d1f74 --- /dev/null +++ b/third_party/securec/src/strncpy_s.c @@ -0,0 +1,143 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_STR_LEN 1 +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +#if defined(SECUREC_COMPATIBLE_WIN_FORMAT) +#define SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count) \ + (((destMax) > 0 && (destMax) <= SECUREC_STRING_MAX_LEN && (strDest) != NULL && (strSrc) != NULL && \ + ((count) <= SECUREC_STRING_MAX_LEN || (count) == ((size_t)(-1))) && (count) > 0)) +#else +#define SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count) \ + (((destMax) > 0 && (destMax) <= SECUREC_STRING_MAX_LEN && (strDest) != NULL && (strSrc) != NULL && \ + (count) <= SECUREC_STRING_MAX_LEN && (count) > 0)) +#endif + +/* + * Check Src Count Range + */ +static errno_t CheckSrcCountRange(char *strDest, size_t destMax, const char *strSrc, size_t count) +{ + size_t tmpDestMax = destMax; + size_t tmpCount = count; + const char *endPos = strSrc; + + /* use destMax and count as boundary checker and destMax must be greater than zero */ + while (*(endPos) != '\0' && tmpDestMax > 0 && tmpCount > 0) { + ++endPos; + --tmpCount; + --tmpDestMax; + } + if (tmpDestMax == 0) { + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("strncpy_s"); + return ERANGE_AND_RESET; + } + return EOK; +} + +/* + * Handling errors, when dest euqal src return EOK + */ +errno_t strncpy_error(char *strDest, size_t destMax, const char *strSrc, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("strncpy_s"); + return ERANGE; + } else if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("strncpy_s"); + if (strDest != NULL) { + strDest[0] = '\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } else if (count > SECUREC_STRING_MAX_LEN) { + strDest[0] = '\0'; /* clear dest string */ + SECUREC_ERROR_INVALID_RANGE("strncpy_s"); + return ERANGE_AND_RESET; + } else if (count == 0) { + strDest[0] = '\0'; + return EOK; + } + + return CheckSrcCountRange(strDest, destMax, strSrc, count); +} + +/* + * + * The strncpy_s function copies not more than n successive characters (not including the terminating null character) + * from the array pointed to by strSrc to the array pointed to by strDest. + * + * + * strDest Destination string. + * destMax The size of the destination string, in characters. + * strSrc Source string. + * count Number of characters to be copied. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * EINVAL_AND_RESET strDest != NULL and strSrc is NULL and destMax != 0 and destMax <= SECUREC_STRING_MAX_LEN + * ERANGE destMax is 0 and destMax > SECUREC_STRING_MAX_LEN + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t strncpy_s(char *strDest, size_t destMax, const char *strSrc, size_t count) +{ + if (SECUREC_STRNCPY_PARAM_OK(strDest, destMax, strSrc, count)) { + size_t minCpLen; /* use it to store the maxi length limit */ + if (count < destMax) { + minCpLen = SecStrMinLen(strSrc, count); /* no ending terminator */ + } else { + size_t tmpCount = destMax; +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + if (count == ((size_t)(-1))) { + tmpCount = destMax - 1; + } +#endif + minCpLen = SecStrMinLen(strSrc, tmpCount); + if (minCpLen == destMax) { + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("strncpy_s"); + return ERANGE_AND_RESET; + } + } + if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, minCpLen) || strDest == strSrc) { + /* Not overlap */ + SecDoMemcpy(strDest, strSrc, minCpLen); /* copy string without terminator */ + strDest[minCpLen] = '\0'; + return EOK; + } else { + strDest[0] = '\0'; + SECUREC_ERROR_BUFFER_OVERLAP("strncpy_s"); + return EOVERLAP_AND_RESET; + } + } + return strncpy_error(strDest, destMax, strSrc, count); +} + +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(strncpy_s); +#endif + diff --git a/third_party/securec/src/strtok_s.c b/third_party/securec/src/strtok_s.c new file mode 100644 index 00000000..18f977a7 --- /dev/null +++ b/third_party/securec/src/strtok_s.c @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * Find beginning of token (skip over leading delimiters).Note that + * there is no token if this loop sets string to point to the terminal null. + */ +static char *SecFindBegin(char *strToken, const char *strDelimit) +{ + char *token = strToken; + while (*token != '\0') { + const char *ctl = strDelimit; + while (*ctl != '\0' && *ctl != *token) { + ++ctl; + } + if (*ctl == '\0') { /* don't find any delimiter in string header, break the loop */ + break; + } + ++token; + } + return token; +} + +/* + * Find rest of token + */ +static char *SecFindRest(char *strToken, const char *strDelimit) +{ + /* Find the rest of the token. If it is not the end of the string, + * put a null there. + */ + char *token = strToken; + while (*token != '\0') { + const char *ctl = strDelimit; + while (*ctl != '\0' && *ctl != *token) { + ++ctl; + } + if (*ctl != '\0') { /* find a delimiter */ + *token++ = '\0'; /* set string termintor */ + break; + } + ++token; + } + return token; +} + +/* + * Find the final position pointer + */ +static char *SecUpdateToken(char *strToken, const char *strDelimit, char **context) +{ + /* point to updated position */ + char *token = SecFindRest(strToken, strDelimit); + /* record string position for next search in the context */ + *context = token; + /* Determine if a token has been found. */ + if (token == strToken) { + return NULL; + } + return strToken; +} + +/* + * + * The strtok_s function parses a string into a sequence of strToken, + * replace all characters in strToken string that match to strDelimit set with 0. + * On the first call to strtok_s the string to be parsed should be specified in strToken. + * In each subsequent call that should parse the same string, strToken should be NULL + * + * strToken String containing token or tokens. + * strDelimit Set of delimiter characters. + * context Used to store position information between calls + * to strtok_s + * + * context is updated + * + * On the first call returns the address of the first non \0 character, otherwise NULL is returned. + * In subsequent calls, the strtoken is set to NULL, and the context set is the same as the previous call, + * return NULL if the *context string length is equal 0, otherwise return *context. + */ +char *strtok_s(char *strToken, const char *strDelimit, char **context) +{ + char *orgToken = strToken; + /* validate delimiter and string context */ + if (context == NULL || strDelimit == NULL) { + return NULL; + } + /* valid input string and string pointer from where to search */ + if (orgToken == NULL && (*context) == NULL) { + return NULL; + } + /* If string is null, continue searching from previous string position stored in context */ + if (orgToken == NULL) { + orgToken = *context; + } + orgToken = SecFindBegin(orgToken, strDelimit); + return SecUpdateToken(orgToken, strDelimit, context); +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(strtok_s); +#endif + diff --git a/third_party/securec/src/swprintf_s.c b/third_party/securec/src/swprintf_s.c new file mode 100644 index 00000000..1fb0f6c7 --- /dev/null +++ b/third_party/securec/src/swprintf_s.c @@ -0,0 +1,51 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The swprintf_s function is the wide-character equivalent of the sprintf_s function + * + * + * strDest Storage location for the output. + * destMax Maximum number of characters to store. + * format Format-control string. + * ... Optional arguments + * + * + * strDest is updated + * + * + * return the number of wide characters stored in strDest, not counting the terminating null wide character. + * return -1 if an error occurred. + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int swprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vswprintf_s(strDest, destMax, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + + diff --git a/third_party/securec/src/swscanf_s.c b/third_party/securec/src/swscanf_s.c new file mode 100644 index 00000000..c16045fa --- /dev/null +++ b/third_party/securec/src/swscanf_s.c @@ -0,0 +1,57 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * The swscanf_s function is the wide-character equivalent of the sscanf_s function + * The swscanf_s function reads data from buffer into the location given by + * each argument. Every argument must be a pointer to a variable with a type + * that corresponds to a type specifier in format. The format argument controls + * the interpretation of the input fields and has the same form and function + * as the format argument for the scanf function. If copying takes place between + * strings that overlap, the behavior is undefined. + * + * + * buffer Stored data. + * format Format control string, see Format Specifications. + * ... Optional arguments. + * + * + * ... the converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; The return value does not include fields that were read but not + * assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int swscanf_s(const wchar_t *buffer, const wchar_t *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vswscanf_s(buffer, format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} + + diff --git a/third_party/securec/src/vfscanf_s.c b/third_party/securec/src/vfscanf_s.c new file mode 100644 index 00000000..78444e4b --- /dev/null +++ b/third_party/securec/src/vfscanf_s.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" + +/* + * + * The vfscanf_s function is equivalent to fscanf_s, with the variable argument list replaced by argList + * The vfscanf_s function reads data from the current position of stream into + * the locations given by argument (if any). Each argument must be a pointer + * to a variable of a type that corresponds to a type specifier in format. + * format controls the interpretation of the input fields and has the same + * form and function as the format argument for scanf. + * + * + * stream Pointer to FILE structure. + * format Format control string, see Format Specifications. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vfscanf_s(FILE *stream, const char *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + SecFileStream fStr; + + if ((stream == NULL) || (format == NULL)) { + SECUREC_ERROR_INVALID_PARAMTER("vfscanf_s"); + return SECUREC_SCANF_EINVAL; + } + if (stream == stdin) { + return vscanf_s(format, argList); + } + + SECUREC_LOCK_FILE(stream); + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FILE_STREAM_FLAG, stream, SECUREC_UNINITIALIZED_FILE_POS, NULL, 0); + retVal = SecInputS(&fStr, format, argList); + SECUREC_UNLOCK_FILE(stream); + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vfscanf_s"); + return SECUREC_SCANF_EINVAL; + } + + return retVal; +} + + diff --git a/third_party/securec/src/vfwscanf_s.c b/third_party/securec/src/vfwscanf_s.c new file mode 100644 index 00000000..3ae62eea --- /dev/null +++ b/third_party/securec/src/vfwscanf_s.c @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" + +/* + * + * The vfwscanf_s function is the wide-character equivalent of the vfscanf_s function + * The vfwscanf_s function reads data from the current position of stream into + * the locations given by argument (if any). Each argument must be a pointer + * to a variable of a type that corresponds to a type specifier in format. + * format controls the interpretation of the input fields and has the same form + * and function as the format argument for scanf. + * + * + * stream Pointer to FILE structure. + * format Format control string, see Format Specifications. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vfwscanf_s(FILE *stream, const wchar_t *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + SecFileStream fStr; + + if ((stream == NULL) || (format == NULL)) { + SECUREC_ERROR_INVALID_PARAMTER("vfwscanf_s"); + return SECUREC_SCANF_EINVAL; + } + if (stream == stdin) { + return vwscanf_s(format, argList); + } + + SECUREC_LOCK_FILE(stream); + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FILE_STREAM_FLAG, stream, SECUREC_UNINITIALIZED_FILE_POS, NULL, 0); + retVal = SecInputSW(&fStr, format, argList); + SECUREC_UNLOCK_FILE(stream); + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vfwscanf_s"); + return SECUREC_SCANF_EINVAL; + } + return retVal; +} + + diff --git a/third_party/securec/src/vscanf_s.c b/third_party/securec/src/vscanf_s.c new file mode 100644 index 00000000..66669765 --- /dev/null +++ b/third_party/securec/src/vscanf_s.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" + +/* + * + * The vscanf_s function is equivalent to scanf_s, with the variable argument list replaced by argList, + * The vscanf_s function reads data from the standard input stream stdin and + * writes the data into the location that's given by argument. Each argument + * must be a pointer to a variable of a type that corresponds to a type specifier + * in format. If copying occurs between strings that overlap, the behavior is + * undefined. + * + * + * format Format control string. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Returns the number of fields successfully converted and assigned; + * the return value does not include fields that were read but not assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vscanf_s(const char *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + SecFileStream fStr; + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FROM_STDIN_FLAG, stdin, 0, NULL, 0); + /* + * "va_list" has different definition on different platform, so we can't use argList == NULL + * to determine it's invalid. If you has fixed platform, you can check some fields to validate it, + * such as "argList == NULL" or argList.xxx != NULL or *(size_t *)&argList != 0. + */ + if (format == NULL || fStr.pf == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("vscanf_s"); + return SECUREC_SCANF_EINVAL; + } + + SECUREC_LOCK_STDIN(0, fStr.pf); + + retVal = SecInputS(&fStr, format, argList); + + SECUREC_UNLOCK_STDIN(0, fStr.pf); + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vscanf_s"); + return SECUREC_SCANF_EINVAL; + } + return retVal; +} + + diff --git a/third_party/securec/src/vsnprintf_s.c b/third_party/securec/src/vsnprintf_s.c new file mode 100644 index 00000000..dfa55bab --- /dev/null +++ b/third_party/securec/src/vsnprintf_s.c @@ -0,0 +1,149 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secureprintoutput.h" + +#if SECUREC_ENABLE_VSNPRINTF +/* + * + * The vsnprintf_s function is equivalent to the vsnprintf function + * except for the parameter destMax/count and the explicit runtime-constraints violation + * The vsnprintf_s function takes a pointer to an argument list, then formats + * and writes up to count characters of the given data to the memory pointed + * to by strDest and appends a terminating null. + * + * + * strDest Storage location for the output. + * destMax The size of the strDest for output. + * count Maximum number of character to write(not including + * the terminating NULL) + * format Format-control string. + * argList pointer to list of arguments. + * + * + * strDest is updated + * + * + * return the number of characters written, not including the terminating null + * return -1 if an error occurs. + * return -1 if count < destMax and the output string has been truncated + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int vsnprintf_s(char *strDest, size_t destMax, size_t count, const char *format, va_list argList) +{ + int retVal; + + if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN || + (count > (SECUREC_STRING_MAX_LEN - 1) && count != (size_t)(-1))) { + if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { + strDest[0] = '\0'; + } + SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_s"); + return -1; + } + + if (destMax > count) { + retVal = SecVsnprintfImpl(strDest, count + 1, format, argList); + if (retVal == SECUREC_PRINTF_TRUNCATE) { /* lsd add to keep dest buffer not destroyed 2014.2.18 */ + /* the string has been truncated, return -1 */ + return -1; /* to skip error handler, return strlen(strDest) or -1 */ + } + } else { + retVal = SecVsnprintfImpl(strDest, destMax, format, argList); +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + if (retVal == SECUREC_PRINTF_TRUNCATE && count == (size_t)(-1)) { + return -1; + } +#endif + } + + if (retVal < 0) { + strDest[0] = '\0'; /* empty the dest strDest */ + + if (retVal == SECUREC_PRINTF_TRUNCATE) { + /* Buffer too small */ + SECUREC_ERROR_INVALID_RANGE("vsnprintf_s"); + } + + SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_s"); + return -1; + } + + return retVal; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(vsnprintf_s); +#endif +#endif + +#if SECUREC_SNPRINTF_TRUNCATED +/* + * + * The vsnprintf_truncated_s function is equivalent to the vsnprintf function + * except for the parameter destMax/count and the explicit runtime-constraints violation + * The vsnprintf_truncated_s function takes a pointer to an argument list, then formats + * and writes up to count characters of the given data to the memory pointed + * to by strDest and appends a terminating null. + * + * + * strDest Storage location for the output. + * destMax The size of the strDest for output. + * the terminating NULL) + * format Format-control string. + * argList pointer to list of arguments. + * + * + * strDest is updated + * + * + * return the number of characters written, not including the terminating null + * return -1 if an error occurs. + * return destMax-1 if output string has been truncated + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int vsnprintf_truncated_s(char *strDest, size_t destMax, const char *format, va_list argList) +{ + int retVal; + + if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { + strDest[0] = '\0'; + } + SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_truncated_s"); + return -1; + } + + retVal = SecVsnprintfImpl(strDest, destMax, format, argList); + + if (retVal < 0) { + if (retVal == SECUREC_PRINTF_TRUNCATE) { + return (int)(destMax - 1); /* to skip error handler, return strlen(strDest) */ + } + strDest[0] = '\0'; /* empty the dest strDest */ + SECUREC_ERROR_INVALID_PARAMTER("vsnprintf_truncated_s"); + return -1; + } + + return retVal; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(vsnprintf_truncated_s); +#endif +#endif + + diff --git a/third_party/securec/src/vsprintf_s.c b/third_party/securec/src/vsprintf_s.c new file mode 100644 index 00000000..e74c7748 --- /dev/null +++ b/third_party/securec/src/vsprintf_s.c @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secureprintoutput.h" + +/* + * + * The vsprintf_s function is equivalent to the vsprintf function + * except for the parameter destMax and the explicit runtime-constraints violation + * The vsprintf_s function takes a pointer to an argument list, and then formats + * and writes the given data to the memory pointed to by strDest. + * The function differ from the non-secure versions only in that the secure + * versions support positional parameters. + * + * + * strDest Storage location for the output. + * destMax Size of strDest + * format Format specification. + * argList pointer to list of arguments + * + * + * strDest is updated + * + * + * return the number of characters written, not including the terminating null character, + * return -1 if an error occurs. + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int vsprintf_s(char *strDest, size_t destMax, const char *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + + if (format == NULL || strDest == NULL || destMax == 0 || destMax > SECUREC_STRING_MAX_LEN) { + if (strDest != NULL && destMax > 0 && destMax <= SECUREC_STRING_MAX_LEN) { + strDest[0] = '\0'; + } + SECUREC_ERROR_INVALID_PARAMTER("vsprintf_s"); + return -1; + } + + retVal = SecVsnprintfImpl(strDest, destMax, format, argList); + + if (retVal < 0) { + strDest[0] = '\0'; + if (retVal == SECUREC_PRINTF_TRUNCATE) { + /* Buffer is too small */ + SECUREC_ERROR_INVALID_RANGE("vsprintf_s"); + } + SECUREC_ERROR_INVALID_PARAMTER("vsprintf_s"); + return -1; + } + + return retVal; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(vsprintf_s); +#endif + + diff --git a/third_party/securec/src/vsscanf_s.c b/third_party/securec/src/vsscanf_s.c new file mode 100644 index 00000000..e0a5ecda --- /dev/null +++ b/third_party/securec/src/vsscanf_s.c @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" +#if defined(SECUREC_VXWORKS_PLATFORM) && (!defined(SECUREC_SYSAPI4VXWORKS) && !defined(SECUREC_CTYPE_MACRO_ADAPT)) +#include +#endif + +/* + * + * vsscanf_s + * + * + * + * The vsscanf_s function is equivalent to sscanf_s, with the variable argument list replaced by argList + * The vsscanf_s function reads data from buffer into the location given by + * each argument. Every argument must be a pointer to a variable with a type + * that corresponds to a type specifier in format. The format argument controls + * the interpretation of the input fields and has the same form and function + * as the format argument for the scanf function. + * If copying takes place between strings that overlap, the behavior is undefined. + * + * + * buffer Stored data + * format Format control string, see Format Specifications. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vsscanf_s(const char *buffer, const char *format, va_list argList) +{ + size_t count; /* If initialization causes e838 */ + int retVal; + SecFileStream fStr; + + /* validation section */ + if (buffer == NULL || format == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); + return SECUREC_SCANF_EINVAL; + } + count = strlen(buffer); + if (count == 0 || count > SECUREC_STRING_MAX_LEN) { + SecClearDestBuf(buffer, format, argList); + SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); + return SECUREC_SCANF_EINVAL; + } +#ifdef SECUREC_VXWORKS_PLATFORM + /* + * in vxworks platform when buffer is white string, will set first %s argument tu zero.like following useage: + * " \v\f\t\r\n", "%s", str, strSize + * do not check all character, just first and last character then consider it is white string + */ + if (isspace((int)buffer[0]) && isspace((int)buffer[count - 1])) { + SecClearDestBuf(buffer, format, argList); + } +#endif + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_MEM_STR_FLAG, NULL, 0, buffer, (int)count); + retVal = SecInputS(&fStr, format, argList); + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vsscanf_s"); + return SECUREC_SCANF_EINVAL; + } + return retVal; +} +#if SECUREC_IN_KERNEL +EXPORT_SYMBOL(vsscanf_s); +#endif + diff --git a/third_party/securec/src/vswprintf_s.c b/third_party/securec/src/vswprintf_s.c new file mode 100644 index 00000000..3403a6b5 --- /dev/null +++ b/third_party/securec/src/vswprintf_s.c @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secureprintoutput.h" + + +/* + * + * The vswprintf_s function is the wide-character equivalent of the vsprintf_s function + * + * + * strDest Storage location for the output. + * destMax Size of strDest + * format Format specification. + * argList pointer to list of arguments + * + * + * strDest is updated + * + * + * return the number of wide characters stored in strDest, not counting the terminating null wide character. + * return -1 if an error occurred. + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +int vswprintf_s(wchar_t *strDest, size_t destMax, const wchar_t *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + + if (format == NULL || strDest == NULL || destMax == 0 || destMax > (SECUREC_WCHAR_STRING_MAX_LEN)) { + if (strDest != NULL && destMax > 0) { + strDest[0] = '\0'; + } + SECUREC_ERROR_INVALID_PARAMTER("vswprintf_s"); + return -1; + } + + retVal = SecVswprintfImpl(strDest, destMax, format, argList); + + if (retVal < 0) { + strDest[0] = '\0'; + if (retVal == SECUREC_PRINTF_TRUNCATE) { + /* Buffer too small */ + SECUREC_ERROR_INVALID_RANGE("vswprintf_s"); + } + SECUREC_ERROR_INVALID_PARAMTER("vswprintf_s"); + return -1; + } + + return retVal; +} + + diff --git a/third_party/securec/src/vswscanf_s.c b/third_party/securec/src/vswscanf_s.c new file mode 100644 index 00000000..269e1053 --- /dev/null +++ b/third_party/securec/src/vswscanf_s.c @@ -0,0 +1,79 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" + +static size_t SecWcslen(const wchar_t *s) +{ + const wchar_t *end = s; + while (*end != L'\0') { + ++end; + } + return ((size_t)((end - s))); +} + +/* + * + * The vswscanf_s function is the wide-character equivalent of the vsscanf_s function + * The vsscanf_s function reads data from buffer into the location given by + * each argument. Every argument must be a pointer to a variable with a type + * that corresponds to a type specifier in format. + * The format argument controls the interpretation of the input fields and + * has the same form and function as the format argument for the scanf function. + * If copying takes place between strings that overlap, the behavior is undefined. + * + * + * buffer Stored data + * format Format control string, see Format Specifications. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Each of these functions returns the number of fields successfully converted + * and assigned; the return value does not include fields that were read but + * not assigned. A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vswscanf_s(const wchar_t *buffer, const wchar_t *format, va_list argList) +{ + size_t count; /* If initialization causes e838 */ + SecFileStream fStr; + int retVal; + + /* validation section */ + if (buffer == NULL || format == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); + return SECUREC_SCANF_EINVAL; + } + count = SecWcslen(buffer); + if (count == 0 || count > SECUREC_WCHAR_STRING_MAX_LEN) { + SecClearDestBufW(buffer, format, argList); + SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); + return SECUREC_SCANF_EINVAL; + } + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_MEM_STR_FLAG, NULL, 0,\ + (const char *)buffer, (int)count * ((int)sizeof(wchar_t))); + retVal = SecInputSW(&fStr, format, argList); + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vswscanf_s"); + return SECUREC_SCANF_EINVAL; + } + return retVal; +} + + diff --git a/third_party/securec/src/vwscanf_s.c b/third_party/securec/src/vwscanf_s.c new file mode 100644 index 00000000..56e0f6b4 --- /dev/null +++ b/third_party/securec/src/vwscanf_s.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "secinput.h" + +/* + * + * The vwscanf_s function is the wide-character equivalent of the vscanf_s function + * The vwscanf_s function is the wide-character version of vscanf_s. The + * function reads data from the standard input stream stdin and writes the + * data into the location that's given by argument. Each argument must be a + * pointer to a variable of a type that corresponds to a type specifier in + * format. If copying occurs between strings that overlap, the behavior is + * undefined. + * + * + * format Format control string. + * argList pointer to list of arguments + * + * + * argList the converted value stored in user assigned address + * + * + * Returns the number of fields successfully converted and assigned; + * the return value does not include fields that were read but not assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ +int vwscanf_s(const wchar_t *format, va_list argList) +{ + int retVal; /* If initialization causes e838 */ + SecFileStream fStr; + + SECUREC_INIT_SEC_FILE_STREAM(fStr, SECUREC_FROM_STDIN_FLAG, stdin, 0, NULL, 0); + if (format == NULL || fStr.pf == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("vwscanf_s"); + return SECUREC_SCANF_EINVAL; + } + + SECUREC_LOCK_STDIN(0, fStr.pf); + + retVal = SecInputSW(&fStr, format, argList); + + SECUREC_UNLOCK_STDIN(0, fStr.pf); + + if (retVal < 0) { + SECUREC_ERROR_INVALID_PARAMTER("vwscanf_s"); + return SECUREC_SCANF_EINVAL; + } + + return retVal; +} + + diff --git a/third_party/securec/src/wcscat_s.c b/third_party/securec/src/wcscat_s.c new file mode 100644 index 00000000..51254b3f --- /dev/null +++ b/third_party/securec/src/wcscat_s.c @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +/* + * Befor this function, the basic parameter checking has been done + */ +static errno_t SecDoWcscat(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) +{ + size_t destLen; + size_t srcLen; + size_t maxCount; /* Store the maximum available count */ + + /* To calculate the length of a wide character, the parameter must be a wide character */ + SECUREC_CALC_WSTR_LEN(strDest, destMax, &destLen); + maxCount = destMax - destLen; + SECUREC_CALC_WSTR_LEN(strSrc, maxCount, &srcLen); + + if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { + strDest[0] = L'\0'; + if (strDest + destLen <= strSrc && destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_BUFFER_OVERLAP("wcscat_s"); + return EOVERLAP_AND_RESET; + } + if (srcLen + destLen >= destMax || strDest == strSrc) { + strDest[0] = L'\0'; + if (destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_INVALID_RANGE("wcscat_s"); + return ERANGE_AND_RESET; + } + SecDoMemcpy(strDest + destLen, strSrc, (srcLen + 1) * sizeof(wchar_t)); /* single character length include \0 */ + return EOK; +} + +/* + * + * The wcscat_s function appends a copy of the wide string pointed to by strSrc +* (including the terminating null wide character) + * to the end of the wide string pointed to by strDest. + * The arguments and return value of wcscat_s are wide-character strings. + * + * The wcscat_s function appends strSrc to strDest and terminates the resulting + * string with a null character. The initial character of strSrc overwrites the + * terminating null character of strDest. wcscat_s will return EOVERLAP_AND_RESET if the + * source and destination strings overlap. + * + * Note that the second parameter is the total size of the buffer, not the + * remaining size. + * + * + * strDest Null-terminated destination string buffer. + * destMax Size of the destination string buffer. + * strSrc Null-terminated source string buffer. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid) or + * (strDest != NULL and strSrc is NULLL and destMax != 0 + * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN) + * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t wcscat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("wcscat_s"); + return ERANGE; + } + + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("wcscat_s"); + if (strDest != NULL) { + strDest[0] = L'\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + + return SecDoWcscat(strDest, destMax, strSrc); +} + + diff --git a/third_party/securec/src/wcscpy_s.c b/third_party/securec/src/wcscpy_s.c new file mode 100644 index 00000000..2c348d4b --- /dev/null +++ b/third_party/securec/src/wcscpy_s.c @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +static errno_t SecDoWcscpy(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) +{ + size_t srcStrLen; + + SECUREC_CALC_WSTR_LEN(strSrc, destMax, &srcStrLen); + if (srcStrLen == destMax) { + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("wcscpy_s"); + return ERANGE_AND_RESET; + } + if (strDest == strSrc) { + return EOK; + } + + if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, srcStrLen)) { + /* performance optimization srcStrLen include '\0' */ + SecDoMemcpy(strDest, strSrc, (srcStrLen + 1) * sizeof(wchar_t)); /* single character length include \0 */ + return EOK; + } else { + strDest[0] = L'\0'; + SECUREC_ERROR_BUFFER_OVERLAP("wcscpy_s"); + return EOVERLAP_AND_RESET; + } +} + +/* + * + * The wcscpy_s function copies the wide string pointed to by strSrc + * (including theterminating null wide character) into the array pointed to by strDest + + * + * strDest Destination string buffer + * destMax Size of the destination string buffer. + * strSrc Null-terminated source string buffer. + * + * + * strDest is updated. + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * EINVAL_AND_RESET strDest != NULL and strSrc is NULLL and destMax != 0 + * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET destMax <= length of strSrc and strDest != strSrc + * and strDest != NULL and strSrc != NULL and destMax != 0 + * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and destMax != 0 + * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * and strDest != NULL and strSrc !=NULL and strDest != strSrc + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t wcscpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("wcscpy_s"); + return ERANGE; + } + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("wcscpy_s"); + if (strDest != NULL) { + strDest[0] = L'\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + return SecDoWcscpy(strDest, destMax, strSrc); +} + + diff --git a/third_party/securec/src/wcsncat_s.c b/third_party/securec/src/wcsncat_s.c new file mode 100644 index 00000000..bc9e6e39 --- /dev/null +++ b/third_party/securec/src/wcsncat_s.c @@ -0,0 +1,118 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +/* + * Befor this function, the basic parameter checking has been done + */ +static errno_t SecDoWcsncat(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) +{ + size_t destLen; + size_t srcLen; + + /* To calculate the length of a wide character, the parameter must be a wide character */ + SECUREC_CALC_WSTR_LEN(strDest, destMax, &destLen); + SECUREC_CALC_WSTR_LEN(strSrc, count, &srcLen); + + if (SECUREC_CAT_STRING_IS_OVERLAP(strDest, destLen, strSrc, srcLen)) { + strDest[0] = L'\0'; + if (strDest + destLen <= strSrc && destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_BUFFER_OVERLAP("wcsncat_s"); + return EOVERLAP_AND_RESET; + } + if (srcLen + destLen >= destMax || strDest == strSrc) { + strDest[0] = L'\0'; + if (destLen == destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); + return EINVAL_AND_RESET; + } + SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); + return ERANGE_AND_RESET; + } + SecDoMemcpy(strDest + destLen, strSrc, srcLen * sizeof(wchar_t)); /* no terminator */ + *(strDest + destLen + srcLen) = L'\0'; + return EOK; +} + +/* + * + * The wcsncat_s function appends not more than n successive wide characters + * (not including the terminating null wide character) + * from the array pointed to by strSrc to the end of the wide string pointed to by strDest. + * + * The wcsncat_s function try to append the first D characters of strSrc to + * the end of strDest, where D is the lesser of count and the length of strSrc. + * If appending those D characters will fit within strDest (whose size is + * given as destMax) and still leave room for a null terminator, then those + * characters are appended, starting at the original terminating null of + * strDest, and a new terminating null is appended; otherwise, strDest[0] is + * set to the null character. + * + * + * strDest Null-terminated destination string. + * destMax Size of the destination buffer. + * strSrc Null-terminated source string. + * count Number of character to append, or truncate. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * EINVAL_AND_RESET (strDest unterminated and all other parameters are valid) or + * (strDest != NULL and strSrc is NULLL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN) + * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET strDest have not enough space and all other parameters are valid and not overlap + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t wcsncat_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); + return ERANGE; + } + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("wcsncat_s"); + if (strDest != NULL) { + strDest[0] = L'\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > SECUREC_WCHAR_STRING_MAX_LEN) { +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + if (count == ((size_t)-1)) { + /* Windows internal functions may pass in -1 when calling this function */ + return SecDoWcsncat(strDest, destMax, strSrc, destMax); + } +#endif + strDest[0] = L'\0'; + SECUREC_ERROR_INVALID_RANGE("wcsncat_s"); + return ERANGE_AND_RESET; + } + return SecDoWcsncat(strDest, destMax, strSrc, count); +} + + diff --git a/third_party/securec/src/wcsncpy_s.c b/third_party/securec/src/wcsncpy_s.c new file mode 100644 index 00000000..746b1d44 --- /dev/null +++ b/third_party/securec/src/wcsncpy_s.c @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define SECUREC_INLINE_DO_MEMCPY 1 + +#include "securecutil.h" + +static errno_t SecDoWcsncpy(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) +{ + size_t srcStrLen; + if (count < destMax) { + SECUREC_CALC_WSTR_LEN(strSrc, count, &srcStrLen); + } else { + SECUREC_CALC_WSTR_LEN(strSrc, destMax, &srcStrLen); + } + if (srcStrLen == destMax) { + strDest[0] = '\0'; + SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); + return ERANGE_AND_RESET; + } + if (strDest == strSrc) { + return EOK; + } + if (SECUREC_STRING_NO_OVERLAP(strDest, strSrc, srcStrLen)) { + /* performance optimization srcStrLen not include '\0' */ + SecDoMemcpy(strDest, strSrc, srcStrLen * sizeof(wchar_t)); + *(strDest + srcStrLen) = L'\0'; + return EOK; + } else { + strDest[0] = L'\0'; + SECUREC_ERROR_BUFFER_OVERLAP("wcsncpy_s"); + return EOVERLAP_AND_RESET; + } +} + +/* + * + * The wcsncpy_s function copies not more than n successive wide characters + * (not including the terminating null wide character) + * from the array pointed to by strSrc to the array pointed to by strDest + * + * + * strDest Destination string. + * destMax The size of the destination string, in characters. + * strSrc Source string. + * count Number of characters to be copied. + * + * + * strDest is updated + * + * + * EOK Success + * EINVAL strDest is NULL and destMax != 0 and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * EINVAL_AND_RESET strDest != NULL and strSrc is NULLL and destMax != 0 + * and destMax <= SECUREC_WCHAR_STRING_MAX_LEN + * ERANGE destMax > SECUREC_WCHAR_STRING_MAX_LEN or destMax is 0 + * ERANGE_AND_RESET count > SECUREC_WCHAR_STRING_MAX_LEN or + * (destMax <= length of strSrc and destMax <= count and strDest != strSrc + * and strDest != NULL and strSrc != NULL and destMax != 0 and + * destMax <= SECUREC_WCHAR_STRING_MAX_LEN and not overlap) + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and all parameters are valid + * + * + * If there is a runtime-constraint violation, strDest[0] will be set to the '\0' when strDest and destMax valid + */ +errno_t wcsncpy_s(wchar_t *strDest, size_t destMax, const wchar_t *strSrc, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_STRING_MAX_LEN) { + SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); + return ERANGE; + } + if (strDest == NULL || strSrc == NULL) { + SECUREC_ERROR_INVALID_PARAMTER("wcsncpy_s"); + if (strDest != NULL) { + strDest[0] = '\0'; + return EINVAL_AND_RESET; + } + return EINVAL; + } + if (count > SECUREC_WCHAR_STRING_MAX_LEN) { +#ifdef SECUREC_COMPATIBLE_WIN_FORMAT + if (count == (size_t)(-1)) { + return SecDoWcsncpy(strDest, destMax, strSrc, destMax - 1); + } +#endif + strDest[0] = '\0'; /* clear dest string */ + SECUREC_ERROR_INVALID_RANGE("wcsncpy_s"); + return ERANGE_AND_RESET; + } + + if (count == 0) { + strDest[0] = '\0'; + return EOK; + } + + return SecDoWcsncpy(strDest, destMax, strSrc, count); +} + diff --git a/third_party/securec/src/wcstok_s.c b/third_party/securec/src/wcstok_s.c new file mode 100644 index 00000000..99c524f0 --- /dev/null +++ b/third_party/securec/src/wcstok_s.c @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * FindBegin Wide character postion function + */ +static wchar_t *SecFindBeginW(wchar_t *strToken, const wchar_t *strDelimit) +{ + /* Find beginning of token (skip over leading delimiters). Note that + * there is no token if this loop sets string to point to the terminal null. + */ + wchar_t *token = strToken; + while (*token != L'\0') { + const wchar_t *ctl = strDelimit; + while (*ctl != L'\0' && *ctl != *token) { + ++ctl; + } + if (*ctl == L'\0') { + break; + } + ++token; + } + return token; +} + +/* + * FindBegin rest Wide character postion function + */ +static wchar_t *SecFindRestW(wchar_t *strToken, const wchar_t *strDelimit) +{ + /* Find the end of the token. If it is not the end of the string, + * put a null there. + */ + wchar_t *token = strToken; + while (*token != L'\0') { + const wchar_t *ctl = strDelimit; + while (*ctl != L'\0' && *ctl != *token) { + ++ctl; + } + if (*ctl != L'\0') { + *token++ = L'\0'; + break; + } + ++token; + } + return token; +} + +/* + * Update Token wide character function + */ +static wchar_t *SecUpdateTokenW(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context) +{ + /* point to updated position */ + wchar_t *token = SecFindRestW(strToken, strDelimit); + /* Update the context */ + *context = token; + /* Determine if a token has been found. */ + if (token == strToken) { + return NULL; + } + return strToken; +} + +/* + * + * wcstok_s + * + * + * + * The wcstok_s function is the wide-character equivalent of the strtok_s function + * + * + * strToken String containing token or tokens. + * strDelimit Set of delimiter characters. + * context Used to store position information between calls to + * wcstok_s. + * + * + * context is updated + * + * The wcstok_s function is the wide-character equivalent of the strtok_s function + */ +wchar_t *wcstok_s(wchar_t *strToken, const wchar_t *strDelimit, wchar_t **context) +{ + wchar_t *orgToken = strToken; + /* validation section */ + if (context == NULL || strDelimit == NULL) { + return NULL; + } + if (orgToken == NULL && (*context) == NULL) { + return NULL; + } + /* If string==NULL, continue with previous string */ + if (orgToken == NULL) { + orgToken = *context; + } + orgToken = SecFindBeginW(orgToken, strDelimit); + return SecUpdateTokenW(orgToken, strDelimit, context); +} + diff --git a/third_party/securec/src/wmemcpy_s.c b/third_party/securec/src/wmemcpy_s.c new file mode 100644 index 00000000..236fcce1 --- /dev/null +++ b/third_party/securec/src/wmemcpy_s.c @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securecutil.h" + +/* + * + * The wmemcpy_s function copies n successive wide characters + * from the object pointed to by src into the object pointed to by dest.t. + * + * + * dest Destination buffer. + * destMax Size of the destination buffer. + * src Buffer to copy from. + * count Number of characters to copy. + * + * + * dest buffer is uptdated. + * + * + * EOK Success + * EINVAL dest is NULL and destMax != 0 and count <= destMax + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN + * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN and count <= destMax + * ERANGE destMax > SECUREC_WCHAR_MEM_MAX_LEN or destMax is 0 or + * (count > destMax and dest is NULL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN) + * ERANGE_AND_RESET count > destMax and dest != NULL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN + * EOVERLAP_AND_RESET dest buffer and source buffer are overlapped and + * count <= destMax destMax != 0 and destMax <= SECUREC_WCHAR_MEM_MAX_LEN + * and dest != NULL and src != NULL and dest != src + * + * if an error occured, dest will be filled with 0 when dest and destMax valid . + * If the source and destination overlap, the behavior of wmemcpy_s is undefined. + * Use wmemmove_s to handle overlapping regions. + */ +errno_t wmemcpy_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_PARAMTER("wmemcpy_s"); + return ERANGE; + } + if (count > destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wmemcpy_s"); + if (dest != NULL) { + (void)memset(dest, 0, destMax * sizeof(wchar_t)); + return ERANGE_AND_RESET; + } + return ERANGE; + } + return memcpy_s(dest, destMax * sizeof(wchar_t), src, count * sizeof(wchar_t)); +} + diff --git a/third_party/securec/src/wmemmove_s.c b/third_party/securec/src/wmemmove_s.c new file mode 100644 index 00000000..2ef549a0 --- /dev/null +++ b/third_party/securec/src/wmemmove_s.c @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securecutil.h" + +/* + * + * The wmemmove_s function copies n successive wide characters from the object pointed + * to by src into the object pointed to by dest. + * + * + * dest Destination buffer. + * destMax Size of the destination buffer. + * src Source object. + * count Number of bytes or character to copy. + * + * + * dest is updated. + * + * + * EOK Success + * EINVAL dest is NULL and destMax != 0 and count <= destMax + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN + * EINVAL_AND_RESET dest != NULL and src is NULLL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN and count <= destMax + * ERANGE destMax > SECUREC_WCHAR_MEM_MAX_LEN or destMax is 0 or + * (count > destMax and dest is NULL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN) + * ERANGE_AND_RESET count > destMax and dest != NULL and destMax != 0 + * and destMax <= SECUREC_WCHAR_MEM_MAX_LEN + * + * + * If an error occured, dest will be filled with 0 when dest and destMax valid. + * If some regions of the source area and the destination overlap, wmemmove_s + * ensures that the original source bytes in the overlapping region are copied + * before being overwritten + */ +errno_t wmemmove_s(wchar_t *dest, size_t destMax, const wchar_t *src, size_t count) +{ + if (destMax == 0 || destMax > SECUREC_WCHAR_MEM_MAX_LEN) { + SECUREC_ERROR_INVALID_PARAMTER("wmemmove_s"); + return ERANGE; + } + if (count > destMax) { + SECUREC_ERROR_INVALID_PARAMTER("wmemmove_s"); + if (dest != NULL) { + (void)memset(dest, 0, destMax * sizeof(wchar_t)); + return ERANGE_AND_RESET; + } + return ERANGE; + } + return memmove_s(dest, destMax * sizeof(wchar_t), src, count * sizeof(wchar_t)); +} + diff --git a/third_party/securec/src/wscanf_s.c b/third_party/securec/src/wscanf_s.c new file mode 100644 index 00000000..c1dcce27 --- /dev/null +++ b/third_party/securec/src/wscanf_s.c @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "securec.h" + +/* + * + * + * The wscanf_s function is the wide-character equivalent of the scanf_s function + * The wscanf_s function reads data from the standard input stream stdin and + * writes the data into the location that's given by argument. Each argument + * must be a pointer to a variable of a type that corresponds to a type specifier + * in format. If copying occurs between strings that overlap, the behavior is + * undefined. + * + * + * format Format control string. + * ... Optional arguments. + * + * + * ... the converted value stored in user assigned address + * + * + * Returns the number of fields successfully converted and assigned; + * the return value does not include fields that were read but not assigned. + * A return value of 0 indicates that no fields were assigned. + * return -1 if an error occurs. + */ + +int wscanf_s(const wchar_t *format, ...) +{ + int ret; /* If initialization causes e838 */ + va_list argList; + + va_start(argList, format); + ret = vwscanf_s(format, argList); + va_end(argList); + (void)argList; /* to clear e438 last value assigned not used , the compiler will optimize this code */ + + return ret; +} +