diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 55e4761036..a9ce32c8df 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -14,11 +14,9 @@ * limitations under the License. */ #include "common/trans.h" -#include #include #include #include -#include "./securec.h" #include "common/utils.h" #include "session/anf_runtime_algorithm.h" #include "kernel/kernel.h" @@ -29,34 +27,7 @@ namespace mindspore { namespace trans { -namespace { -std::vector PaddingShapeTo4dByDefault(const std::vector &shape) { - std::vector shape_4d(4, 1); - switch (shape.size()) { - case 0: - return shape_4d; - case 1: - shape_4d[1] = shape[0]; - break; - case 2: - shape_4d[1] = shape[0]; - shape_4d[2] = shape[1]; - break; - case 3: - shape_4d[1] = shape[0]; - shape_4d[2] = shape[1]; - shape_4d[3] = shape[2]; - break; - case 4: - std::copy(shape.begin(), shape.end(), shape_4d.begin()); - break; - default: - MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); - } - return shape_4d; -} -} // namespace -const size_t kNchwDims = 4; +enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, @@ -84,7 +55,10 @@ inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, template T DivCeil(T n1, T n2) { - return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; + if (n2 != 0) { + return (n1 - 1) / n2 + 1; + } + return 0; } enum DataTypeTransMode { @@ -226,8 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) { } size_t ShapeSize(const std::vector &shape) { - size_t product = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - return product; + return std::accumulate(shape.begin(), shape.end(), IntToSize(1), std::multiplies()); } size_t TypeIdSize(const TypeId data_type) { @@ -239,57 +212,9 @@ size_t TypeIdSize(const TypeId data_type) { return unsupported_type_error; } -bool IsNeedPadding(const std::string &format, const size_t shape_size) { - if (shape_size == 0) { - return false; - } - if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { - return false; - } else if (shape_size < 4) { - return true; - } - return false; -} - -std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { - std::vector shape; - std::vector host_shape; - if (node->isa()) { - auto value_node = node->cast(); - auto node_value = value_node->value(); - auto tensor = node_value->cast(); - if (tensor == nullptr) { - MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert "; - } - auto shape_temp = tensor->shape(); - (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize); - if (host_shape.empty()) { - host_shape.push_back(1); - } - } else { - host_shape = AnfAlgo::GetOutputInferShape(node, index); - } - if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { - host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); - } - std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); - return shape; -} - -std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis) { - if (padding_axis.empty() || shape.size() != padding_axis.size()) { - return PaddingShapeTo4dByDefault(shape); - } - std::vector shape_4d(4, 1); - for (size_t index = 0; index < padding_axis.size(); index++) { - shape_4d[padding_axis[index]] = shape[index]; - } - return shape_4d; -} - namespace { bool CheckDims(const std::vector &shape) { - if (shape.size() != 4) { + if (shape.size() != kNchwDims) { MS_LOG(ERROR) << "Host shape dims shoud be 4"; return false; } @@ -308,10 +233,10 @@ std::vector NhwcDeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Ccheck dims failed."; } std::vector device_shape; - device_shape.push_back(shape[0]); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); - device_shape.push_back(shape[1]); + device_shape.push_back(shape[kN]); + device_shape.push_back(shape[kH]); + device_shape.push_back(shape[kW]); + device_shape.push_back(shape[kC]); return device_shape; } @@ -320,10 +245,10 @@ std::vector HwchDeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); - device_shape.push_back(shape[1]); - device_shape.push_back(shape[0]); + device_shape.push_back(shape[kH]); + device_shape.push_back(shape[kW]); + device_shape.push_back(shape[kC]); + device_shape.push_back(shape[kN]); return device_shape; } @@ -332,9 +257,9 @@ std::vector FracZDeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; - size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; - device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); + const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize; + const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize; + device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize); device_shape.push_back(cout16 / kCubeSize); device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); @@ -346,12 +271,12 @@ std::vector Nc1hwc0DeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; - size_t C0 = kCubeSize; - device_shape.push_back(shape[0]); + const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize; + const size_t C0 = kCubeSize; + device_shape.push_back(shape[kN]); device_shape.push_back(C1); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); + device_shape.push_back(shape[kH]); + device_shape.push_back(shape[kW]); device_shape.push_back(C0); return device_shape; } @@ -361,10 +286,10 @@ std::vector C1hwncoc0DeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - device_shape.push_back((shape[1] - 1) / kCubeSize + 1); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); - device_shape.push_back(shape[0]); + device_shape.push_back((shape[kC] - 1) / kCubeSize + 1); + device_shape.push_back(shape[kH]); + device_shape.push_back(shape[kW]); + device_shape.push_back(shape[kN]); device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); return device_shape; @@ -375,9 +300,9 @@ std::vector FracZc04DeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - size_t c0 = 4; - auto first_dim = DivCeil(c0 * shape.at(2) * shape.at(3), kCubeSize); - auto no = DivCeil(shape.at(0), kCubeSize); + const size_t c0 = 4; + auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize); + auto no = DivCeil(shape.at(kN), kCubeSize); device_shape.push_back(first_dim); device_shape.push_back(no); device_shape.push_back(kCubeSize); @@ -390,24 +315,101 @@ std::vector Nc1hwc04DeviceShape(const std::vector &shape) { MS_LOG(EXCEPTION) << "Check dims failed."; } std::vector device_shape; - size_t C1 = 1; - size_t C0 = 4; - device_shape.push_back(shape[0]); + const size_t C1 = 1; + const size_t C0 = 4; + device_shape.push_back(shape[kN]); device_shape.push_back(C1); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); + device_shape.push_back(shape[kH]); + device_shape.push_back(shape[kW]); device_shape.push_back(C0); return device_shape; } std::vector NdhwcDeviceShape(const std::vector &shape) { - if (shape.size() < 5) { + if (shape.size() < kNdhwc) { MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; } return shape; } + +std::vector PaddingShapeTo4dByDefault(const std::vector &shape) { + std::vector shape_4d(kNchwDims, 1); + switch (shape.size()) { + case 0: + return shape_4d; + case 1: + shape_4d[kC] = shape[kN]; + break; + case 2: + shape_4d[kC] = shape[kN]; + shape_4d[kH] = shape[kC]; + break; + case 3: + shape_4d[kC] = shape[kN]; + shape_4d[kH] = shape[kC]; + shape_4d[kW] = shape[kH]; + break; + case 4: + std::copy(shape.begin(), shape.end(), shape_4d.begin()); + break; + default: + MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); + } + return shape_4d; +} } // namespace +bool IsNeedPadding(const std::string &format, const size_t shape_size) { + if (shape_size == 0) { + return false; + } + if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { + return false; + } else if (shape_size < kNchwDims) { + return true; + } + return false; +} + +std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { + MS_EXCEPTION_IF_NULL(node); + std::vector shape; + std::vector host_shape; + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + MS_LOG(EXCEPTION) << " The node[ " << node->DebugString() << "]'s cannot convert "; + } + auto shape_temp = tensor->shape(); + (void)std::transform(shape_temp.begin(), shape_temp.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.push_back(1); + } + } else { + host_shape = AnfAlgo::GetOutputInferShape(node, index); + } + if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { + host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); + } + std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); + return shape; +} + +std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis) { + if (padding_axis.empty() || shape.size() != padding_axis.size()) { + return PaddingShapeTo4dByDefault(shape); + } + std::vector shape_4d(kNchwDims, 1); + for (size_t index = 0; index < padding_axis.size(); index++) { + shape_4d[padding_axis[index]] = shape[index]; + } + return shape_4d; +} + std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, @@ -439,7 +441,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s device_shape.push_back(kCubeSize); return device_shape; } - if (shape.size() != 4) { + if (shape.size() != kNchwDims) { MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; temp_shape = PaddingShapeTo4dByDefault(shape); } @@ -455,6 +457,8 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } + MS_EXCEPTION_IF_NULL(size); + MS_EXCEPTION_IF_NULL(total_size); *size = TypeIdSize(args.src_data_type); if (*size < 1) { MS_LOG(ERROR) << "Illegal dtype."; @@ -540,10 +544,10 @@ bool NchwTo4D(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Check args failed."; return false; } - size_t n = args.host_shape[0]; - size_t c = args.host_shape[1]; - size_t h = args.host_shape[2]; - size_t w = args.host_shape[3]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; for (size_t ni = 0; ni < n; ni++) { for (size_t ci = 0; ci < c; ci++) { for (size_t hi = 0; hi < h; hi++) { @@ -572,10 +576,10 @@ bool ToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Check args failed."; return false; } - size_t n = args.host_shape[0]; - size_t c = args.host_shape[1]; - size_t h = args.host_shape[2]; - size_t w = args.host_shape[3]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; for (size_t ni = 0; ni < n; ni++) { for (size_t ci = 0; ci < c; ci++) { for (size_t hi = 0; hi < h; hi++) { @@ -602,32 +606,32 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - size_t size = TypeIdSize(args.src_data_type); + auto size = TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; - size_t c0 = CubeSizeByType(args.src_data_type); + auto c0 = CubeSizeByType(args.src_data_type); if (c0 < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t c1 = DivCeil(c, c0); - size_t hw = h * w; - size_t chw = c * hw; - size_t hwc0 = hw * c0; - size_t nchw = n * chw; - - size_t hf_cnt = DivCeil(n, kCubeSize); - size_t vf_cnt = c1 * hw; - size_t fractal_ele_cnt = c0 * kCubeSize; - size_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; - size_t dst_size = total_ele_cnt * size; + auto c1 = DivCeil(c, c0); + auto hw = h * w; + auto chw = c * hw; + auto hwc0 = hw * c0; + auto nchw = n * chw; + + auto hf_cnt = DivCeil(n, kCubeSize); + auto vf_cnt = c1 * hw; + auto fractal_ele_cnt = c0 * kCubeSize; + auto total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; + auto dst_size = total_ele_cnt * size; if (dst_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size." << "dst size is :" << dst_size << "device size is :" << args.device_size; @@ -647,7 +651,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { auto src_ni = hfi * kCubeSize + col; auto src_idx = src_row_offset + chw * col; auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; - auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? true : false; + auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c; SetData(size, pad_zero, src_idx, dst_idx, args, result); } } @@ -663,12 +667,12 @@ bool FracZToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - size_t size = TypeIdSize(args.src_data_type); + auto size = TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t total_size = ShapeSize(args.device_shape) * size; + auto total_size = ShapeSize(args.device_shape) * size; if (total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; return false; @@ -677,18 +681,16 @@ bool FracZToNchw(const FormatArgs &args, void *result) { auto n0 = args.device_shape.at(1); auto ni = args.device_shape.at(2); auto c0 = args.device_shape.at(3); - - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; - - size_t nc = ni * n0; - size_t ncc0 = nc * c0; - size_t wncc0 = w * ncc0; - size_t hwncc0 = h * wncc0; - size_t hw = h * w; - size_t chw = c * hw; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; + auto nc = ni * n0; + auto ncc0 = nc * c0; + auto wncc0 = w * ncc0; + auto hwncc0 = h * wncc0; + auto hw = h * w; + auto chw = c * hw; for (size_t n_idx = 0; n_idx < n; n_idx++) { size_t n_head_addr = n_idx * chw; @@ -720,20 +722,18 @@ bool NchwToFracZc04(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Check args failed."; return false; } - size_t cube = kCubeSize; - size_t n = args.host_shape[0]; - size_t c = args.host_shape[1]; - size_t h = args.host_shape[2]; - size_t w = args.host_shape[3]; - - size_t c0 = 4; - size_t c1 = DivCeil(c, c0); - size_t hwc0 = h * w * c0; - size_t hwc = h * w * c; - size_t nhwc = n * h * w * c; - - size_t n_cnt = DivCeil(n, cube); - size_t v_cnt = DivCeil(h * w * c0 * c1, cube); + auto cube = kCubeSize; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; + const size_t c0 = 4; + auto c1 = DivCeil(c, c0); + auto hwc0 = h * w * c0; + auto hwc = h * w * c; + auto nhwc = n * h * w * c; + auto n_cnt = DivCeil(n, cube); + auto v_cnt = DivCeil(h * w * c0 * c1, cube); size_t dst_idx = 0; for (size_t vi = 0; vi < v_cnt; vi++) { @@ -929,7 +929,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - size_t size = TypeIdSize(args.src_data_type); + auto size = TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; @@ -940,20 +940,23 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { return false; } - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; - size_t c0 = CubeSizeByType(args.src_data_type); + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; + auto c0 = CubeSizeByType(args.src_data_type); if (c0 < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t c1 = DivCeil(c, c0); - size_t hw = h * w; - size_t chw = c * hw; - size_t c1hwc0 = c1 * hw * c0; - size_t wc0 = w * c0; + if (args.device_format == kOpFormat_NC1HWC0_C04) { + c0 = 4; + } + auto c1 = DivCeil(c, c0); + auto hw = h * w; + auto chw = c * hw; + auto c1hwc0 = c1 * hw * c0; + auto wc0 = w * c0; for (size_t n_idx = 0; n_idx < n; n_idx++) { size_t n_head_addr = n_idx * c1hwc0; @@ -967,7 +970,7 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result) { size_t dst_idx = c0_idx + w_head_addr; size_t c_idx = c0_idx + c1_idx * c0; size_t src_idx = n_idx * chw + c_idx * hw + h_idx * w + w_idx; - auto pad_zero = (c_idx < c) ? false : true; + auto pad_zero = c_idx >= c; SetData(size, pad_zero, src_idx, dst_idx, args, result); } } @@ -984,29 +987,29 @@ bool Nc1hwc0ToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Invalid host shape, host shape dims:" << args.host_shape.size() << ", expect dims:" << kNchwDims; return false; } - size_t size = TypeIdSize(args.src_data_type); + auto size = TypeIdSize(args.src_data_type); if (size < 1) { MS_LOG(ERROR) << "Illegal dtype."; return false; } - size_t total_size = ShapeSize(args.device_shape) * size; + auto total_size = ShapeSize(args.device_shape) * size; if (total_size != args.device_size) { MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; return false; } - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; auto c1 = args.device_shape[1]; auto c0 = args.device_shape[4]; - size_t hw = h * w; - size_t chw = c * hw; - size_t wc0 = w * c0; - size_t hwc0 = h * wc0; - size_t c1hwc0 = c1 * hwc0; + auto hw = h * w; + auto chw = c * hw; + auto wc0 = w * c0; + auto hwc0 = h * wc0; + auto c1hwc0 = c1 * hwc0; for (size_t n_idx = 0; n_idx < n; n_idx++) { size_t n_head_addr = n_idx * chw; @@ -1037,13 +1040,15 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Check args failed."; return false; } - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; + const int co_idx = 4; + const int c0_idx = 5; auto c1 = args.device_shape[0]; - auto co = args.device_shape[4]; - auto c0 = args.device_shape[5]; + auto co = args.device_shape[co_idx]; + auto c0 = args.device_shape[c0_idx]; for (size_t c1_i = 0; c1_i < c1; c1_i++) { for (size_t h_i = 0; h_i < h; h_i++) { @@ -1055,7 +1060,7 @@ bool NchwToC1hwncoc0(const FormatArgs &args, void *result) { co_i * c0 + c0_i; size_t c_i = c0_i + c1_i * c0; size_t src_idx = n_i * c * h * w + c_i * h * w + h_i * w + w_i; - auto pad_zero = (c_i < c && c0_i == co_i) ? false : true; + auto pad_zero = !(c_i < c && c0_i == co_i); SetData(size, pad_zero, src_idx, dst_idx, args, result); } } @@ -1076,12 +1081,14 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { MS_LOG(ERROR) << "Check args failed."; return false; } - auto n = args.host_shape[0]; - auto c = args.host_shape[1]; - auto h = args.host_shape[2]; - auto w = args.host_shape[3]; - auto co = args.device_shape[4]; - auto c0 = args.device_shape[5]; + auto n = args.host_shape[kN]; + auto c = args.host_shape[kC]; + auto h = args.host_shape[kH]; + auto w = args.host_shape[kW]; + const int co_idx = 4; + const int c0_idx = 5; + auto co = args.device_shape[co_idx]; + auto c0 = args.device_shape[c0_idx]; for (size_t n_i = 0; n_i < n; n_i++) { for (size_t c_i = 0; c_i < c; c_i++) { for (size_t h_i = 0; h_i < h; h_i++) {