!10324 [lite]add argmax、layernorm、batchmatmul for minidr

From: @xu_anyue
Reviewed-by: 
Signed-off-by:
pull/10324/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 00455b9559

@ -261,6 +261,7 @@ union PrimitiveType {
Reciprocal, Reciprocal,
Merge, Merge,
Mod, Mod,
GeLU,
} }
enum QuantType: int { enum QuantType: int {

@ -34,7 +34,36 @@ void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgM
void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; } void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; } void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; } void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }
int ArgMax::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_ArgMax;
}
if (this->primitive_->value.type != schema::PrimitiveType_ArgMax) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto argmax_attr = new (std::nothrow) schema::ArgMaxT();
if (argmax_attr == nullptr) {
MS_LOG(ERROR) << "new primitive value.value error";
return RET_ERROR;
}
if (prim.GetAttr("axis") != nullptr) {
argmax_attr->axis = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("axis")));
}
if (prim.GetAttr("keep_dims") != nullptr) {
argmax_attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims")));
}
argmax_attr->outMaxValue = false;
this->primitive_->value.value = argmax_attr;
}
return RET_OK;
}
#else #else
int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int ArgMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);

@ -37,6 +37,7 @@ class ArgMax : public PrimitiveC {
void SetTopK(int top_k); void SetTopK(int top_k);
void SetKeepDims(bool keep_dims); void SetKeepDims(bool keep_dims);
void SetAxisType(int axis_type); void SetAxisType(int axis_type);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif

@ -61,6 +61,7 @@ int ArgMin::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
if (prim.GetAttr("keep_dims") != nullptr) { if (prim.GetAttr("keep_dims") != nullptr) {
attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims"))); attr->keepDims = static_cast<bool>(GetValue<bool>(prim.GetAttr("keep_dims")));
} }
attr->outMaxValue = false;
} }
return RET_OK; return RET_OK;
} }

@ -0,0 +1,53 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/gelu.h"
#include <memory>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GeLU::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_GeLU;
}
if (this->primitive_->value.type != schema::PrimitiveType_GeLU) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
this->primitive_->value.value = new (std::nothrow) schema::GeLUT();
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
}
return RET_OK;
}
#endif
} // namespace lite
} // namespace mindspore

@ -0,0 +1,40 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_
#define LITE_MINDSPORE_LITE_C_OPS_GELU_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class GeLU : public PrimitiveC {
public:
GeLU() = default;
~GeLU() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GeLU, PrimitiveC);
explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#endif
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_GELU_H_

@ -35,7 +35,42 @@ void LayerNorm::SetEpsilon(float epsilon) { this->primitive_->value.AsLayerNorm(
void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) { void LayerNorm::SetElementwiseAffine(bool elementwiseAffine) {
this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine; this->primitive_->value.AsLayerNorm()->elementwiseAffine = elementwiseAffine;
} }
int LayerNorm::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitive error";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_LayerNorm;
}
if (this->primitive_->value.type != schema::PrimitiveType_LayerNorm) {
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto layer_norm_attr = new (std::nothrow) schema::LayerNormT();
if (layer_norm_attr == nullptr) {
MS_LOG(ERROR) << "new primitive value.value error";
return RET_ERROR;
}
auto value_attr = prim.GetAttr("epsilon");
if (value_attr != nullptr) {
layer_norm_attr->epsilon = GetValue<float>(value_attr);
} else {
layer_norm_attr->epsilon = 1e-7;
}
value_attr = prim.GetAttr("normalized_shape");
if (value_attr != nullptr) {
layer_norm_attr->normalizedShape = CastToInt(value_attr);
}
if (inputs.size() == 3) {
layer_norm_attr->elementwiseAffine = true;
}
this->primitive_->value.value = layer_norm_attr;
}
return RET_OK;
}
#else #else
int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive); MS_ASSERT(nullptr != primitive);
@ -100,13 +135,12 @@ int LayerNorm::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite:
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (normlized_shape_.empty()) { if (normlized_shape_.empty()) {
// instance norm -> layernorm // instance norm -> layernorm only for nchw
if (input->format() == schema::Format_NCHW) { if (input->format() == schema::Format_NCHW) {
normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end()); normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 2, input_shape.end());
elementwise_mode_ = 1; elementwise_mode_ = 1;
} else { } else {
MS_LOG(INFO) << "normalized_shape attr invalid"; normlized_shape_.insert(normlized_shape_.begin(), input_shape.begin() + 1, input_shape.end());
return RET_PARAM_INVALID;
} }
} }
size_t first_index = input_shape.size() - normlized_shape_.size(); size_t first_index = input_shape.size() - normlized_shape_.size();

