fix detect model to eval

pull/3988/head
guohongzilong 5 years ago
parent 34214e8f4c
commit 226d1ec6fd

@ -35,7 +35,7 @@ int DeConv2D::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tens
int32_t output_n = input->Batch();
int32_t output_h = 0;
int32_t output_w = 0;
int32_t output_c = weight->Batch();
int32_t output_c = weight->Channel();
auto deconv = GetAttribute();
int kernel_w = deconv->kernelW();

@ -154,7 +154,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_KHWC;
weightTensor->format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "unsupport format";
return -1;
@ -367,8 +367,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
@ -390,7 +390,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW;
} else {
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str();

@ -49,10 +49,10 @@ STATUS TfliteDeConvParser::Parse(const std::unique_ptr<tflite::OperatorT> &tflit
return RET_ERROR;
}
auto weight_shape = weight_tensor->shape;
attr->channelIn = weight_shape[KHWC_C];
attr->channelOut = weight_shape[KHWC_K];
attr->kernelW = weight_shape[KHWC_W];
attr->kernelH = weight_shape[KHWC_H];
attr->channelIn = weight_shape[CHWK_K];
attr->channelOut = weight_shape[CHWK_C];
attr->kernelW = weight_shape[CHWK_W];
attr->kernelH = weight_shape[CHWK_H];
if (op != nullptr) {
op->primitive = std::make_unique<schema::PrimitiveT>();

Loading…
Cancel
Save