!12089 [MS_LITE]fix tf model

From: @YeFeng_24
Reviewed-by: @hangangqiang
Signed-off-by: @hangangqiang
pull/12089/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 269d276473

@ -56,3 +56,6 @@ mtk_age_gender.pb 1
mtk_model_ckpt.pb 1
mtk_model_face_dress.pb 1;1,128,128,3
mtk_model_normalize_object_scene_ps_20200519.pb 1;1,224,224,3
ml_ocr_latin.pb 1
ml_noya_tts_melgan.pb 1;16,16,80
ml_video_edit_oneclick_adaptis.pb 3

@ -26,7 +26,7 @@ namespace lite {
STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF DeConvParser";
MS_LOG(DEBUG) << "TF DeConvParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
@ -71,8 +71,8 @@ STATUS TFDeconvParser::Parse(const tensorflow::NodeDef &tf_op,
}
attr->kernelH = kernels[0];
attr->kernelW = kernels[1];
attr->channelIn = kernels[2];
attr->channelOut = kernels[3];
attr->channelOut = kernels[2];
attr->channelIn = kernels[3];
} else {
attr->kernelH = -1;
attr->kernelW = -1;

@ -45,6 +45,27 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
primitive->value.type = schema::PrimitiveType_LogicalAnd;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
} else if (tf_op.op() == "LogicalOr") {
auto attr = std::make_unique<schema::LogicalOrT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalOr;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
} else if (tf_op.op() == "LogicalNot") {
auto attr = std::make_unique<schema::LogicalNotT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_LogicalNot;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
} else {
MS_LOG(ERROR) << tf_op.op() << " is not supported.";
return RET_ERROR;
}
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
@ -59,5 +80,7 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_OK;
}
TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser());
TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser());
} // namespace lite
} // namespace mindspore

@ -222,8 +222,16 @@ void ConvTransformFusion::CalNewWeightTensor(const CNodePtr &conv_node, const Pa
return;
}
if (this->fmk_type_ == lite::converter::FmkType_TF) {
for (int i = 0; i < weight_shape_size; i++) {
tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num];
auto group = primc->GetGroup();
auto cin_group = weight_tensor->tensor_shape()[3] / group;
int area_size = weight_tensor->tensor_shape()[0] * weight_tensor->tensor_shape()[1];
for (int j = 0; j < area_size; j++) {
for (int i = 0; i < kernel_num; ++i) {
for (int k = 0; k < cin_group; ++k) {
tmp_weight_data[k + i * cin_group + j * kernel_num * cin_group] =
weight_data[k + i * cin_group + j * kernel_num * cin_group] * trans_scale[i];
}
}
}
} else {
auto group = primc->GetGroup();

Loading…
Cancel
Save