!10013 [MSLITE] Fix bug of converter for mindspore models.

From: @wang_shaocong
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
pull/10013/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9ed5168c90

@ -22,7 +22,35 @@
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
#ifdef PRIMITIVE_WRITEABLE
int Equal::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Equal;
}
if (this->primitive_->value.type != schema::PrimitiveType_Equal) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::EqualT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);

@ -31,6 +31,7 @@ class Equal : public ArithmeticCompare {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, ArithmeticCompare);
explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -23,7 +23,36 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int GatherNd::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_GatherNd;
}
if (this->primitive_->value.type != schema::PrimitiveType_GatherNd) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::GatherNdT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("batchDims") != nullptr) {
attr->batchDims = static_cast<int32_t>(GetValue<int64_t>(prim.GetAttr("batchDims")));
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -32,6 +32,7 @@ class GatherNd : public PrimitiveC {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GatherNd, PrimitiveC);
explicit GatherNd(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -22,8 +22,35 @@
namespace mindspore {
namespace lite {
#ifndef PRIMITIVE_WRITEABLE
#ifdef PRIMITIVE_WRITEABLE
int Greater::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Greater;
}
if (this->primitive_->value.type != schema::PrimitiveType_Greater) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::GreaterT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);

@ -31,6 +31,7 @@ class Greater : public ArithmeticCompare {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, ArithmeticCompare);
explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -593,6 +593,22 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<Div>(prim, inputs, quantType);
} else if (op_type == "Tanh") {
return NewPrimitiveC<Activation>(prim, inputs, quantType);
} else if (op_type == "Equal") {
return NewPrimitiveC<Equal>(prim, inputs, quantType);
} else if (op_type == "TopK") {
return NewPrimitiveC<TopK>(prim, inputs, quantType);
} else if (op_type == "Range") {
return NewPrimitiveC<Range>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
} else if (op_type == "GatherNd") {
return NewPrimitiveC<GatherNd>(prim, inputs, quantType);
} else if (op_type == "Square") {
return NewPrimitiveC<Square>(prim, inputs, quantType);
} else if (op_type == "Sqrt") {
return NewPrimitiveC<Sqrt>(prim, inputs, quantType);
} else if (op_type == "Greater") {
return NewPrimitiveC<Greater>(prim, inputs, quantType);
#ifdef SUPPORT_TRAIN
} else if (op_type == "SoftmaxCrossEntropyWithLogits") {
return NewPrimitiveC<SoftmaxCrossEntropy>(prim, inputs, quantType);
@ -621,8 +637,6 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<FlattenGrad>(prim, inputs, quantType);
} else if (op_type == "FusedBatchNormGrad") {
return NewPrimitiveC<BNGrad>(prim, inputs, quantType);
} else if (op_type == "Tile") {
return NewPrimitiveC<Tile>(prim, inputs, quantType);
} else if (op_type == "PowerGrad") {
return NewPrimitiveC<PowerGrad>(prim, inputs, quantType);
} else if (op_type == "SGD") {

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <algorithm>
#include "src/ops/range.h"
#ifndef PRIMITIVE_WRITEABLE
@ -32,7 +33,43 @@ void Range::SetDType(int d_type) { this->primitive_->value.AsRange()->dType = d_
void Range::SetStart(int start) { this->primitive_->value.AsRange()->start = start; }
void Range::SetLimit(int limit) { this->primitive_->value.AsRange()->limit = limit; }
void Range::SetDelta(int delta) { this->primitive_->value.AsRange()->delta = delta; }
int Range::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Range;
}
if (this->primitive_->value.type != schema::PrimitiveType_Range) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::RangeT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
attr->dType = 0;
if (prim.GetAttr("start") != nullptr) {
attr->start = static_cast<int32_t>(GetValue<float>(prim.GetAttr("start")));
}
if (prim.GetAttr("limit") != nullptr) {
attr->limit = static_cast<int32_t>(GetValue<float>(prim.GetAttr("limit")));
}
if (prim.GetAttr("delta") != nullptr) {
attr->delta = static_cast<int32_t>(GetValue<float>(prim.GetAttr("delta")));
}
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Range::GetDType() const { return this->primitive_->value_as_Range()->dType(); }

@ -36,6 +36,7 @@ class Range : public PrimitiveC {
void SetStart(int start);
void SetLimit(int limit);
void SetDelta(int delta);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -23,6 +23,33 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Sqrt::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Sqrt;
}
if (this->primitive_->value.type != schema::PrimitiveType_Sqrt) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SqrtT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -32,6 +32,7 @@ class Sqrt : public ArithmeticSelf {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Sqrt, ArithmeticSelf);
explicit Sqrt(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -23,6 +23,33 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Square::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Square;
}
if (this->primitive_->value.type != schema::PrimitiveType_Square) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::SquareT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -31,6 +31,7 @@ class Square : public ArithmeticSelf {
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Square, ArithmeticSelf);
explicit Square(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -52,12 +52,6 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(INFO) << "Tile's attr dims is set to default";
attr->dims = {1};
} else {
attr->dims = CastToInt(prim.GetAttr("dims"));
}
if (inputs.size() == kAnfPopulaterInputNumTwo) {
auto inputNode = inputs[kAnfPopulaterInputNumOne];
MS_ASSERT(inputNode != nullptr);
@ -80,6 +74,15 @@ int Tile::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &input
}
}
}
if (prim.GetAttr("dims") == nullptr) {
MS_LOG(INFO) << "Tile's attr dims is set to default. The operator in mindspore has no attribute"
"named dims and all the dimensions needs to be multiplied by default.";
for (size_t i = 0; i < attr->multiples.size(); i++) {
attr->dims.push_back(i);
}
} else {
attr->dims = CastToInt(prim.GetAttr("dims"));
}
this->primitive_->value.value = attr;
}
return RET_OK;

