|
|
|
@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
|
|
|
|
|
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
|
|
|
|
sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_);
|
|
|
|
|
} else {
|
|
|
|
|
auto shape_size = trans::ShapeSize(host_shape);
|
|
|
|
|
auto host_size = trans::ShapeSize(host_shape);
|
|
|
|
|
auto host = std::vector<uint8_t>(size_);
|
|
|
|
|
SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST);
|
|
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type};
|
|
|
|
|
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size};
|
|
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr);
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans data type failed.";
|
|
|
|
@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
|
|
|
|
|
auto host = std::vector<uint8_t>(size_);
|
|
|
|
|
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data());
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans format failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans format failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto shape_size = trans::ShapeSize(host_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type};
|
|
|
|
|
auto host_size = trans::ShapeSize(host_shape);
|
|
|
|
|
auto device_size = trans::ShapeSize(device_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size};
|
|
|
|
|
sync_ok = trans::TransDataType(type_args, host_ptr);
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans format failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans format failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector<int
|
|
|
|
|
host_shape, device_shape, type_id_};
|
|
|
|
|
sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr);
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans format failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans format failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -192,12 +193,12 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
|
|
|
|
|
} else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
|
|
|
|
|
sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size);
|
|
|
|
|
} else {
|
|
|
|
|
auto shape_size = trans::ShapeSize(host_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_};
|
|
|
|
|
auto host_size = trans::ShapeSize(host_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size};
|
|
|
|
|
auto host_tmp = std::vector<uint8_t>(size_);
|
|
|
|
|
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans data type failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans data type failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|
|
|
|
device_shape = trans::TransShapeToDevice(host_shape, format_);
|
|
|
|
|
}
|
|
|
|
|
if (type_id_ != type) {
|
|
|
|
|
auto shape_size = trans::ShapeSize(host_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_};
|
|
|
|
|
auto host_size = trans::ShapeSize(host_shape);
|
|
|
|
|
auto device_size = trans::ShapeSize(device_shape);
|
|
|
|
|
const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, device_size};
|
|
|
|
|
auto host_tmp = std::vector<uint8_t>(size_);
|
|
|
|
|
sync_ok = trans::TransDataType(type_args, host_tmp.data());
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans datatype failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans datatype failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_,
|
|
|
|
@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|
|
|
|
auto dst_tmp = std::vector<uint8_t>(size_);
|
|
|
|
|
sync_ok = trans::TransFormat(format_args, dst_tmp.data());
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans format failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans format failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector<int
|
|
|
|
|
auto host_tmp = std::vector<uint8_t>(size_);
|
|
|
|
|
sync_ok = trans::TransFormat(format_args, host_tmp.data());
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "trans format failed.";
|
|
|
|
|
MS_LOG(ERROR) << "Trans format failed.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
|