!9018 [MSLITE]add news models to the entrance guard and fix some bugs
From: @probiotics_53 Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhanghaibo5pull/9018/MERGE
commit
0552eff0c8
@ -0,0 +1,64 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/optimizer/graph/update_conv2d_param_pass.h"
|
||||
#include <memory>
|
||||
#include "mindspore/lite/include/errorcode.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
auto manager = func_graph->manager();
|
||||
MS_ASSERT(manager != nullptr);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
int status = RET_OK;
|
||||
for (auto &node : node_list) {
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
continue;
|
||||
}
|
||||
auto type = opt::GetCNodeType(node);
|
||||
if (type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
auto dwconv2d_cnode = node->cast<CNodePtr>();
|
||||
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(dwconv2d_cnode->input(0));
|
||||
if (primitive_c == nullptr) {
|
||||
MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveC.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto primT = primitive_c->primitiveT();
|
||||
if (primT == nullptr) {
|
||||
MS_LOG(ERROR) << "Depthwise conv2D node has no primitiveT.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int channel_in = primT->value.AsDepthwiseConv2D()->channelIn;
|
||||
if (channel_in == -1) {
|
||||
auto input_node = node->cast<CNodePtr>()->input(lite::kAnfPopulaterInputNumTwo);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
if (input_node->isa<Parameter>()) {
|
||||
auto param_node = input_node->cast<ParameterPtr>();
|
||||
auto param = param_node->default_param();
|
||||
auto weight = std::dynamic_pointer_cast<ParamValueLite>(param);
|
||||
primT->value.AsDepthwiseConv2D()->channelIn = weight->tensor_shape().at(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
|
||||
MS_LOG(ERROR) << "remove identity pass is failed.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::opt
|
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
class UpdateConv2DParamPass : public Pass {
|
||||
public:
|
||||
UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {}
|
||||
~UpdateConv2DParamPass() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_
|
Loading…
Reference in new issue