@ -28,7 +28,38 @@ bool TopK::GetSorted() const { return this->primitive_->value.AsTopK()->sorted;
void TopK::SetK(int k) { this->primitive_->value.AsTopK()->k = k; }
void TopK::SetSorted(bool sorted) { this->primitive_->value.AsTopK()->sorted = sorted; }
int TopK::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_TopK;
}
if (this->primitive_->value.type != schema::PrimitiveType_TopK) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::TopKT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
// the k value of mindspore models is one of inputs instead of an attribute.
attr->k = 0;
if (prim.GetAttr("sorted") != nullptr) {
attr->sorted = GetValue<bool>(prim.GetAttr("sorted"));
}
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int TopK::GetK() const { return this->primitive_->value_as_TopK()->k(); }
@ -60,7 +91,7 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
}
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
if (input->format() != schema::Format::Format_NHWC) {
if (input->shape().size() == kDimension_4d && input->format() != schema::Format::Format_NHWC) {
MS_LOG(ERROR) << "topk only support NHWC now!";
return RET_FORMAT_ERR;
}
@ -76,7 +107,16 @@ int TopK::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_INFER_INVALID;
}
auto out_shape = input->shape();
out_shape.at(out_shape.size() - 1) = GetK();
if (inputs_.size() == kSingleNum) {
out_shape.at(out_shape.size() - 1) = GetK();
} else if (inputs_.size() == kDoubleNum) {
if (inputs_.at(1)->data_c() == nullptr) {
return RET_INFER_INVALID;
} else {
int *data = reinterpret_cast<int32_t *>(inputs_.at(1)->data_c());
out_shape.at(out_shape.size() - 1) = *data;
}
}
if (inputs_.size() == kDoubleNum && inputs_.at(1)->data_c() != nullptr) {
out_shape.at(out_shape.size() - 1) = reinterpret_cast<int *>(inputs_.at(1)->data_c())[0];
}

@ -34,6 +34,7 @@ class TopK : public PrimitiveC {
explicit TopK(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
void SetK(int k);
void SetSorted(bool sorted);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif

@ -152,4 +152,5 @@ kernel::LiteKernel *CpuGatherNdFp32KernelCreator(const std::vector<lite::Tensor
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_GatherNd, CpuGatherNdFp32KernelCreator)
} // namespace mindspore::kernel

Loading…
Cancel
Save