From 311e7be605ef979f5680e80a707b5e7fe2032c3c Mon Sep 17 00:00:00 2001 From: HuangBingjian Date: Tue, 15 Dec 2020 16:29:21 +0800 Subject: [PATCH] fix scalar tensor shape=[] --- mindspore/ccsrc/utils/tensorprint_utils.cc | 4 ++-- tests/st/ops/ascend/test_tensor_print/test_tensor_print.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index 5f52d5d56a..a92d1e2e46 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -103,7 +103,7 @@ template void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) { MS_EXCEPTION_IF_NULL(str_data_ptr); MS_EXCEPTION_IF_NULL(buf); - *buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value="; + *buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value="; const T *data_ptr = reinterpret_cast(str_data_ptr); if constexpr (std::is_same::value || std::is_same::value) { const int int_data = static_cast(*data_ptr); @@ -117,7 +117,7 @@ void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type MS_EXCEPTION_IF_NULL(str_data_ptr); MS_EXCEPTION_IF_NULL(buf); const bool *data_ptr = reinterpret_cast(str_data_ptr); - *buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value="; + *buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value="; if (*data_ptr) { *buf << "True)\n"; } else { diff --git a/tests/st/ops/ascend/test_tensor_print/test_tensor_print.py b/tests/st/ops/ascend/test_tensor_print/test_tensor_print.py index 5efd5ed230..5e676c926f 100644 --- a/tests/st/ops/ascend/test_tensor_print/test_tensor_print.py +++ b/tests/st/ops/ascend/test_tensor_print/test_tensor_print.py @@ -25,7 +25,7 @@ expect_array = {'Bool': '\n[[ True False]\n [False True]]', 'UInt': '\n[[1 2 3] '[ *.********e*** **.********e*** *.********e***]]'} def get_expect_value(res): - if res[0] == '[1]': + if res[0] == '[]': if res[1] == 'Bool': return expect_scalar['Bool'] if res[1] in ['Uint8', 'Uint16', 'Uint32', 'Uint64']: