fix caffe resize bugs

pull/11736/head
yeyunpeng 4 years ago
parent a84a5215ca
commit 24642d61a3

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,8 +32,8 @@ OpParameter *PopulateResizeParameter(const mindspore::lite::PrimitiveC *primitiv
resize_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Resize *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
resize_param->method_ = static_cast<int>(param->GetMethod());
resize_param->new_height_ = param->GetNewHeight();
resize_param->new_width_ = param->GetNewWidth();
resize_param->new_height_ = param->new_height();
resize_param->new_width_ = param->new_width();
resize_param->coordinate_transform_mode_ = param->GetCoordinateTransformMode();
resize_param->preserve_aspect_ratio_ = param->GetPreserveAspectRatio();
return reinterpret_cast<OpParameter *>(resize_param);

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -122,6 +122,8 @@ Registry ResizeRegistry(schema::PrimitiveType_Resize, ResizeCreator);
namespace {
constexpr int kInputRank = 4;
} // namespace
int64_t Resize::new_height() const { return new_height_; }
int64_t Resize::new_width() const { return new_width_; }
int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
@ -145,15 +147,27 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
std::vector<int> output_shape;
output_shape.push_back(input->Batch());
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
auto ret = CalculateNewHeightAndWidth(inputs_);
if (ret == RET_OK) {
output_shape.push_back(new_height_);
output_shape.push_back(new_width_);
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
}
return ret;
}
int Resize::CalculateNewHeightAndWidth(const std::vector<lite::Tensor *> &inputs) {
auto input = inputs.front();
if (inputs.size() == kDoubleNum) {
auto shape_tensor = inputs.at(1);
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_size) {
case kInputRank: {
case kQuadrupleNum: {
if (shape_tensor->data_type() == kNumberTypeInt32) {
auto data = reinterpret_cast<int32_t *>(shape_tensor->data_c());
if (data == nullptr) {
@ -162,12 +176,12 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
}
switch (shape_tensor->format()) {
case schema::Format_NCHW:
output_shape.push_back(data[2]);
output_shape.push_back(data[3]);
new_height_ = data[2];
new_width_ = data[3];
break;
case schema::Format_NHWC:
output_shape.push_back(data[1]);
output_shape.push_back(data[2]);
new_height_ = data[1];
new_width_ = data[2];
break;
default:
MS_LOG(INFO) << "Resize don't support tensor format.";
@ -181,12 +195,12 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
}
switch (shape_tensor->format()) {
case schema::Format_NCHW:
output_shape.push_back(data[2] * input->Height());
output_shape.push_back(data[3] * input->Width());
new_height_ = data[2] * input->Height();
new_width_ = data[3] * input->Width();
break;
case schema::Format_NHWC:
output_shape.push_back(data[1] * input->Height());
output_shape.push_back(data[2] * input->Width());
new_height_ = data[1] * input->Height();
new_width_ = data[2] * input->Width();
break;
default:
MS_LOG(INFO) << "Resize don't support tensor format.";
@ -195,36 +209,52 @@ int Resize::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Te
}
break;
}
default: {
case kDoubleNum: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->data_c());
if (data == nullptr) {
MS_LOG(INFO) << "Resize op size can't cast float.";
return RET_INFER_INVALID;
}
for (size_t i = 0; i < shape_size; i++) {
output_shape.push_back(data[i]);
new_height_ = data[0];
new_width_ = data[1];
break;
}
case kSingleNum: {
// caffe zoom_factor
int scale;
if (shape_tensor->data_type() == kNumberTypeInt32) {
auto data = reinterpret_cast<int *>(shape_tensor->data_c());
if (data == nullptr) {
MS_LOG(INFO) << "Resize op size can't cast int.";
return RET_INFER_INVALID;
}
scale = data[0];
} else {
MS_LOG(ERROR) << "Unsupported data type:" << shape_tensor->data_type();
return RET_INFER_ERR;
}
new_height_ = input->Height() + (input->Height() - 1) * (scale - 1);
new_width_ = input->Width() + (input->Width() - 1) * (scale - 1);
break;
}
default: {
MS_LOG(ERROR) << "Unsupported shape size:" << shape_size;
return RET_INFER_ERR;
}
}
} else if (inputs_.size() == kSingleNum) {
auto new_height = GetNewHeight();
auto new_width = GetNewWidth();
output_shape.push_back(new_height);
output_shape.push_back(new_width);
} else if (inputs_.size() == kQuadrupleNum) {
if (inputs_[3]->data_c() == nullptr) {
} else if (inputs.size() == kSingleNum) {
new_height_ = GetNewHeight();
new_width_ = GetNewWidth();
} else if (inputs.size() == kQuadrupleNum) {
if (inputs[3]->data_c() == nullptr) {
return RET_INFER_INVALID;
}
output_shape.push_back(static_cast<int *>(inputs_.at(3)->data_c())[0]);
output_shape.push_back(static_cast<int *>(inputs_.at(3)->data_c())[1]);
new_height_ = static_cast<int *>(inputs.at(3)->data_c())[0];
new_height_ = static_cast<int *>(inputs.at(3)->data_c())[1];
} else {
MS_LOG(ERROR) << "inputs tensor size invalid.";
return RET_INFER_ERR;
}
output_shape.push_back(input->Channel());
output->set_shape(output_shape);
return RET_OK;
}
} // namespace lite

@ -1,5 +1,5 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,6 +49,14 @@ class Resize : public PrimitiveC {
int64_t GetNewWidth() const;
bool GetPreserveAspectRatio() const;
int GetCoordinateTransformMode() const;
int64_t new_height() const;
int64_t new_width() const;
private:
int CalculateNewHeightAndWidth(const std::vector<lite::Tensor *> &inputs);
int64_t new_height_;
int64_t new_width_;
};
} // namespace lite
} // namespace mindspore

@ -78,7 +78,8 @@ domi::ModelBufferData *SubGraphNpuKernel::BuildIRModel() {
}
int SubGraphNpuKernel::Run() {
return reinterpret_cast<lite::NPUExecutor *>(this->executor_)->Run(in_tensors_, out_tensors_, out_nodes_, nodes_);
return reinterpret_cast<lite::NPUExecutor *>(this->executor_)
->Run(in_tensors_, out_tensor_sorted_, out_nodes_, nodes_);
}
int SubGraphNpuKernel::BuildNPUInputOp() {
@ -156,6 +157,14 @@ std::vector<ge::Operator> SubGraphNpuKernel::GetNPUNodes(const vector<kernel::Li
int SubGraphNpuKernel::BuildNPUOutputOp() {
subgraph_output_op_.clear();
subgraph_output_op_ = GetNPUNodes(out_nodes_);
out_tensor_sorted_.resize(out_tensors_.size());
int i = 0;
for (auto node : out_nodes_) {
for (auto tensor : node->out_tensors()) {
if (std::find(out_tensors_.begin(), out_tensors_.end(), tensor) != out_tensors_.end())
this->out_tensor_sorted_[i++] = tensor;
}
}
if (subgraph_output_op_.empty()) {
MS_LOG(ERROR) << "NPU subgraph output op is empty.";
return RET_ERROR;

@ -74,6 +74,8 @@ class SubGraphNpuKernel : public SubGraphKernel {
std::vector<ge::Operator> subgraph_input_op_;
std::vector<ge::Operator> subgraph_output_op_;
std::vector<lite::Tensor *> out_tensor_sorted_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_AGENT_SUBGRAPH_NPU_KERNEL_H_

@ -62,9 +62,9 @@ int ResizeBaseCPUKernel::CheckParameters() {
MS_LOG(INFO) << "Out shape is not assigned";
const_shape_ = false;
} else {
auto ret = CalculateLinearNewHeightWidth();
if (ret != RET_OK) {
return ret;
if (InferShapeDone()) {
new_height_ = out_tensors_.at(0)->shape().at(1);
new_width_ = out_tensors_.at(0)->shape().at(2);
}
const_shape_ = true;
}
@ -78,52 +78,6 @@ int ResizeBaseCPUKernel::CheckParameters() {
return RET_OK;
}
int ResizeBaseCPUKernel::CalculateLinearNewHeightWidth() {
if (method_ != static_cast<int>(schema::ResizeMethod_LINEAR)) {
return RET_OK;
}
if (in_tensors_.size() != 2) {
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(0);
auto shape_scale_tensor = in_tensors_.at(1);
if (shape_scale_tensor->data_type() == kNumberTypeFloat32) {
// float type means scale
float *shape_scale = reinterpret_cast<float *>(shape_scale_tensor->data_c());
if (shape_scale == nullptr) {
return RET_ERROR;
}
if (shape_scale_tensor->format() == schema::Format_NHWC) {
new_height_ = input_tensor->Height() * shape_scale[1];
new_width_ = input_tensor->Width() * shape_scale[2];
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
new_height_ = input_tensor->Height() * shape_scale[2];
new_width_ = input_tensor->Width() * shape_scale[3];
} else {
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
return RET_ERROR;
}
} else if (shape_scale_tensor->data_type() == kNumberTypeInt32) {
// int32 type means real shape
int32_t *shape_data = reinterpret_cast<int32_t *>(shape_scale_tensor->data_c());
if (shape_data == nullptr) {
return RET_ERROR;
}
if (shape_scale_tensor->format() == schema::Format_NHWC) {
new_height_ = shape_data[1];
new_width_ = shape_data[2];
} else if (shape_scale_tensor->format() == schema::Format_NCHW) {
new_height_ = shape_data[2];
new_width_ = shape_data[3];
} else {
MS_LOG(ERROR) << "resize not support format " << shape_scale_tensor->format();
return RET_ERROR;
}
}
return RET_OK;
}
int ResizeBaseCPUKernel::CheckInputsOuputs() {
if (in_tensors_.size() <= lite::kQuadrupleNum) {
for (size_t i = 0; i < in_tensors_.size(); i++) {

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,7 +47,6 @@ class ResizeBaseCPUKernel : public LiteKernel {
private:
int CheckParameters();
int CheckInputsOuputs();
int CalculateLinearNewHeightWidth();
};
} // namespace mindspore::kernel

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -44,7 +44,9 @@ int ScaleNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, cons
op_->set_attr_axis(scale_parameter_->axis_);
op_->set_input_x(*npu_inputs[0]);
op_->set_input_scale(*npu_inputs[1]);
op_->set_input_bias(*npu_inputs[2]);
if (npu_inputs[2] != nullptr) {
op_->set_input_bias(*npu_inputs[2]);
}
return RET_OK;
}

@ -218,6 +218,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/if_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_control_op_pass.cc
${LITE_DIR}/tools/optimizer/graph/functionalize_while.cc
${LITE_DIR}/tools/optimizer/graph/inputs_adjust_pass.cc
)
endif()
### train

@ -69,6 +69,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/mindir_inputs_adjust_pass.cc
../optimizer/graph/functionalize_control_op_pass.cc
../optimizer/graph/functionalize_while.cc
../optimizer/graph/inputs_adjust_pass.cc
)
add_subdirectory(../anf_importer anf_importer)

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,6 +49,7 @@
#include "tools/optimizer/graph/while_pass.h"
#include "tools/optimizer/graph/if_pass.h"
#include "tools/optimizer/graph/functionalize_control_op_pass.h"
#include "tools/optimizer/graph/inputs_adjust_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/huffman_encode.h"
@ -124,6 +125,7 @@ int AnfTransform::AddGraphPass(const std::shared_ptr<opt::GraphOptimizer> &optim
auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
slice_prepose_pass->SetFmkType(config->fmk);
graph_pm->AddPass(slice_prepose_pass);
graph_pm->AddPass(std::make_shared<opt::InputAdjustPass>());
optimizer->AddPassManager(graph_pm);
return RET_OK;
}

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -27,9 +27,9 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p
return nullptr;
}
const caffe::InterpParameter &interpParam = proto.interp_param();
if (interpParam.has_height()) {
int64_t height = interpParam.height();
const caffe::InterpParameter &interp_param = proto.interp_param();
if (interp_param.has_height()) {
int64_t height = interp_param.height();
if (height < 0) {
MS_LOG(ERROR) << "Interp height must be > 0";
return nullptr;
@ -37,8 +37,8 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p
attr->newHeight = height;
}
if (interpParam.has_width()) {
int64_t width = interpParam.width();
if (interp_param.has_width()) {
int64_t width = interp_param.width();
if (width < 0) {
MS_LOG(ERROR) << "Interp width must be > 0";
return nullptr;
@ -50,7 +50,11 @@ PrimitiveC *CaffeInterpParser::ParseLitePrimitive(const caffe::LayerParameter &p
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Resize;
primitive->value.value = attr.release();
return PrimitiveC::Create(primitive.release());
auto primitive_c = PrimitiveC::Create(primitive.release());
if (interp_param.has_zoom_factor()) {
primitive_c->AddAttr("zoom_factor", MakeValue(interp_param.zoom_factor()));
}
return primitive_c;
}
CaffeNodeRegistrar g_caffeInterpParser("Interp", new CaffeInterpParser());

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,14 +34,15 @@ PrimitiveC *CaffeScaleParser::ParseLitePrimitive(const caffe::LayerParameter &pr
}
const caffe::ScaleParameter &scaleParam = weight.scale_param();
attr->axis = 1;
if (scaleParam.has_axis()) {
uint32_t axis_index = 1;
if (GetAxisIndex(scaleParam.axis(), &axis_index)) {
MS_LOG(ERROR) << "scale get axis failed for layer " << weight.name().c_str();
return nullptr;
}
attr->axis = axis_index;
}
attr->axis = 1;
auto primitive = std::make_unique<schema::PrimitiveT>();
primitive->value.type = schema::PrimitiveType_Scale;
primitive->value.value = attr.release();

File diff suppressed because it is too large Load Diff

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,6 +19,7 @@
#include <memory>
#include <vector>
#include <string>
#include "src/ops/primitive_c.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
@ -40,6 +41,8 @@ bool IsRealCNodeKernel(const AnfNodePtr &node);
bool IsGraphKernel(const AnfNodePtr &node);
bool CheckInputs(const CNodePtr &cnode);
int CheckIfFuncGraphIsNull(const FuncGraphPtr &graph);
int CheckIfAnfNodeIsNull(const AnfNodePtr &node);
@ -121,6 +124,19 @@ template <typename T>
static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type);
STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format);
ParameterPtr BuildIntValueParameterNode(const FuncGraphPtr &func_graph, const int32_t &data,
const std::string &node_name);
ParameterPtr BuildIntVecParameterNode(const FuncGraphPtr &func_graph, const std::vector<int32_t> &data,
const std::string &node_name);
ParameterPtr BuildIntVec2DParameterNode(const FuncGraphPtr &func_graph, const std::vector<std::vector<int32_t>> &data,
const std::string &node_name);
ParameterPtr BuildFloatValueParameterNode(const FuncGraphPtr &func_graph, const float &data,
const std::string &node_name);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

@ -0,0 +1,109 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/inputs_adjust_pass.h"
#include <vector>
#include <string>
#include <memory>
#include "src/ops/primitive_c.h"
namespace mindspore::opt {
STATUS InputAdjustPass::AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num,
const std::string &attr_name, int flag) {
MS_ASSERT(cnode != nullptr);
if (!CheckInputs(cnode)) {
MS_LOG(ERROR) << "input is invalid.";
return lite::RET_INPUT_TENSOR_ERROR;
}
auto primitive_c = GetValueNode<PrimitiveCPtr>(cnode->input(0));
auto value_ptr = primitive_c->GetAttr(attr_name);
if (value_ptr == nullptr) {
MS_LOG(DEBUG) << "there is no attr :" << attr_name;
return lite::RET_NO_CHANGE;
}
auto inputs = cnode->inputs();
if (static_cast<int>(inputs.size()) > input_num) {
primitive_c->EraseAttr(attr_name);
MS_LOG(DEBUG) << "input num has been meet, which is " << inputs.size();
return lite::RET_OK;
} else if (static_cast<int>(inputs.size()) < input_num) {
MS_LOG(ERROR) << "input num is invalid.";
return lite::RET_ERROR;
}
switch (flag) {
case 1: {
auto value_data = GetValue<int32_t>(value_ptr);
auto param_node =
BuildIntValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case 2: {
auto value_data = GetValue<std::vector<int32_t>>(value_ptr);
auto param_node =
BuildIntVecParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case 3: {
auto value_data = GetValue<std::vector<std::vector<int32_t>>>(value_ptr);
auto param_node =
BuildIntVec2DParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
case 4: {
auto value_data = GetValue<float>(value_ptr);
auto param_node =
BuildFloatValueParameterNode(func_graph, value_data, cnode->fullname_with_scope() + "_" + attr_name);
inputs.push_back(param_node);
break;
}
default: {
MS_LOG(ERROR) << "Error attr flag";
return lite::RET_ERROR;
}
}
cnode->set_inputs(inputs);
return lite::RET_OK;
}
bool InputAdjustPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto manager = Manage(func_graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return lite::RET_NULL_PTR;
}
auto node_list = TopoSort(func_graph->get_return());
STATUS status = lite::RET_OK;
for (auto &node : node_list) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
continue;
}
if (GetCNodeType(node) == schema::PrimitiveType_Resize) {
status = AddAttrToInput(func_graph, cnode, 2, "zoom_factor", 1);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "adjust input pass is failed.";
return false;
}
}
return true;
}
} // namespace mindspore::opt

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_
#include <vector>
#include <string>
#include "tools/optimizer/common/gllo_utils.h"
#include "backend/optimizer/common/pass.h"
#include "src/param_value_lite.h"
#include "mindspore/lite/include/errorcode.h"
using mindspore::lite::STATUS;
namespace mindspore::opt {
class InputAdjustPass : public Pass {
public:
InputAdjustPass() : Pass("input_adjust") {}
~InputAdjustPass() override = default;
static STATUS AddAttrToInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, int input_num,
const std::string &attr_name, int flag);
bool Run(const FuncGraphPtr &func_graph) override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_INPUTS_ADJUST_PASS_H_
Loading…
Cancel
Save