|
|
|
@ -1284,15 +1284,15 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
|
|
|
|
auto n1n0c0 = n1n0 * c0;
|
|
|
|
|
auto wn1n0c0 = w * n1n0c0;
|
|
|
|
|
auto hwn1n0c0 = h * wn1n0c0;
|
|
|
|
|
auto dhwn1n0c0 = d * hwn1n0c0;
|
|
|
|
|
auto c1hwn1n0c0 = c1 * hwn1n0c0;
|
|
|
|
|
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) {
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) {
|
|
|
|
|
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
|
|
|
|
|
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
|
|
|
|
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
|
|
|
|
|
auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
|
|
|
|
|
// ncdhw
|
|
|
|
|
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
|
|
|
|
|
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
|
|
|
|
@ -1329,17 +1329,16 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|
|
|
|
auto d = args.host_shape[2];
|
|
|
|
|
auto h = args.host_shape[3];
|
|
|
|
|
auto w = args.host_shape[4];
|
|
|
|
|
auto n0 = args.device_shape[1];
|
|
|
|
|
auto ni = args.device_shape[2];
|
|
|
|
|
auto c0 = args.device_shape[3];
|
|
|
|
|
auto c1 = DivCeil(c, kCubeSize);
|
|
|
|
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
|
|
|
|
auto n1n0c0 = n1n0 * c0;
|
|
|
|
|
auto wn1n0c0 = w * n1n0c0;
|
|
|
|
|
auto hwn1n0c0 = h * wn1n0c0;
|
|
|
|
|
auto c1hwn1n0c0 = c1 * hwn1n0c0;
|
|
|
|
|
auto hw = h * w;
|
|
|
|
|
auto dhw = d * hw;
|
|
|
|
|
auto cdhw = c * dhw;
|
|
|
|
|
auto nc = ni * n0;
|
|
|
|
|
auto ncc0 = nc * c0;
|
|
|
|
|
auto wncc0 = w * ncc0;
|
|
|
|
|
auto hwncc0 = h * wncc0;
|
|
|
|
|
auto dhwncc0 = d * hwncc0;
|
|
|
|
|
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) {
|
|
|
|
|
size_t n_head = n_i * cdhw;
|
|
|
|
@ -1354,7 +1353,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|
|
|
|
size_t c1_i = c_i / c0;
|
|
|
|
|
size_t c0_i = c_i % c0;
|
|
|
|
|
size_t nc_i = n_i;
|
|
|
|
|
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
|
|
|
|
|
size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
|
|
|
|
|
SetData(size, false, src_i, dst_i, args, result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|