|
|
|
@ -189,7 +189,7 @@ size_t CubeSizeByType(const TypeId data_type) {
|
|
|
|
|
const size_t default_error = 0;
|
|
|
|
|
auto dt_size = abstract::TypeIdSize(data_type);
|
|
|
|
|
if (dt_size < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Illegal dtype.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "Illegal dtype.";
|
|
|
|
|
return default_error;
|
|
|
|
|
} else if (dt_size == 1) {
|
|
|
|
|
return kCubeSize * 2;
|
|
|
|
@ -206,14 +206,14 @@ bool CheckDims(const std::vector<size_t> &shape) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> NchwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
|
return shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Ccheck dims failed.";
|
|
|
|
|
}
|
|
|
|
@ -225,7 +225,7 @@ std::vector<size_t> NhwcDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
@ -237,27 +237,29 @@ std::vector<size_t> HwchDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> FracZDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t cout16 = ((shape[kN] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
|
|
|
|
const size_t cin16 = ((shape[kC] + kCubeSize - 1) / kCubeSize) * kCubeSize;
|
|
|
|
|
device_shape.push_back(shape[kH] * shape[kW] * cin16 / kCubeSize);
|
|
|
|
|
device_shape.push_back(cout16 / kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
auto c1 = DivCeil(shape[kC], kCube);
|
|
|
|
|
auto n0 = DivCeil(shape[kN], kCubeSize);
|
|
|
|
|
device_shape.push_back(shape[kH] * shape[kW] * c1);
|
|
|
|
|
device_shape.push_back(n0);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t C1 = (shape[kC] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
const size_t C0 = kCubeSize;
|
|
|
|
|
const size_t C1 = (shape[kC] + kCube - 1) / kCube;
|
|
|
|
|
const size_t C0 = kCube;
|
|
|
|
|
device_shape.push_back(shape[kN]);
|
|
|
|
|
device_shape.push_back(C1);
|
|
|
|
|
device_shape.push_back(shape[kH]);
|
|
|
|
@ -266,7 +268,7 @@ std::vector<size_t> Nc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
// NCDHW
|
|
|
|
|
if (shape.size() != 5) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
|
|
|
@ -283,51 +285,54 @@ std::vector<size_t> Ndc1hwc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> Fracz3DDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
// NCDHW -> Frac_Z_3D
|
|
|
|
|
if (shape.size() != 5) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size();
|
|
|
|
|
}
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
const size_t C1 = (shape[1] + kCube - 1) / kCube;
|
|
|
|
|
const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize;
|
|
|
|
|
device_shape.push_back(shape[2] * C1 * shape[3] * shape[4]);
|
|
|
|
|
device_shape.push_back(N1);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> C1hwncoc0DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
device_shape.push_back((shape[kC] - 1) / kCubeSize + 1);
|
|
|
|
|
device_shape.push_back((shape[kC] - 1) / kCube + 1);
|
|
|
|
|
device_shape.push_back(shape[kH]);
|
|
|
|
|
device_shape.push_back(shape[kW]);
|
|
|
|
|
device_shape.push_back(shape[kN]);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> FracZc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
const size_t c0 = 4;
|
|
|
|
|
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCubeSize);
|
|
|
|
|
auto no = DivCeil(shape.at(kN), kCubeSize);
|
|
|
|
|
auto first_dim = DivCeil(c0 * shape[kH] * shape[kW], kCube);
|
|
|
|
|
auto no = DivCeil(shape.at(kN), kCube);
|
|
|
|
|
device_shape.push_back(first_dim);
|
|
|
|
|
device_shape.push_back(no);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (!CheckDims(shape)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Check dims failed.";
|
|
|
|
|
}
|
|
|
|
@ -342,7 +347,7 @@ std::vector<size_t> Nc1hwc04DeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
return device_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape) {
|
|
|
|
|
std::vector<size_t> NcdhwDeviceShape(const std::vector<size_t> &shape, const TypeId &type) {
|
|
|
|
|
if (shape.size() < kNdhwc) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc.";
|
|
|
|
|
}
|
|
|
|
@ -427,8 +432,9 @@ std::vector<size_t> PaddingShapeTo4d(const std::vector<size_t> &shape, const std
|
|
|
|
|
return shape_4d;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format) {
|
|
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &)>;
|
|
|
|
|
std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const std::string &format,
|
|
|
|
|
const TypeId &type) {
|
|
|
|
|
using DeviceShapeTransfer = std::function<std::vector<size_t>(const std::vector<size_t> &, const TypeId &)>;
|
|
|
|
|
const std::map<std::string, DeviceShapeTransfer> device_shape_map{{kOpFormat_NCHW, NchwDeviceShape},
|
|
|
|
|
{kOpFormat_NHWC, NhwcDeviceShape},
|
|
|
|
|
{kOpFormat_HWCN, HwchDeviceShape},
|
|
|
|
@ -446,8 +452,9 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
}
|
|
|
|
|
auto temp_shape = shape;
|
|
|
|
|
std::vector<size_t> device_shape;
|
|
|
|
|
auto kCube = CubeSizeByType(type);
|
|
|
|
|
if (format == kOpFormat_FRAC_NZ) {
|
|
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCubeSize == 0)) {
|
|
|
|
|
if (shape.size() == 1 && (shape[0] == 1 || shape[0] % kCube == 0)) {
|
|
|
|
|
// For [1] and [1024] shape we can trait it as NZ shape
|
|
|
|
|
return shape;
|
|
|
|
|
}
|
|
|
|
@ -456,12 +463,12 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
} else {
|
|
|
|
|
(void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape));
|
|
|
|
|
}
|
|
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCube + 1;
|
|
|
|
|
auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1;
|
|
|
|
|
auto w1 = (shape[shape.size() - 1] - 1) / kCubeSize + 1;
|
|
|
|
|
device_shape.push_back(w1);
|
|
|
|
|
device_shape.push_back(h1);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCubeSize);
|
|
|
|
|
device_shape.push_back(kCube);
|
|
|
|
|
return device_shape;
|
|
|
|
|
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
|
|
|
|
|
const size_t c0 = 4;
|
|
|
|
@ -483,7 +490,7 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
|
|
|
|
|
if (iter == device_shape_map.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]";
|
|
|
|
|
}
|
|
|
|
|
return iter->second(temp_shape);
|
|
|
|
|
return iter->second(temp_shape, type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) {
|
|
|
|
|