|
|
|
@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
|
|
|
|
|
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
|
|
|
|
|
} else {
|
|
|
|
|
auto iter = kNeedTransFormatSet.find(format_);
|
|
|
|
|
if (iter != kNeedTransFormatSet.end()) {
|
|
|
|
|
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
|
|
|
|
@ -199,8 +202,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
|
|
|
|
|
}
|
|
|
|
|
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
|
|
|
|
|
}
|
|
|
|
|
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
|
|
|
|
|
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
|
|
|
|
} else {
|
|
|
|
|
auto iter = kNeedTransFormatSet.find(format_);
|
|
|
|
|
if (iter != kNeedTransFormatSet.end()) {
|
|
|
|
|
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!sync_ok) {
|
|
|
|
|
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
|
|
|
|
|