@ -35,6 +35,7 @@ class LayerNorm : public PrimitiveC {
void SetNormalizedShape(const std::vector<int> &normalizedShape); void SetNormalizedShape(const std::vector<int> &normalizedShape);
void SetEpsilon(float epsilon); void SetEpsilon(float epsilon);
void SetElementwiseAffine(bool elementwiseAffine); void SetElementwiseAffine(bool elementwiseAffine);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif

@ -160,6 +160,7 @@
#include "src/ops/merge.h" #include "src/ops/merge.h"
#include "src/ops/switch.h" #include "src/ops/switch.h"
#include "src/ops/partial.h" #include "src/ops/partial.h"
#include "src/ops/gelu.h"
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
#include "src/ops/neg_grad.h" #include "src/ops/neg_grad.h"
@ -330,9 +331,28 @@ void PrimitiveC::PopulaterOutputQuantParam(const Primitive &prim, bool narrowRan
void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { void PrimitiveC::PopulaterQuantParam(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
auto narrow_range = prim.GetAttr("narrow_range"); auto narrow_range = prim.GetAttr("narrow_range");
bool narrowRangeQuantParam = narrow_range != nullptr && GetValue<bool>(narrow_range); bool narrowRangeQuantParam = false;
if (narrow_range != nullptr) {
if (utils::isa<tensor::TensorPtr>(narrow_range)) {
auto narrow_range_tensor = narrow_range->cast<tensor::TensorPtr>();
narrowRangeQuantParam = *reinterpret_cast<bool *>(narrow_range_tensor->data_c());
} else if (utils::isa<ImmTraits<bool>::type>(narrow_range)) {
narrowRangeQuantParam = GetValue<bool>(narrow_range);
} else {
MS_LOG(ERROR) << "valueptr is invalid.";
return;
}
}
auto num_bits = prim.GetAttr("num_bits"); auto num_bits = prim.GetAttr("num_bits");
int32_t numbitsRangeQuantParam = num_bits != nullptr ? GetValue<int64_t>(num_bits) : 8; int32_t numbitsRangeQuantParam = 8;
if (num_bits != nullptr) {
if (utils::isa<tensor::TensorPtr>(num_bits)) {
auto num_bits_tensor = num_bits->cast<tensor::TensorPtr>();
numbitsRangeQuantParam = *reinterpret_cast<int64_t *>(num_bits_tensor->data_c());
} else if (utils::isa<ImmTraits<int64_t>::type>(num_bits)) {
numbitsRangeQuantParam = GetValue<int64_t>(num_bits);
}
}
PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam); PopulaterInputQuantParam(prim, inputs, narrowRangeQuantParam, numbitsRangeQuantParam);
PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam); PopulaterOutputQuantParam(prim, narrowRangeQuantParam, numbitsRangeQuantParam);
} }
@ -511,7 +531,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType); return NewPrimitiveC<FusedBatchNorm>(prim, inputs, quantType);
} else if (op_type == "make_tuple") { } else if (op_type == "make_tuple") {
return NewPrimitiveC<MakeTuple>(prim, inputs, quantType); return NewPrimitiveC<MakeTuple>(prim, inputs, quantType);
} else if (op_type == "MatMul") { } else if (op_type == "MatMul" || op_type == "BatchMatMul") {
return NewPrimitiveC<MatMul>(prim, inputs, quantType); return NewPrimitiveC<MatMul>(prim, inputs, quantType);
} else if (op_type == "Mul") { } else if (op_type == "Mul") {
return NewPrimitiveC<Mul>(prim, inputs, quantType); return NewPrimitiveC<Mul>(prim, inputs, quantType);
@ -601,7 +621,7 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<TopK>(prim, inputs, quantType); return NewPrimitiveC<TopK>(prim, inputs, quantType);
} else if (op_type == "Mod") { } else if (op_type == "Mod") {
return NewPrimitiveC<Mod>(prim, inputs, quantType); return NewPrimitiveC<Mod>(prim, inputs, quantType);
} else if (op_type == "ArgMinWithValue") { } else if (op_type == "ArgMin" || op_type == "ArgMinWithValue") {
return NewPrimitiveC<ArgMin>(prim, inputs, quantType); return NewPrimitiveC<ArgMin>(prim, inputs, quantType);
} else if (op_type == "Range") { } else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType); return NewPrimitiveC<Range>(prim, inputs, quantType);
@ -621,6 +641,12 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Partial>(prim, inputs, quantType); return NewPrimitiveC<Partial>(prim, inputs, quantType);
} else if (op_type == "Merge") { } else if (op_type == "Merge") {
return NewPrimitiveC<Merge>(prim, inputs, quantType); return NewPrimitiveC<Merge>(prim, inputs, quantType);
} else if (op_type == "LayerNorm") {
return NewPrimitiveC<LayerNorm>(prim, inputs, quantType);
} else if (op_type == "ArgMax" || op_type == "ArgMaxWithValue") {
return NewPrimitiveC<ArgMax>(prim, inputs, quantType);
} else if (op_type == "Gelu") {
return NewPrimitiveC<GeLU>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") { } else if (op_type == "SoftmaxCrossEntropyWithLogits") {
@ -965,6 +991,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Partial(primitive); return new (std::nothrow) Partial(primitive);
case schema::PrimitiveType_Assert: case schema::PrimitiveType_Assert:
return new (std::nothrow) AssertOP(primitive); return new (std::nothrow) AssertOP(primitive);
case schema::PrimitiveType_GeLU:
return new (std::nothrow) GeLU(primitive);
#ifdef SUPPORT_TRAIN #ifdef SUPPORT_TRAIN
case schema::PrimitiveType_ActivationGrad: case schema::PrimitiveType_ActivationGrad:
return new (std::nothrow) ActivationGrad(primitive); return new (std::nothrow) ActivationGrad(primitive);

@ -199,6 +199,7 @@ if(ENABLE_CONVERTER)
${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc
${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/mindir_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc ${LITE_DIR}/tools/optimizer/graph/onnx_inputs_adjust_pass.cc
${LITE_DIR}/tools/optimizer/graph/while_pass.cc ${LITE_DIR}/tools/optimizer/graph/while_pass.cc
) )

@ -737,7 +737,8 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
meta_graphT->allTensors.emplace_back(msTensor); meta_graphT->allTensors.emplace_back(msTensor);
if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) || if (IsPrimitiveCNode(cnode, schema::PrimitiveType_Conv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) || IsPrimitiveCNode(cnode, schema::PrimitiveType_DepthwiseConv2D) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm)) { IsPrimitiveCNode(cnode, schema::PrimitiveType_FusedBatchNorm) ||
IsPrimitiveCNode(cnode, schema::PrimitiveType_LayerNorm)) {
break; break;
} }
#endif #endif

@ -60,6 +60,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/graph/mindir_adjust_pass.cc ../optimizer/graph/mindir_adjust_pass.cc
../optimizer/graph/onnx_inputs_adjust_pass.cc ../optimizer/graph/onnx_inputs_adjust_pass.cc
../optimizer/graph/while_pass.cc ../optimizer/graph/while_pass.cc
../optimizer/graph/mindir_inputs_adjust_pass.cc
) )
add_subdirectory(../anf_importer anf_importer) add_subdirectory(../anf_importer anf_importer)

@ -30,6 +30,7 @@
#include "tools/optimizer/fusion/sigmoid_mul_fusion.h" #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
#include "tools/optimizer/fusion/conv_conv_fusion.h" #include "tools/optimizer/fusion/conv_conv_fusion.h"
#include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/mindir_adjust_pass.h"
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include "tools/optimizer/graph/identity_remove_pass.h" #include "tools/optimizer/graph/identity_remove_pass.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h"
@ -77,6 +78,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr; return nullptr;
} }
auto mindir_inputs_adjust_pass = std::make_shared<opt::MindirInputAdjustOpPass>();
if (!mindir_inputs_adjust_pass->Run(old_graph)) {
MS_LOG(ERROR) << "mindir inputs adjust failed.";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
return nullptr;
}
} }
// onnx pre adjustment // onnx pre adjustment

