|
|
|
@ -200,7 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|
|
|
|
namespace {
|
|
|
|
|
bool CheckDims(const std::vector<size_t> &shape) {
|
|
|
|
|
if (shape.size() != kNchwDims) {
|
|
|
|
|
MS_LOG(ERROR) << "Host shape dims shoud be 4";
|
|
|
|
|
MS_LOG(ERROR) << "Host shape dims should be 4";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
@ -370,7 +370,7 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape)
|
|
|
|
|
std::copy(shape.begin(), shape.end(), shape_4d.begin());
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
|
|
|
|
|
}
|
|
|
|
|
return shape_4d;
|
|
|
|
|
}
|
|
|
|
@ -545,7 +545,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{
|
|
|
|
|
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
|
|
|
|
|
{kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start trans format.";
|
|
|
|
|
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
|
|
|
@ -1248,5 +1249,119 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
|
|
|
|
|
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type);
|
|
|
|
|
if (size < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
|
|
|
|
if (total_size != args.device_size) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto n = args.host_shape[0];
|
|
|
|
|
auto c = args.host_shape[1];
|
|
|
|
|
auto d = args.host_shape[2];
|
|
|
|
|
auto h = args.host_shape[3];
|
|
|
|
|
auto w = args.host_shape[4];
|
|
|
|
|
|
|
|
|
|
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
|
|
|
|
|
auto c0 = CubeSizeByType(args.src_data_type);
|
|
|
|
|
auto c1 = DivCeil(c, c0);
|
|
|
|
|
auto hw = h * w;
|
|
|
|
|
auto dhw = d * hw;
|
|
|
|
|
auto cdhw = c * dhw;
|
|
|
|
|
auto n1n0c0 = n1n0 * c0;
|
|
|
|
|
auto wn1n0c0 = w * n1n0c0;
|
|
|
|
|
auto hwn1n0c0 = h * wn1n0c0;
|
|
|
|
|
auto dhwn1n0c0 = d * hwn1n0c0;
|
|
|
|
|
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) {
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) {
|
|
|
|
|
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
|
|
|
|
|
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
|
|
|
|
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
|
|
|
|
|
// ncdhw
|
|
|
|
|
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
|
|
|
|
|
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
|
|
|
|
|
SetData(size, pad_zero, src_i, dst_i, args, result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
|
|
|
|
|
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type);
|
|
|
|
|
if (size < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
|
|
|
|
if (total_size != args.device_size) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto n = args.host_shape[0];
|
|
|
|
|
auto c = args.host_shape[1];
|
|
|
|
|
auto d = args.host_shape[2];
|
|
|
|
|
auto h = args.host_shape[3];
|
|
|
|
|
auto w = args.host_shape[4];
|
|
|
|
|
auto n0 = args.device_shape[1];
|
|
|
|
|
auto ni = args.device_shape[1];
|
|
|
|
|
auto c0 = args.device_shape[3];
|
|
|
|
|
auto hw = h * w;
|
|
|
|
|
auto dhw = d * hw;
|
|
|
|
|
auto cdhw = c * dhw;
|
|
|
|
|
auto nc = ni * n0;
|
|
|
|
|
auto ncc0 = nc * c0;
|
|
|
|
|
auto wncc0 = w * ncc0;
|
|
|
|
|
auto hwncc0 = h * wncc0;
|
|
|
|
|
auto dhwncc0 = d * hwncc0;
|
|
|
|
|
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) {
|
|
|
|
|
size_t n_head = n_i * cdhw;
|
|
|
|
|
for (size_t c_i = 0; c_i < c; c_i++) {
|
|
|
|
|
size_t c_head = n_head + c_i * dhw;
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
size_t d_head = c_head + d_i * hw;
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) {
|
|
|
|
|
size_t h_head = d_head + h_i * w;
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) {
|
|
|
|
|
size_t dst_i = h_head + w_i;
|
|
|
|
|
size_t c1_i = c_i / c0;
|
|
|
|
|
size_t c0_i = c_i % c0;
|
|
|
|
|
size_t nc_i = n_i;
|
|
|
|
|
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
|
|
|
|
|
SetData(size, false, src_i, dst_i, args, result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace trans
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|