!9990 Fix scalar tensor print

From: @huangbingjian
Reviewed-by: @zhunaipan,@ginfung,@zh_qh
Signed-off-by: @zh_qh
pull/9990/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f8aada52be

@ -103,7 +103,7 @@ template <typename T>
void PrintScalarToString(const char *str_data_ptr, const string &tensor_type, std::ostringstream *const buf) {
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
const T *data_ptr = reinterpret_cast<const T *>(str_data_ptr);
if constexpr (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
const int int_data = static_cast<int>(*data_ptr);
@ -117,7 +117,7 @@ void PrintScalarToBoolString(const char *str_data_ptr, const string &tensor_type
MS_EXCEPTION_IF_NULL(str_data_ptr);
MS_EXCEPTION_IF_NULL(buf);
const bool *data_ptr = reinterpret_cast<const bool *>(str_data_ptr);
*buf << "Tensor(shape=[1], dtype=" << GetParseType(tensor_type) << ", value=";
*buf << "Tensor(shape=[], dtype=" << GetParseType(tensor_type) << ", value=";
if (*data_ptr) {
*buf << "True)\n";
} else {

@ -25,7 +25,7 @@ expect_array = {'Bool': '\n[[ True False]\n [False True]]', 'UInt': '\n[[1 2 3]
'[ *.********e*** **.********e*** *.********e***]]'}
def get_expect_value(res):
if res[0] == '[1]':
if res[0] == '[]':
if res[1] == 'Bool':
return expect_scalar['Bool']
if res[1] in ['Uint8', 'Uint16', 'Uint32', 'Uint64']:

Loading…
Cancel
Save