diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index b4e02c8fe6..1174be1f48 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -103,17 +103,39 @@ const std::map, DataTypeTransMode> mode_map{ template void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } for (size_t idx = 0; idx != data_size; idx++) { SrcT src_data = static_cast(args.data)[idx]; static_cast(dst)[idx] = static_cast(src_data); } } +template +void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } + auto src_data = static_cast(args.data); + auto half_data = static_cast(dst); + for (size_t i = 0; i < data_size; i++) { + half_data[i] = Eigen::half(src_data[i]); + } +} + bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { switch (mode) { case FROM_FLOAT_TO_FLOAT16: device::FloatToHalf(dst, args.data, data_size); break; + case FROM_INT32_TO_FLOAT16: + TransDataSrc2Fp16(args, dst, data_size); + break; case FROM_FLOAT16_TO_FLOAT: device::HalfToFloat(dst, args.data, data_size); break; @@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { } bool TransDataType(const TypeIdArgs &args, void *result) { - MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " - << TypeIdLabel(args.device_data_type); + MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); MS_EXCEPTION_IF_NULL(result); - std::pair type_info(args.host_data_type, args.device_data_type); + std::pair type_info(args.src_type, args.dst_type); auto iter = mode_map.find(type_info); if (iter == mode_map.end()) { - MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) - << ", dst_type:" << TypeIdLabel(args.device_data_type); + MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) + << ", dst_type:" << TypeIdLabel(args.dst_type); return false; } auto trans_mode = iter->second; - auto type_size = TypeIdSize(args.device_data_type); - if (type_size < 1) { - MS_LOG(ERROR) << "Invalid host data type."; + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (src_id < 1 || dst_id < 1) { + MS_LOG(ERROR) << "Invalid src or dst data type."; return false; } - if (args.host_shape_size < 1) { - MS_LOG(ERROR) << "Invalid host data size."; + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(ERROR) << "Invalid src or dst data size."; return false; } - if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { + if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { MS_LOG(ERROR) << "Failed to trans datatype.."; return false; } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 054fa89a06..e6e81ed359 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -31,9 +31,12 @@ namespace mindspore { namespace trans { struct TypeIdArgs { const void *data; - size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d - TypeId host_data_type; - TypeId device_data_type; + size_t src_size; + size_t dst_size; + TypeId src_type; + TypeId dst_type; + size_t src_shape_size; + size_t dst_shape_size; }; struct FormatArgs { diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 79241df612..df49400341 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); } else { - auto shape_size = trans::ShapeSize(host_shape); + auto host_size = trans::ShapeSize(host_shape); auto host = std::vector(size_); SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { MS_LOG(ERROR) << "trans data type failed."; @@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector(size_); sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + auto host_size = trans::ShapeSize(host_shape); + auto device_size = trans::ShapeSize(device_shape); + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } } else { @@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); } else { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_}; + auto host_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans data type failed."; + MS_LOG(ERROR) << "Trans data type failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans datatype failed."; + MS_LOG(ERROR) << "Trans datatype failed."; return false; } const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, @@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, dst_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);