!414 add 6d format transfer

Merge pull request !414 from liubuyu/dev_lby
pull/414/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 41c969ab00

File diff suppressed because it is too large Load Diff

@ -63,10 +63,12 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
// device to host
bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
} // namespace trans
} // namespace mindspore

@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector<int> &shape, size_t
return false;
}
}
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
}
}
if (!sync_ok) {
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)
@ -199,8 +202,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector<int> &shape, size_t
}
SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE);
}
} else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
} else {
auto iter = kNeedTransFormatSet.find(format_);
if (iter != kNeedTransFormatSet.end()) {
sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr);
}
}
if (!sync_ok) {
MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_)

@ -186,8 +186,10 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ";
constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ";
constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0";
constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04";
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0};
const std::set<std::string> k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_C1HWNCoC0};
const std::set<std::string> k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z,
kOpFormat_NC1KHKWHWC0};
const std::set<std::string> k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};

Loading…
Cancel
Save