!11536 support ncdhw to frac_z_3d data trans at host

From: @liubuyu
Reviewed-by: @kisnwang,@zhoufeng54
Signed-off-by: @zhoufeng54
pull/11536/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 59a277756e

@ -200,7 +200,7 @@ size_t CubeSizeByType(const TypeId data_type) {
namespace {
bool CheckDims(const std::vector<size_t> &shape) {
if (shape.size() != kNchwDims) {
MS_LOG(ERROR) << "Host shape dims shoud be 4";
MS_LOG(ERROR) << "Host shape dims should be 4";
return false;
}
return true;
@ -370,7 +370,7 @@ std::vector<size_t> PaddingShapeTo4dByDefault(const std::vector<size_t> &shape)
std::copy(shape.begin(), shape.end(), shape_4d.begin());
break;
default:
MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size();
MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size();
}
return shape_4d;
}
@ -545,7 +545,8 @@ bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
const std::map<std::string, FormatTransfer> format_trans_map{
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw},
{kOpFormat_FRACTAL_Z_3D, FracZ3DToNcdhw}};
MS_LOG(DEBUG) << "Start trans format.";
if (abstract::TypeIdSize(args.src_data_type) < 1) {
@ -1248,5 +1249,119 @@ bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
}
return true;
}
bool NcdhwToFracZ3D(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from ncdhw to frac_z_3d";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false;
}
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto n1n0 = DivCeil(n, kCubeSize) * kCubeSize;
auto c0 = CubeSizeByType(args.src_data_type);
auto c1 = DivCeil(c, c0);
auto hw = h * w;
auto dhw = d * hw;
auto cdhw = c * dhw;
auto n1n0c0 = n1n0 * c0;
auto wn1n0c0 = w * n1n0c0;
auto hwn1n0c0 = h * wn1n0c0;
auto dhwn1n0c0 = d * hwn1n0c0;
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
for (size_t d_i = 0; d_i < d; d_i++) {
for (size_t h_i = 0; h_i < h; h_i++) {
for (size_t w_i = 0; w_i < w; w_i++) {
for (size_t n1n0_i = 0; n1n0_i < n1n0; n1n0_i++) {
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
size_t dst_i = c1_i * dhwn1n0c0 + d_i * hwn1n0c0 + h_i * wn1n0c0 + w_i * n1n0c0 + n1n0_i * c0 + c0_i;
// ncdhw
size_t src_i = n1n0_i * cdhw + (c1_i * c0 + c0_i) * dhw + d_i * hw + h_i * w + w_i;
auto pad_zero = ((c1_i * c0 + c0_i) >= c) || (n1n0_i >= n);
SetData(size, pad_zero, src_i, dst_i, args, result);
}
}
}
}
}
}
return true;
}
bool FracZ3DToNcdhw(const FormatArgs &args, void *result) {
MS_LOG(DEBUG) << "Trans from frac_z_3d to ncdhw";
MS_EXCEPTION_IF_NULL(result);
if (args.host_shape.size() != 5) {
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
return false;
}
auto size = abstract::TypeIdSize(args.src_data_type);
if (size < 1) {
MS_LOG(ERROR) << "Illegal dtype.";
return false;
}
auto total_size = abstract::ShapeSize(args.device_shape) * size;
if (total_size != args.device_size) {
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
return false;
}
auto n = args.host_shape[0];
auto c = args.host_shape[1];
auto d = args.host_shape[2];
auto h = args.host_shape[3];
auto w = args.host_shape[4];
auto n0 = args.device_shape[1];
auto ni = args.device_shape[1];
auto c0 = args.device_shape[3];
auto hw = h * w;
auto dhw = d * hw;
auto cdhw = c * dhw;
auto nc = ni * n0;
auto ncc0 = nc * c0;
auto wncc0 = w * ncc0;
auto hwncc0 = h * wncc0;
auto dhwncc0 = d * hwncc0;
for (size_t n_i = 0; n_i < n; n_i++) {
size_t n_head = n_i * cdhw;
for (size_t c_i = 0; c_i < c; c_i++) {
size_t c_head = n_head + c_i * dhw;
for (size_t d_i = 0; d_i < d; d_i++) {
size_t d_head = c_head + d_i * hw;
for (size_t h_i = 0; h_i < h; h_i++) {
size_t h_head = d_head + h_i * w;
for (size_t w_i = 0; w_i < w; w_i++) {
size_t dst_i = h_head + w_i;
size_t c1_i = c_i / c0;
size_t c0_i = c_i % c0;
size_t nc_i = n_i;
size_t src_i = c1_i * dhwncc0 + d_i * hwncc0 + h_i * wncc0 + w_i * ncc0 + nc_i * c0 + c0_i;
SetData(size, false, src_i, dst_i, args, result);
}
}
}
}
}
return true;
}
} // namespace trans
} // namespace mindspore

@ -63,6 +63,7 @@ bool NchwTo4D(const FormatArgs &args, void *result);
bool NchwToFracZ(const FormatArgs &args, void *result);
bool NchwToFracNz(const FormatArgs &args, void *result);
bool NchwToNc1hwc0(const FormatArgs &args, void *result);
bool NcdhwToFracZ3D(const FormatArgs &args, void *result);
bool NchwToFracZc04(const FormatArgs &args, void *result);
bool NchwToNc1hwc04(const FormatArgs &args, void *result);
bool NchwToC1hwncoc0(const FormatArgs &args, void *result);
@ -74,6 +75,7 @@ bool FracZToNchw(const FormatArgs &args, void *result);
bool FracNzToNchw(const FormatArgs &args, void *result);
bool Nc1hwc0ToNchw(const FormatArgs &args, void *result);
bool Nc1hwc04ToNchw(const FormatArgs &args, void *result);
bool FracZ3DToNcdhw(const FormatArgs &args, void *result);
bool C1hwncoc0ToNchw(const FormatArgs &args, void *result);
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result);
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
@ -81,7 +83,7 @@ const std::map<std::string, FormatTransfer> kTransFormatMapOfHostToDevice{
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}, {kOpFormat_FRACTAL_Z_3D, NcdhwToFracZ3D}};
} // namespace trans
} // namespace mindspore

@ -93,9 +93,9 @@ namespace device {
namespace ascend {
const int FLOAT_LEN = sizeof(float);
const int FLOAT16_LEN = 2; // sizeof(float16);
const std::set<std::string> kOpNeedTransFormat = {kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0,
kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ,
kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0};
const std::set<std::string> kOpNeedTransFormat = {
kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0,
kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) {
auto ms_context = MsContext::GetInstance();
@ -575,7 +575,8 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh
host_shape.emplace_back(1);
}
std::vector<size_t> device_shape;
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0) {
if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 ||
format_ == kOpFormat_FRACTAL_Z_3D) {
device_shape = trans::TransShapeToDevice(host_shape, format_);
} else {
host_shape = trans::PaddingShapeTo4d(host_shape);

Loading…
Cancel
Save