diff --git a/mindspore/lite/src/model_impl.cc b/mindspore/lite/src/model_impl.cc index f8ccd4980b..078bae0f41 100644 --- a/mindspore/lite/src/model_impl.cc +++ b/mindspore/lite/src/model_impl.cc @@ -122,8 +122,6 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::Scale(const_cast(srcPrim)); case schema::PrimitiveType_Eltwise: return new lite::Eltwise(const_cast(srcPrim)); - case schema::PrimitiveType_Ceil: - return new lite::Ceil(const_cast(srcPrim)); case schema::PrimitiveType_Concat: return new lite::Concat(const_cast(srcPrim)); case schema::PrimitiveType_Fill: @@ -148,6 +146,72 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) { return new lite::Crop(const_cast(srcPrim)); case schema::PrimitiveType_SquaredDifference: return new lite::SquaredDifference(const_cast(srcPrim)); + case schema::PrimitiveType_AddN: + return new lite::AddN(const_cast(srcPrim)); + case schema::PrimitiveType_Abs: + return new lite::Abs(const_cast(srcPrim)); + case schema::PrimitiveType_Sin: + return new lite::Sin(const_cast(srcPrim)); + case schema::PrimitiveType_Cos: + return new lite::Cos(const_cast(srcPrim)); + case schema::PrimitiveType_Log: + return new lite::Log(const_cast(srcPrim)); + case schema::PrimitiveType_Sqrt: + return new lite::Sqrt(const_cast(srcPrim)); + case schema::PrimitiveType_Rsqrt: + return new lite::Rsqrt(const_cast(srcPrim)); + case schema::PrimitiveType_Square: + return new lite::Square(const_cast(srcPrim)); + case schema::PrimitiveType_Exp: + return new lite::Exp(const_cast(srcPrim)); + case schema::PrimitiveType_Gather: + return new lite::Gather(const_cast(srcPrim)); + case schema::PrimitiveType_LocalResponseNormalization: + return new lite::LocalResponseNormalization(const_cast(srcPrim)); + case schema::PrimitiveType_Maximum: + return new lite::Maximum(const_cast(srcPrim)); + case schema::PrimitiveType_Minimum: + return new lite::Minimum(const_cast(srcPrim)); + case schema::PrimitiveType_Pad: + return new lite::Pad(const_cast(srcPrim)); + case schema::PrimitiveType_StridedSlice: + return new lite::StridedSlice(const_cast(srcPrim)); + case schema::PrimitiveType_Prelu: + return new lite::Prelu(const_cast(srcPrim)); + case schema::PrimitiveType_Round: + return new lite::Round(const_cast(srcPrim)); + case schema::PrimitiveType_ReverseSequence: + return new lite::ReverseSequence(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalAnd: + return new lite::LogicalAnd(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalOr: + return new lite::LogicalOr(const_cast(srcPrim)); + case schema::PrimitiveType_LogicalNot: + return new lite::LogicalNot(const_cast(srcPrim)); + case schema::PrimitiveType_FloorDiv: + return new lite::FloorDiv(const_cast(srcPrim)); + case schema::PrimitiveType_FloorMod: + return new lite::FloorMod(const_cast(srcPrim)); + case schema::PrimitiveType_Equal: + return new lite::Equal(const_cast(srcPrim)); + case schema::PrimitiveType_NotEqual: + return new lite::NotEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Less: + return new lite::Less(const_cast(srcPrim)); + case schema::PrimitiveType_LessEqual: + return new lite::LessEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Greater: + return new lite::Greater(const_cast(srcPrim)); + case schema::PrimitiveType_GreaterEqual: + return new lite::GreaterEqual(const_cast(srcPrim)); + case schema::PrimitiveType_Floor: + return new lite::Floor(const_cast(srcPrim)); + case schema::PrimitiveType_Ceil: + return new lite::Ceil(const_cast(srcPrim)); + case schema::PrimitiveType_Split: + return new lite::Split(const_cast(srcPrim)); + case schema::PrimitiveType_OneHot: + return new lite::OneHot(const_cast(srcPrim)); case schema::PrimitiveType_MatMul: return new lite::MatMul(const_cast(srcPrim)); case schema::PrimitiveType_QuantDTypeCast: diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index a4e0f2c35a..d69b00ed8d 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -108,6 +108,12 @@ class Activation : public Primitive { const schema::Activation *GetAttribute() const { return this->primitive->value_as_Activation(); } }; +class Prelu : public Activation { + public: + explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {} + const schema::Prelu *GetAttribute() const { return this->primitive->value_as_Prelu(); } +}; + class Split : public Primitive { public: explicit Split(schema::Primitive *primitive) : Primitive(primitive) {} @@ -259,12 +265,84 @@ class Div : public Arithmetic { const schema::Div *GetAttribute() const { return this->primitive->value_as_Div(); } }; +class LogicalAnd : public Arithmetic { + public: + explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LogicalAnd *GetAttribute() const { return this->primitive->value_as_LogicalAnd(); } +}; + +class LogicalOr : public Arithmetic { + public: + explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LogicalOr *GetAttribute() const { return this->primitive->value_as_LogicalOr(); } +}; + +class Maximum : public Arithmetic { + public: + explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Maximum *GetAttribute() const { return this->primitive->value_as_Maximum(); } +}; + +class Minimum : public Arithmetic { + public: + explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Minimum *GetAttribute() const { return this->primitive->value_as_Minimum(); } +}; + +class FloorDiv : public Arithmetic { + public: + explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::FloorDiv *GetAttribute() const { return this->primitive->value_as_FloorDiv(); } +}; + +class FloorMod : public Arithmetic { + public: + explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::FloorMod *GetAttribute() const { return this->primitive->value_as_FloorMod(); } +}; + class SquaredDifference : public Arithmetic { public: explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {} const schema::SquaredDifference *GetAttribute() const { return this->primitive->value_as_SquaredDifference(); } }; +class Equal : public Arithmetic { + public: + explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Equal *GetAttribute() const { return this->primitive->value_as_Equal(); } +}; + +class NotEqual : public Arithmetic { + public: + explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::NotEqual *GetAttribute() const { return this->primitive->value_as_NotEqual(); } +}; + +class Less : public Arithmetic { + public: + explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Less *GetAttribute() const { return this->primitive->value_as_Less(); } +}; + +class LessEqual : public Arithmetic { + public: + explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::LessEqual *GetAttribute() const { return this->primitive->value_as_LessEqual(); } +}; + +class Greater : public Arithmetic { + public: + explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::Greater *GetAttribute() const { return this->primitive->value_as_Greater(); } +}; + +class GreaterEqual : public Arithmetic { + public: + explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {} + const schema::GreaterEqual *GetAttribute() const { return this->primitive->value_as_GreaterEqual(); } +}; + class Eltwise : public Arithmetic { public: explicit Eltwise(schema::Primitive *primitive) : Arithmetic(primitive) {} @@ -331,6 +409,18 @@ class LogicalNot : public ArithmeticSelf { const schema::LogicalNot *GetAttribute() const { return this->primitive->value_as_LogicalNot(); } }; +class Floor : public ArithmeticSelf { + public: + explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); } +}; + +class Ceil : public ArithmeticSelf { + public: + explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} + const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); } +}; + class RealDiv : public Arithmetic { public: explicit RealDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} @@ -364,12 +454,6 @@ class Cast : public Primitive { int InferShape(std::vector inputs_, std::vector outputs_) override; }; -class Ceil : public ArithmeticSelf { - public: - explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); } -}; - class Concat : public Primitive { public: explicit Concat(schema::Primitive *primitive) : Primitive(primitive) {} @@ -475,24 +559,6 @@ class Squeeze : public Primitive { int InferShape(std::vector inputs_, std::vector outputs_) override; }; -class Floor : public ArithmeticSelf { - public: - explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {} - const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); } -}; - -class FloorDiv : public Arithmetic { - public: - explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {} - const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } -}; - -class FloorMod : public Arithmetic { - public: - explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {} - const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); } -}; - class Transpose : public Primitive { public: explicit Transpose(schema::Primitive *primitive) : Primitive(primitive) {}