!8010 fix int8transpose parser

Merge pull request !8010 from yankai10/merge_1030
pull/8010/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 27d4c8f5fb

@ -824,6 +824,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new InstanceNorm(primitive);
case schema::PrimitiveType_While:
return new While(primitive);
case schema::PrimitiveType_OnnxInt8Quantize:
return new Quant(primitive);
case schema::PrimitiveType_OnnxInt8Dequantize:
return new Dequant(primitive);
#ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad:

@ -196,6 +196,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
${LITE_DIR}/tools/optimizer/graph/clip_convert_activation_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc
${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc
)
endif()

@ -152,7 +152,7 @@ std::string FlagParser::Usage(const Option<std::string> &usgMsg) const {
// first line, brief of the usage
std::string usageString = usgMsg.IsSome() ? usgMsg.Get() + "\n" : "";
// usage of bin name
usageString += usageMsg.IsNone() ? "usage: " + binName + " [options]\n" : usageMsg.Get() + "\n";
usageString += usageMsg.IsNone() ? "\nusage: " + binName + " [options]\n" : usageMsg.Get() + "\n";
// help line of help message, usageLine:message of parametors
std::string helpLine = "";
std::string usageLine = "";

@ -47,6 +47,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/weight_format_hardcode_pass.cc
../optimizer/graph/clip_convert_activation_pass.cc
../optimizer/graph/unused_cast_node_remove_pass.cc
../optimizer/graph/unused_transpose_node_remove_pass.cc
../optimizer/graph/identity_remove_pass.cc
)

@ -33,6 +33,7 @@
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/optimizer/graph/clip_convert_activation_pass.h"
#include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -90,9 +91,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
if (config->fmk == lite::converter::FmkType_MS) {
auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
if (remove_unused_cast_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedCastOpPass shoud be specified";
return nullptr;
}
remove_unused_cast_pass->SetFmkType(config->fmk);
pm->AddPass(remove_unused_cast_pass);
}
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
if (remove_unused_transpose_pass == nullptr) {
MS_LOG(ERROR) << "RemoveUnusedTransposeOpPass shoud be specified";
return nullptr;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);
pm->AddPass(remove_unused_transpose_pass);
}
pm->AddPass(std::make_shared<opt::ConstFoldPass>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
optimizer->AddPassManager(convert_pm);

@ -61,5 +61,6 @@ STATUS OnnxTransposeParser::Parse(const onnx::GraphProto &onnx_graph, const onnx
}
OnnxNodeRegistrar g_onnxTransposeParser("Transpose", new OnnxTransposeParser());
OnnxNodeRegistrar g_onnxInt8TransposeParser("Int8Transpose", new OnnxTransposeParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,90 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/unused_transpose_node_remove_pass.h"
#include <vector>
#include <memory>
#include "tools/optimizer/common/gllo_utils.h"
#include "mindspore/lite/include/errorcode.h"
#include "src/ops/primitive_c.h"
namespace mindspore::opt {
static constexpr size_t kTransposeInput = 1;
const std::vector<int> kPermNCHW{0, 3, 1, 2};
const std::vector<int> kPermNHWC{0, 2, 3, 1};
void RemoveUnusedTransposeOpPass::SetFmkType(FmkType type) { this->fmk_type = type; }
bool RemoveUnusedTransposeOpPass::Run(const FuncGraphPtr &func_graph) {
if (this->fmk_type != lite::converter::FmkType_ONNX) {
MS_LOG(ERROR) << "The framework type of model should be onnx.";
return RET_ERROR;
}
MS_ASSERT(func_graph != nullptr);
auto manager = func_graph->manager();
MS_ASSERT(manager != nullptr);
auto node_list = TopoSort(func_graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto type = opt::GetCNodeType(node);
if (type == schema::PrimitiveType_Transpose) {
auto transpose_cnode = node->cast<CNodePtr>();
auto typeInput = opt::GetCNodeType(transpose_cnode->input(kTransposeInput));
if (typeInput != schema::PrimitiveType_Conv2D) {
continue;
}
auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0));
if (primPtr == nullptr) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC";
return RET_ERROR;
}
auto primT = primPtr->GetPrimitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC";
return RET_ERROR;
}
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
if (perm == kPermNCHW) {
manager->Replace(transpose_cnode, transpose_cnode->input(1));
}
} else if (type == schema::PrimitiveType_Conv2D) {
auto conv_node = node->cast<CNodePtr>();
auto typeInput = opt::GetCNodeType(conv_node->input(kTransposeInput));
if (typeInput != schema::PrimitiveType_Transpose) {
continue;
}
auto transpose_cnode = conv_node->input(kTransposeInput)->cast<CNodePtr>();
auto primPtr = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(transpose_cnode->input(0));
if (primPtr == nullptr) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveC";
return RET_ERROR;
}
auto primT = primPtr->GetPrimitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "Transpose node of onnx need to removed which has not primitiveT";
return RET_ERROR;
}
std::vector<int32_t> perm = primT->value.AsTranspose()->perm;
if (perm == kPermNHWC) {
manager->Replace(transpose_cnode, transpose_cnode->input(1));
}
} else {
continue;
}
}
return true;
}
} // namespace mindspore::opt

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
#include <string>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class RemoveUnusedTransposeOpPass : public Pass {
public:
RemoveUnusedTransposeOpPass() : Pass("remove_unused_cast_pass") {}
~RemoveUnusedTransposeOpPass() override = default;
void SetFmkType(FmkType fmkType);
bool Run(const FuncGraphPtr &graph) override;
private:
FmkType fmk_type = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_UNUSED_TRANSPOSE_PASS_H_
Loading…
Cancel
Save