diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc index 00e45a839c..af155c09b1 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc +++ b/mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc @@ -355,6 +355,12 @@ void DataDumper::RtLoadDumpData(const aicpu::dump::OpMappingInfo &dump_info, voi } } +void SetDumpShape(const std::vector &ms_shape, NotNull dump_shape) { + for (auto &dim : ms_shape) { + dump_shape->add_dim(dim); + } +} + void DataDumper::DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull task) { if (!DumpJsonParser::GetInstance().OutputNeedDump()) { MS_LOG(INFO) << "Skip dump output"; @@ -368,14 +374,14 @@ void DataDumper::DumpKernelOutput(const CNodePtr &kernel, void *args, NotNulladd_dim(dim); - } + SetDumpShape(output_shape, NOT_NULL(output.mutable_shape())); + SetDumpShape(output_origin_shape, NOT_NULL(output.mutable_origin_shape())); + output.set_original_output_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size())); output.set_address(static_cast(reinterpret_cast(args)) + offset); // device address data size @@ -409,13 +415,13 @@ void DataDumper::DumpKernelInput(const CNodePtr &kernel, void *args, NotNulladd_dim(dim); - } + SetDumpShape(output_shape, NOT_NULL(input.mutable_shape())); + SetDumpShape(output_origin_shape, NOT_NULL(input.mutable_origin_shape())); + input.set_address(static_cast(reinterpret_cast(args)) + offset); // device address data size auto address = AnfAlgo::GetPrevNodeOutputAddr(kernel, i); diff --git a/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto b/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto index d3377c655d..df019c153d 100644 --- a/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto +++ b/mindspore/ccsrc/runtime/device/ascend/dump/proto/op_mapping_info.proto @@ -31,7 +31,8 @@ message Output { int32 original_output_data_type = 7; int32 original_output_format = 8; uint64 size = 9; -}; + Shape origin_shape = 10; +} message Input { int32 data_type = 1; @@ -39,12 +40,13 @@ message Input { Shape shape = 3; uint64 address = 4; uint64 size = 5; + Shape origin_shape = 6; } message Op { string op_name = 1; string op_type = 2; -}; +} message Task { uint32 task_id = 1; @@ -53,7 +55,7 @@ message Task { repeated Output output = 4; bool end_graph = 5; repeated Input input = 6; -}; +} message OpMappingInfo { string dump_path = 1; @@ -75,4 +77,4 @@ message OpMappingInfo { uint32 flag = 7; // 0x01 load, 0x00 unload repeated Task task = 8; string dump_step = 9; -}; +}