diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc index fd5185a993..4329a43e33 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -37,7 +37,7 @@ constexpr char kIterEndNode[] = "PROFILING_ITER_END"; // PROFILING_CUSTOM_LOGID_START 3 constexpr uint64_t kProfilingFpStartLogId = 1; constexpr uint64_t kProfilingBpEndLogId = 2; -constexpr uint64_t kProfilingIterEndLogId = 255; +constexpr uint64_t kProfilingIterEndLogId = 65535; std::map> ProfilingUtils::graph_profiling_cnode_; std::map> ProfilingUtils::graph_kernel_name_; std::map>> ProfilingUtils::graph_point_; diff --git a/mindspore/profiler/parser/step_trace_parser.py b/mindspore/profiler/parser/step_trace_parser.py index 82c6799f22..32c244ad4a 100644 --- a/mindspore/profiler/parser/step_trace_parser.py +++ b/mindspore/profiler/parser/step_trace_parser.py @@ -105,6 +105,8 @@ class BaseStepTraceParser: Args: point_info (dict): The point info about tag id and relative op name. """ + self._get_step_trace_files() + self._get_step_end_tag_id() tag_map = {} for tag, op_name in point_info.items(): op_type = self._get_op_type(tag, op_name) @@ -123,13 +125,13 @@ class BaseStepTraceParser: Returns: str, the op type or communication op name. """ - tag_map = {self._fp_tag: 'fp', self._bp_tag: 'bp', self._end_tag: 'end'} + tag_map = {self._fp_tag: 'fp', self._bp_tag: 'bp', self._step_end_tag_id: 'end'} # get solid tag type op_type = tag_map.get(tag, '') if op_type: return op_type # check if the tag is step tag. - if tag > self._end_tag or tag == 0: + if tag > self._step_end_tag_id or tag == 0: return 'start' # analyze the reduce tag op_name = name.rsplit('/', 1)[-1] @@ -477,7 +479,7 @@ class AscendStepTraceParser(BaseStepTraceParser): _event_size = 20 _fp_tag = 1 _bp_tag = 2 - _end_tag = 255 + _step_trace_files = [] def record_point_info(self, point_info, output_path): """ @@ -513,6 +515,9 @@ class AscendStepTraceParser(BaseStepTraceParser): def _get_step_trace_files(self): """Get step trace files.""" # step trace files may under $profiler_dir or $profiler_dir/data + if self._step_trace_files: + return self._step_trace_files + profiler_dir = self._input_dir step_trace_files = self._search_file(profiler_dir) if not step_trace_files: @@ -521,17 +526,21 @@ class AscendStepTraceParser(BaseStepTraceParser): step_trace_files = self._search_file(profiler_dir) if not step_trace_files: raise ProfilerPathErrorException('Training trace file does not exist.') + self._step_trace_files = step_trace_files return step_trace_files - def _get_step_end_tag_id(self, source_files): + def _get_step_end_tag_id(self): """ Get step end tag id.This id is 255 before 2020.12.16,and 65535 now. File is an old version if there is no 65535 tag id, or it is a new version. """ + if not self._step_trace_files: + return + step_num = 0 - source_file = validate_and_normalize_path(source_files[0]) + source_file = validate_and_normalize_path(self._step_trace_files[0]) try: with open(source_file, 'rb') as handler: content = handler.read() @@ -555,8 +564,6 @@ class AscendStepTraceParser(BaseStepTraceParser): log.info("Start to parse step trace file.") event_info = {} - self._get_step_end_tag_id(source_files) - for source_file in source_files: source_file = validate_and_normalize_path(source_file) try: