update frac_z trans func compute, change the order c and d

pull/12723/head
liubuyu 4 years ago
parent d1b1ad8ad0
commit 6f4b1880df

@ -1284,15 +1284,15 @@ bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
auto n1n0c0 = n1n0 * c0;
auto wn1n0c0 = w * n1n0c0;
auto hwn1n0c0 = h * wn1n0c0;
auto dhwn1n0c0 = d * hwn1n0c0;
auto c1hwn1n0c0 = c1 * hwn1n0c0;
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t d_i = 0; d_i < d; d_i++) {
for (size_t d_i = 0; d_i < d; d_i++) {
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
for (size_t w_i = 0; w_i < w; w_i++) {
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
auto dst_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
// ncdhw
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
@ -1329,17 +1329,16 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto n0 = args.device_shape[1];
auto ni = args.device_shape[2];
auto c0 = args.device_shape[3];
auto c1 = DivCeil(c, kCubeSize);
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
auto n1n0c0 = n1n0 * c0;
auto wn1n0c0 = w * n1n0c0;
auto hwn1n0c0 = h * wn1n0c0;
auto c1hwn1n0c0 = c1 * hwn1n0c0;
auto hw = h * w;
auto dhw = d * hw;
auto cdhw = c * dhw;
auto nc = ni * n0;
auto ncc0 = nc * c0;
auto wncc0 = w * ncc0;
auto hwncc0 = h * wncc0;
auto dhwncc0 = d * hwncc0;
for (size_t n_i = 0; n_i < n; n_i++) {
size_t n_head = n_i * cdhw;
@ -1354,7 +1353,7 @@ bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
size_t c1_i = c_i / c0;
size_t c0_i = c_i % c0;
size_t nc_i = n_i;
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
size_t src_i = d_i * c1hwn1n0c0 + c1_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + nc_i * c0 + c0_i;
SetData(size, false, src_i, dst_i, args, result);
}
}

Loading…
Cancel
Save