From 0b6b5e5123f9c9f73284c25fe1027bb6e434056a Mon Sep 17 00:00:00 2001 From: liubuyu Date: Thu, 23 Apr 2020 17:03:42 +0800 Subject: [PATCH] fix codedex warning --- mindspore/ccsrc/common/trans.cc | 44 ++++++++----------- mindspore/ccsrc/common/trans.h | 10 ++--- .../device/ascend/ascend_device_address.cc | 18 ++++---- 3 files changed, 31 insertions(+), 41 deletions(-) diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 1174be1f48..3e8d922971 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -101,13 +101,20 @@ const std::map, DataTypeTransMode> mode_map{ {std::pair(kNumberTypeInt64, kNumberTypeInt32), FROM_INT64_TO_INT32}, {std::pair(kNumberTypeUInt16, kNumberTypeInt32), FROM_UINT16_TO_INT32}}; -template -void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { - auto src_id = TypeIdSize(args.src_type); - auto dst_id = TypeIdSize(args.dst_type); - if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { +void CheckMemSize(const TypeIdArgs &args) { + auto src_type_size = TypeIdSize(args.host_data_type); + auto dst_type_size = TypeIdSize(args.device_data_type); + if (src_type_size < 1 || dst_type_size < 1) { + MS_LOG(EXCEPTION) << "Invalid src or dst data type."; + } + if (args.data_size / src_type_size != args.host_shape_size) { MS_LOG(EXCEPTION) << "Invalid src or dst data size."; } +} + +template +void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { + CheckMemSize(args); for (size_t idx = 0; idx != data_size; idx++) { SrcT src_data = static_cast(args.data)[idx]; static_cast(dst)[idx] = static_cast(src_data); @@ -116,11 +123,7 @@ void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) template void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { - auto src_id = TypeIdSize(args.src_type); - auto dst_id = TypeIdSize(args.dst_type); - if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { - MS_LOG(EXCEPTION) << "Invalid src or dst data size."; - } + CheckMemSize(args); auto src_data = static_cast(args.data); auto half_data = static_cast(dst); for (size_t i = 0; i < data_size; i++) { @@ -394,27 +397,18 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { } bool TransDataType(const TypeIdArgs &args, void *result) { - MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); + MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " + << TypeIdLabel(args.device_data_type); MS_EXCEPTION_IF_NULL(result); - std::pair type_info(args.src_type, args.dst_type); + std::pair type_info(args.host_data_type, args.device_data_type); auto iter = mode_map.find(type_info); if (iter == mode_map.end()) { - MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) - << ", dst_type:" << TypeIdLabel(args.dst_type); + MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) + << ", dst_type:" << TypeIdLabel(args.device_data_type); return false; } auto trans_mode = iter->second; - auto src_id = TypeIdSize(args.src_type); - auto dst_id = TypeIdSize(args.dst_type); - if (src_id < 1 || dst_id < 1) { - MS_LOG(ERROR) << "Invalid src or dst data type."; - return false; - } - if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { - MS_LOG(ERROR) << "Invalid src or dst data size."; - return false; - } - if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { + if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { MS_LOG(ERROR) << "Failed to trans datatype.."; return false; } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index e6e81ed359..0593466c38 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -31,12 +31,10 @@ namespace mindspore { namespace trans { struct TypeIdArgs { const void *data; - size_t src_size; - size_t dst_size; - TypeId src_type; - TypeId dst_type; - size_t src_shape_size; - size_t dst_shape_size; + size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d + TypeId host_data_type; + TypeId device_data_type; + size_t data_size; }; struct FormatArgs { diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index df49400341..1f452ce9e2 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); } else { - auto host_size = trans::ShapeSize(host_shape); + auto shape_size = trans::ShapeSize(host_shape); auto host = std::vector(size_); SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size}; + const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type, size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { MS_LOG(ERROR) << "trans data type failed."; @@ -156,9 +156,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); } else { - auto host_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size}; + auto shape_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_, size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { @@ -235,9 +234,8 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) {