|
|
|
@ -526,13 +526,6 @@ bool TransDataType(const TypeIdArgs &args, void *result) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TransFormat(const FormatArgs &args, void *result) {
|
|
|
|
|
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{
|
|
|
|
|
{kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz},
|
|
|
|
|
{kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0},
|
|
|
|
|
{kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04},
|
|
|
|
|
{kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}};
|
|
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start trans format.";
|
|
|
|
|
if (abstract::TypeIdSize(args.src_data_type) < 1) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid datatype..";
|
|
|
|
@ -541,15 +534,14 @@ bool TransFormat(const FormatArgs &args, void *result) {
|
|
|
|
|
if (args.device_format == kOpFormat_HWCN || args.device_format == kOpFormat_NHWC) {
|
|
|
|
|
return NchwTo4D(args, result);
|
|
|
|
|
}
|
|
|
|
|
auto iter = format_trans_map.find(args.device_format);
|
|
|
|
|
if (iter == format_trans_map.end()) {
|
|
|
|
|
auto iter = kTransFormatMapOfHostToDevice.find(args.device_format);
|
|
|
|
|
if (iter == kTransFormatMapOfHostToDevice.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Unexpected format[" << args.device_format << "]";
|
|
|
|
|
}
|
|
|
|
|
return iter->second(args, result);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) {
|
|
|
|
|
using FormatTransfer = std::function<bool(const FormatArgs &, void *)>;
|
|
|
|
|
const std::map<std::string, FormatTransfer> format_trans_map{
|
|
|
|
|
{kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw},
|
|
|
|
|
{kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw},
|
|
|
|
|