|
|
|
@ -154,7 +154,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|
|
|
|
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
|
|
|
|
|
weightTensor->format = schema::Format_CHWK;
|
|
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) {
|
|
|
|
|
weightTensor->format = schema::Format_KHWC;
|
|
|
|
|
weightTensor->format = schema::Format_CHWK;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "unsupport format";
|
|
|
|
|
return -1;
|
|
|
|
@ -367,8 +367,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|
|
|
|
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
|
|
|
|
|
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
|
|
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
|
|
|
|
|
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
|
|
|
|
|
status = RET_OK;
|
|
|
|
|
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
|
|
|
|
|
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
|
|
|
|
|
return -1;
|
|
|
|
@ -390,7 +390,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
if (status == 0) {
|
|
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
|
|
|
|
|
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
|
|
|
|
|
weightTensor->format = schema::Format_CKHW;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str();
|
|
|
|
|