|
|
|
@ -266,6 +266,41 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
// NCDHW
|
|
|
|
|
if (shape.size() != 5) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
const size_t C0 = kCubeSize;
|
|
|
|
|
device_shape.push_back(shape[0]);
|
|
|
|
|
device_shape.push_back(shape[2]);
|
|
|
|
|
device_shape.push_back(C1);
|
|
|
|
|
device_shape.push_back(shape[3]);
|
|
|
|
|
device_shape.push_back(shape[4]);
|
|
|
|
|
device_shape.push_back(C0);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
// NCDHW -> Frac_Z_3D
|
|
|
|
|
if (shape.size() != 5) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
device_shape.push_back(shape[2]);
|
|
|
|
|
device_shape.push_back(C1);
|
|
|
|
|
device_shape.push_back(shape[3]);
|
|
|
|
|
device_shape.push_back(shape[4]);
|
|
|
|
|
device_shape.push_back(N1);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
@ -310,7 +345,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NdhwcDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
if (shape.size() < kNdhwc) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
|
|
|
|
}
|
|
|
|
@ -405,7 +440,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
{kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape},
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape},
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape},
|
|
|
|
|
{kOpFormat_NDHWC, NdhwcDeviceShape}};
|
|
|
|
|
{kOpFormat_NCDHW, NcdhwDeviceShape},
|
|
|
|
|
{kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape},
|
|
|
|
|
{kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}};
|
|
|
|
|
|
|
|
|
|
if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) {
|
|
|
|
|
return shape;
|
|
|
|
@ -441,7 +478,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
if (shape.size() != kNchwDims) {
|
|
|
|
|
if (shape.size() != kNchwDims && shape.size() != 5) {
|
|
|
|
|
MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly";
|
|
|
|
|
temp_shape = PaddingShapeTo4dByDefault(shape);
|
|
|
|
|
}
|
|
|
|
@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{
|
|
|
|
|
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
|
|
|
|
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}};
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
|
|
|
|
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start trans format.";
|
|
|
|
|
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid datatype..";
|
|
|
|
@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|
|
|
|
|
|
|
|
|
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|
|
|
|
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw},
|
|
|
|
|
{kOpFormat_FRAC_NZ, FracNzToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw},
|
|
|
|
|
{kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}};
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{
|
|
|
|
|
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}};
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start trans format.";
|
|
|
|
|
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid datatype..";
|
|
|
|
@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) {
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) {
|
|
|
|
|
MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type);
|
|
|
|
|
if (size < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
|
|
|
|
if (total_size != args.device_size) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto n = args.host_shape[0];
|
|
|
|
|
auto c = args.host_shape[1];
|
|
|
|
|
auto d = args.host_shape[2];
|
|
|
|
|
auto h = args.host_shape[3];
|
|
|
|
|
auto w = args.host_shape[4];
|
|
|
|
|
auto c1 = args.device_shape[2];
|
|
|
|
|
auto c0 = args.device_shape[5];
|
|
|
|
|
const size_t cdhw = c * d * h * w;
|
|
|
|
|
const size_t dhw = d * h * w;
|
|
|
|
|
const size_t hw = h * w;
|
|
|
|
|
const size_t dc1hwc0 = d * c1 * h * w * c0;
|
|
|
|
|
const size_t c1hwc0 = c1 * h * w * c0;
|
|
|
|
|
const size_t hwc0 = h * w * c0;
|
|
|
|
|
const size_t wc0 = w * c0;
|
|
|
|
|
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) {
|
|
|
|
|
size_t n_head = n_i * cdhw;
|
|
|
|
|
for (size_t c_i = 0; c_i < c; c_i++) {
|
|
|
|
|
size_t c_head = n_head + c_i * dhw;
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
size_t d_head = c_head + d_i * hw;
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) {
|
|
|
|
|
size_t h_head = d_head + h_i * w;
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) {
|
|
|
|
|
size_t dst_i = h_head + w_i;
|
|
|
|
|
size_t c1_i = c_i / c0;
|
|
|
|
|
size_t c0_i = c_i % c0;
|
|
|
|
|
auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i;
|
|
|
|
|
SetData(size, false, src_idx, dst_i, args, result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) {
|
|
|
|
|
MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0";
|
|
|
|
|
MS_EXCEPTION_IF_NULL(result);
|
|
|
|
|
|
|
|
|
|
if (args.host_shape.size() != 5) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto size = abstract::TypeIdSize(args.src_data_type);
|
|
|
|
|
if (size < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto total_size = abstract::ShapeSize(args.device_shape) * size;
|
|
|
|
|
if (total_size != args.device_size) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto n = args.host_shape[0];
|
|
|
|
|
auto c = args.host_shape[1];
|
|
|
|
|
auto d = args.host_shape[2];
|
|
|
|
|
auto h = args.host_shape[3];
|
|
|
|
|
auto w = args.host_shape[4];
|
|
|
|
|
auto c0 = kCubeSize;
|
|
|
|
|
auto c1 = DivCeil(c, c0);
|
|
|
|
|
const size_t cdhw = c * d * h * w;
|
|
|
|
|
const size_t dhw = d * h * w;
|
|
|
|
|
const size_t hw = h * w;
|
|
|
|
|
const size_t dc1hwc0 = d * c1 * h * w * c0;
|
|
|
|
|
const size_t c1hwc0 = c1 * h * w * c0;
|
|
|
|
|
const size_t hwc0 = h * w * c0;
|
|
|
|
|
const size_t wc0 = w * c0;
|
|
|
|
|
|
|
|
|
|
for (size_t n_i = 0; n_i < n; n_i++) {
|
|
|
|
|
size_t n_head = n_i * dc1hwc0;
|
|
|
|
|
for (size_t d_i = 0; d_i < d; d_i++) {
|
|
|
|
|
size_t d_head = n_head + d_i * c1hwc0;
|
|
|
|
|
for (size_t c1_i = 0; c1_i < c1; c1_i++) {
|
|
|
|
|
size_t c1_head = d_head + c1_i * hwc0;
|
|
|
|
|
for (size_t h_i = 0; h_i < h; h_i++) {
|
|
|
|
|
size_t h_head = c1_head + h_i * wc0;
|
|
|
|
|
for (size_t w_i = 0; w_i < w; w_i++) {
|
|
|
|
|
size_t w_head = h_head + w_i * c0;
|
|
|
|
|
for (size_t c0_i = 0; c0_i < c0; c0_i++) {
|
|
|
|
|
size_t dst_i = c0_i + w_head;
|
|
|
|
|
size_t c_i = c0_i + c1_i * c0;
|
|
|
|
|
size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i;
|
|
|
|
|
auto pad_zero = c_i >= c;
|
|
|
|
|
SetData(size, pad_zero, src_i, dst_i, args, result);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace trans
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|