add caffe deconv weight format in weight_format_pass

pull/4106/head
yeyunpeng 5 years ago
parent 7a4dcaac5a
commit 6be0fc8c03

@ -50,7 +50,7 @@ void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType =
// pre set tensor format
// non quant, filterFormat:
// conv deconv depth dedepth
// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp
// caffe K(C/g)HW C(K/g)HW / /
// tf HWCK HWKC HWCK HWKC
// onnx K(C/g)HW C(K/g)HW / /
@ -78,7 +78,8 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
if (fmkType == converter::FmkType_CAFFE) {
switch (node->quantType) {
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType)
@ -227,7 +228,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
STATUS status;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK);
@ -237,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
@ -259,7 +260,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
@ -272,13 +273,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
}
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
@ -365,7 +366,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);

Loading…
Cancel
Save