@ -0,0 +1,236 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/mindir_inputs_adjust_pass.h"
#include <vector>
#include <memory>
#include "src/common/log_adapter.h"
#include "src/ops/primitive_c.h"
#include "src/tensor.h"
using mindspore::lite::PrimitiveC;
namespace mindspore {
namespace opt {
namespace {
template <typename T>
void CopyAttrForArgMinMax(T *left, T *right) {
MS_ASSERT(left != null && right != nullptr);
left->axis = right->axis;
left->outMaxValue = right->outMaxValue;
left->axisType = right->axisType;
left->keepDims = right->keepDims;
left->topK = right->topK;
}
} // namespace
bool MindirInputAdjustOpPass::CheckCNodeIsArgMinMax(const CNodePtr &cnode) {
MS_ASSERT(cnode != nullptr);
auto prim_node = cnode->inputs().at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(DEBUG) << "cnode first input is not valueNode.";
return false;
}
auto value = prim_value_node->value();
MS_ASSERT(value != nullptr);
auto prim_c = value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(DEBUG) << "prim is not primitiveC.";
return false;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr);
return prim->value.type == schema::PrimitiveType_ArgMax || prim->value.type == schema::PrimitiveType_ArgMin;
}
int MindirInputAdjustOpPass::AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value) {
MS_ASSERT(inputs != nullptr);
auto prim_node = inputs->at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(ERROR) << "cnode first input is not valueNode.";
return lite::RET_ERROR;
}
auto prim_value = prim_value_node->value();
if (prim_value == nullptr) {
MS_LOG(ERROR) << "valueNode value is nullptr.";
return lite::RET_ERROR;
}
auto prim_c = prim_value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(ERROR) << "value is not primitiveC.";
return lite::RET_ERROR;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr && prim->value.value != nullptr);
auto attr = prim->value.value;
if (prim->value.type == schema::PrimitiveType_ArgMax) {
reinterpret_cast<schema::ArgMaxT *>(attr)->outMaxValue = index_or_value;
} else if (prim->value.type == schema::PrimitiveType_ArgMin) {
reinterpret_cast<schema::ArgMinT *>(attr)->outMaxValue = index_or_value;
}
return lite::RET_OK;
}
int MindirInputAdjustOpPass::CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs) {
MS_ASSERT(inputs != nullptr);
auto prim_node = inputs->at(0);
MS_ASSERT(prim_node != nullptr);
auto prim_value_node = prim_node->cast<ValueNodePtr>();
if (prim_value_node == nullptr) {
MS_LOG(ERROR) << "cnode first input is not valueNode.";
return lite::RET_ERROR;
}
auto prim_value = prim_value_node->value();
if (prim_value == nullptr) {
MS_LOG(ERROR) << "valueNode value is nullptr.";
return lite::RET_ERROR;
}
auto prim_c = prim_value->cast<PrimitiveCPtr>();
if (prim_c == nullptr) {
MS_LOG(ERROR) << "value is not primitiveC.";
return lite::RET_ERROR;
}
auto prim = prim_c->primitiveT();
MS_ASSERT(prim != nullptr && prim->value.value != nullptr);
auto primitive = std::make_unique<schema::PrimitiveT>();
if (prim->value.type == schema::PrimitiveType_ArgMax) {
primitive->value.type = schema::PrimitiveType_ArgMax;
auto attr = std::make_unique<schema::ArgMaxT>();
CopyAttrForArgMinMax<schema::ArgMaxT>(attr.get(), reinterpret_cast<schema::ArgMaxT *>(prim->value.value));
primitive->value.value = attr.release();
} else {
primitive->value.type = schema::PrimitiveType_ArgMin;
auto attr = std::make_unique<schema::ArgMinT>();
CopyAttrForArgMinMax<schema::ArgMinT>(attr.get(), reinterpret_cast<schema::ArgMinT *>(prim->value.value));
primitive->value.value = attr.release();
}
auto primitive_c = PrimitiveC::Create(primitive.release());
auto value_node = NewValueNode(std::shared_ptr<PrimitiveC>(primitive_c));
inputs->erase(inputs->begin());
inputs->insert(inputs->begin(), value_node);
return lite::RET_OK;
}
int MindirInputAdjustOpPass::BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item,
const CNodePtr &argmin_max) {
MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr);
auto inputs = argmin_max->inputs();
if (CopyPrimitiveCForArgMinMax(&inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "copy argmin or argmax failed.";
return lite::RET_ERROR;
}
if (AdjustArgMinMaxInputs(&inputs, false) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin or argmax attr failed.";
return lite::RET_ERROR;
}
auto new_cnode = graph->NewCNode(inputs);
new_cnode->set_fullname_with_scope(argmin_max->fullname_with_scope() + "_index");
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
new_cnode->set_abstract(std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector));
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(tuple_get_item, new_cnode);
return lite::RET_OK;
}
int MindirInputAdjustOpPass::AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item,
const CNodePtr &argmin_max) {
MS_ASSERT(graph != nullptr && tuple_get_item != nullptr && argmin_max != nullptr);
auto inputs = argmin_max->inputs();
if (AdjustArgMinMaxInputs(&inputs, true) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin or argmax attr failed.";
return lite::RET_ERROR;
}
auto type_ptr = TypeIdToType(kTypeUnknown);
std::vector<int64_t> shape_vector;
auto abtract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vector);
argmin_max->set_abstract(abtract_tensor);
auto manager = graph->manager();
MS_ASSERT(manager != nullptr);
manager->Replace(tuple_get_item, argmin_max);
return lite::RET_OK;
}
int MindirInputAdjustOpPass::AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode) {
MS_ASSERT(graph != nullptr && cnode != nullptr);
auto inputs = cnode->inputs();
if (inputs.size() != 3) {
MS_LOG(ERROR) << "tupleGetItem inputs size is invalid: " << inputs.size();
return lite::RET_ERROR;
}
auto argmin_max = inputs.at(1);
MS_ASSERT(argmin_max != nullptr);
auto argmin_max_cnode = argmin_max->cast<CNodePtr>();
if (argmin_max_cnode == nullptr) {
MS_LOG(ERROR) << "the second input is not a cnode.";
return lite::RET_ERROR;
}
if (!CheckCNodeIsArgMinMax(argmin_max_cnode)) {
MS_LOG(DEBUG) << "tuple_get_item first input is not argmin and argmax.";
return lite::RET_OK;
}
auto index_vnode = inputs.at(2);
auto value_node = index_vnode->cast<ValueNodePtr>();
if (value_node == nullptr) {
MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
return lite::RET_ERROR;
}
int index = lite::CastToInt(value_node->value()).front();
if (index == 0) {
if (BuildCNodeForArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "build new cnode failed.";
return lite::RET_ERROR;
}
} else if (index == 1) {
if (AdjustArgMinMax(graph, cnode, argmin_max_cnode) != lite::RET_OK) {
MS_LOG(ERROR) << "adjust argmin_max failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
bool MindirInputAdjustOpPass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto manager = Manage(graph, true);
if (manager == nullptr) {
MS_LOG(ERROR) << "manager is nullptr.";
return lite::RET_NULL_PTR;
}
auto node_list = TopoSort(graph->get_return());
int status = lite::RET_OK;
for (auto &node : node_list) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
MS_LOG(DEBUG) << "node is not cnode.";
continue;
}
auto type = opt::GetCNodeType(node);
if (type == schema::PrimitiveType_TupleGetItem) {
status = AdjustTupleGetItemWithArgMinMax(graph, cnode);
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "adjust input pass is failed.";
return false;
}
}
return true;
}
} // namespace opt
} // namespace mindspore

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_
#include <string>
#include <vector>
#include "backend/optimizer/common/pass.h"
#include "tools/converter/converter_flags.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "src/param_value_lite.h"
namespace mindspore::opt {
class MindirInputAdjustOpPass : public Pass {
public:
MindirInputAdjustOpPass() : Pass("mindir_inputs_adjust_pass") {}
~MindirInputAdjustOpPass() override = default;
bool CheckCNodeIsArgMinMax(const CNodePtr &cnode);
int AdjustArgMinMaxInputs(std::vector<AnfNodePtr> *inputs, bool index_or_value);
int CopyPrimitiveCForArgMinMax(std::vector<AnfNodePtr> *inputs);
int BuildCNodeForArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max);
int AdjustArgMinMax(const FuncGraphPtr &graph, const CNodePtr &tuple_get_item, const CNodePtr &argmin_max);
int AdjustTupleGetItemWithArgMinMax(const FuncGraphPtr &graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &graph) override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_MINDIR_INPUTS_ADJUST_PASS_H_

@ -43,8 +43,6 @@ class OnnxInputAdjustOpPass : public Pass {
STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS ReplaceConstant(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS ReplaceTransposeWithGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
bool Run(const FuncGraphPtr &func_graph) override; bool Run(const FuncGraphPtr &func_graph) override;
private:
}; };
} // namespace mindspore::opt } // namespace mindspore::opt
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_ #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_ONNX_INPUTS_ADJUST_PASS_H_

Loading…
Cancel
Save