From fc07cd908e85415e1e695fb63e2f8e9612b9f50c Mon Sep 17 00:00:00 2001 From: liubuyu Date: Fri, 17 Apr 2020 18:07:36 +0800 Subject: [PATCH] add 6d format transfer --- mindspore/ccsrc/common/trans.cc | 243 +++++++++++++++--- mindspore/ccsrc/common/trans.h | 2 + .../device/ascend/ascend_device_address.cc | 14 +- mindspore/ccsrc/utils/utils.h | 6 +- 4 files changed, 228 insertions(+), 37 deletions(-) diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index a2b9f7ef24..5c982166dd 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -231,7 +231,98 @@ std::vector PaddingShapeTo4d(const std::vector &shape, const std return shape_4d; } +namespace { +bool CheckDims(const std::vector &shape) { + if (shape.size() != 4) { + MS_LOG(ERROR) << "Host shape dims shoud be 4"; + return false; + } + return true; +} + +std::vector NchwDeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + return shape; +} + +std::vector NhwcDeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Ccheck dims failed."; + } + std::vector device_shape; + device_shape.push_back(shape[0]); + device_shape.push_back(shape[2]); + device_shape.push_back(shape[3]); + device_shape.push_back(shape[1]); + return device_shape; +} + +std::vector HwchDeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + device_shape.push_back(shape[2]); + device_shape.push_back(shape[3]); + device_shape.push_back(shape[1]); + device_shape.push_back(shape[0]); + return device_shape; +} + +std::vector FracZDeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; + size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; + device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); + device_shape.push_back(cout16 / kCubeSize); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); + return device_shape; +} + +std::vector Nc1hwc0DeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; + size_t C0 = kCubeSize; + device_shape.push_back(shape[0]); + device_shape.push_back(C1); + device_shape.push_back(shape[2]); + device_shape.push_back(shape[3]); + device_shape.push_back(C0); + return device_shape; +} + +std::vector C1hwncoc0DeviceShape(const std::vector &shape) { + if (!CheckDims(shape)) { + MS_LOG(EXCEPTION) << "Check dims failed."; + } + std::vector device_shape; + device_shape.push_back((shape[1] - 1) / kCubeSize + 1); + device_shape.push_back(shape[2]); + device_shape.push_back(shape[3]); + device_shape.push_back(shape[0]); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); + return device_shape; +} +} // namespace + std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { + using DeviceShapeTransfer = std::function(const std::vector &)>; + const std::map device_shape_map{ + {kOpFormat_NCHW, NchwDeviceShape}, {kOpFormat_NHWC, NhwcDeviceShape}, + {kOpFormat_HWCN, HwchDeviceShape}, {kOpFormat_FRAC_Z, FracZDeviceShape}, + {kOpFormat_NC1HWC0, Nc1hwc0DeviceShape}, {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, + }; + if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; } @@ -255,37 +346,31 @@ std::vector TransShapeToDevice(const std::vector &shape, const s MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; temp_shape = PaddingShapeTo4dByDefault(shape); } - if (format == kOpFormat_NC1HWC0) { - size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize; - size_t C0 = kCubeSize; - device_shape.push_back(temp_shape[0]); - device_shape.push_back(C1); - device_shape.push_back(temp_shape[2]); - device_shape.push_back(temp_shape[3]); - device_shape.push_back(C0); - return device_shape; - } else if (format == kOpFormat_FRAC_Z) { - size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; - size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; - device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize); - device_shape.push_back(cout16 / kCubeSize); - device_shape.push_back(kCubeSize); - device_shape.push_back(kCubeSize); - return device_shape; - } else if (format == kOpFormat_NHWC) { - device_shape.push_back(temp_shape[0]); - device_shape.push_back(temp_shape[2]); - device_shape.push_back(temp_shape[3]); - device_shape.push_back(temp_shape[1]); - return device_shape; - } else if (format == kOpFormat_HWCN) { - return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]}; - } else if (format == kOpFormat_NCHW) { - return temp_shape; + auto iter = device_shape_map.find(format); + if (iter != device_shape_map.end()) { + return iter->second(temp_shape); } MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; } +bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { + if (args.host_shape.size() != kNchwDims) { + MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; + return false; + } + *size = TypeIdSize(args.src_data_type); + if (*size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + return false; + } + *total_size = ShapeSize(args.device_shape) * (*size); + if (*total_size != args.device_size) { + MS_LOG(ERROR) << "Illegal total data size, total_size:" << *total_size << ", device_size:" << args.device_size; + return false; + } + return true; +} + bool TransDataType(const TypeIdArgs &args, void *result) { MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " << TypeIdLabel(args.device_data_type); @@ -320,13 +405,14 @@ bool TransFormat(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } - if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) && - args.device_format == kOpFormat_FRAC_Z) { + if (args.device_format == kOpFormat_FRAC_Z) { return NchwToFracZ(args, result); } else if (args.device_format == kOpFormat_FRAC_NZ) { return NchwToFracNz(args, result); } else if (args.device_format == kOpFormat_NC1HWC0) { return NchwToNc1hwc0(args, result); + } else if (args.device_format == kOpFormat_C1HWNCoC0) { + return NchwToC1hwncoc0(args, result); } return true; } @@ -337,13 +423,14 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid datatype.."; return false; } - if ((args.host_format == kOpFormat_NCHW || args.host_format == kOpFormat_ND) && - args.device_format == kOpFormat_FRAC_Z) { + if (args.device_format == kOpFormat_FRAC_Z) { return FracZToNchw(args, result); } else if (args.device_format == kOpFormat_FRAC_NZ) { return FracNzToNchw(args, result); } else if (args.device_format == kOpFormat_NC1HWC0) { return Nc1hwc0ToNchw(args, result); + } else if (args.device_format == kOpFormat_C1HWNCoC0) { + return C1hwncoc0ToNchw(args, result); } return true; } @@ -801,5 +888,99 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { } return true; } + +bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { + // trans nchw to c1hwncoc0 + MS_LOG(DEBUG) << "Trans format from nchw to c1hwncoc0."; + MS_EXCEPTION_IF_NULL(result); + size_t size = 0; + size_t total_size = 0; + if (!CheckArgs(args, &size, &total_size)) { + MS_LOG(ERROR) << "Check args failed."; + return false; + } + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto h = args.host_shape[2]; + auto w = args.host_shape[3]; + auto c1 = args.device_shape[0]; + auto co = args.device_shape[4]; + auto c0 = args.device_shape[5]; + for (size_t c1_i = 0; c1_i < c1; c1_i++) { + for (size_t h_i = 0; h_i < h; h_i++) { + for (size_t w_i = 0; w_i < w; w_i++) { + for (size_t n_i = 0; n_i < n; n_i++) { + for (size_t co_i = 0; co_i < co; co_i++) { + for (size_t c0_i = 0; c0_i < c0; c0_i++) { + size_t dst_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + + n_i * co * c0 + co_i * c0 + c0_i) * + size; + size_t protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) + ? total_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); + size_t c_i = c0_i + c1_i * c0; + size_t src_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size; + error_t ret; + if (c_i < c && c0_i == co_i) { + ret = memcpy_s(static_cast(result) + dst_offset, protected_size, + static_cast(args.data) + src_offset, size); + } else { + ret = memset_s(static_cast(result) + dst_offset, protected_size, 0, size); + } + if (ret != EOK) { + MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret; + return false; + } + } + } + } + } + } + } + return true; +} + +bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { + // trans c1hwncoc0 to nchw + MS_LOG(DEBUG) << "Trans format from c1hwncoc0 to nchw"; + MS_EXCEPTION_IF_NULL(result); + size_t size = 0; + size_t total_size = 0; + if (!CheckArgs(args, &size, &total_size)) { + MS_LOG(ERROR) << "Check args failed."; + return false; + } + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto h = args.host_shape[2]; + auto w = args.host_shape[3]; + auto co = args.device_shape[4]; + auto c0 = args.device_shape[5]; + for (size_t n_i = 0; n_i < n; n_i++) { + for (size_t c_i = 0; c_i < c; c_i++) { + for (size_t h_i = 0; h_i < h; h_i++) { + for (size_t w_i = 0; w_i < w; w_i++) { + size_t dst_offset = (n_i * c * h * w + c_i * h * w + h_i * w + w_i) * size; + size_t c1_i = c_i / kCubeSize; + size_t c0_i = c_i % kCubeSize; + size_t co_i = c0_i; + size_t src_offset = (c1_i * h * w * n * co * c0 + h_i * w * n * co * c0 + w_i * n * co * c0 + n_i * co * c0 + + co_i * c0 + c0_i) * + size; + size_t protected_size = total_size - dst_offset < static_cast(SECUREC_MEM_MAX_LEN) + ? total_size - dst_offset + : static_cast(SECUREC_MEM_MAX_LEN); + auto ret = memcpy_s(static_cast(result) + dst_offset, protected_size, + static_cast(args.data) + src_offset, size); + if (ret != EOK) { + MS_LOG(ERROR) << "Failed to operate the dst memory, error-code:" << ret; + return false; + } + } + } + } + } + return true; +} } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 4bebdde814..054fa89a06 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -63,10 +63,12 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result); bool NchwToFracZ(const FormatArgs &args, void *result); bool NchwToFracNz(const FormatArgs &args, void *result); bool NchwToNc1hwc0(const FormatArgs &args, void *result); +bool NchwToC1hwncoc0(const FormatArgs &args, void *result); // device to host bool FracZToNchw(const FormatArgs &args, void *result); bool FracNzToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); +bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 69d1918163..f0a30e4c42 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -114,8 +114,11 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t return false; } } - } else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) { - sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr); + } else { + auto iter = kNeedTransFormatSet.find(format_); + if (iter != kNeedTransFormatSet.end()) { + sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } } if (!sync_ok) { MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) @@ -199,8 +202,11 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); } - } else if (format_ == kOpFormat_NC1HWC0 || format_ == kOpFormat_FRAC_Z || format_ == kOpFormat_FRAC_NZ) { - sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } else { + auto iter = kNeedTransFormatSet.find(format_); + if (iter != kNeedTransFormatSet.end()) { + sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr); + } } if (!sync_ok) { MS_LOG(ERROR) << "Not support to trans, dev_format:" << format_ << ", dev_type:" << TypeIdLabel(type_id_) diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 9405d0d334..10ef4abf62 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -186,8 +186,10 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ"; constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; -const std::set k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0}; +const std::set k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, + kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, + kOpFormat_C1HWNCoC0}; + const std::set k2DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0}; const std::set k3DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0};