add ops create

pull/4002/head
sunsuodong 5 years ago
parent c17ed236ee
commit 4c04c9a916

@ -122,8 +122,6 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Scale(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Eltwise:
return new lite::Eltwise(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Concat:
return new lite::Concat(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Fill:
@ -148,6 +146,72 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Crop(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_SquaredDifference:
return new lite::SquaredDifference(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_AddN:
return new lite::AddN(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Abs:
return new lite::Abs(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Sin:
return new lite::Sin(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Cos:
return new lite::Cos(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Log:
return new lite::Log(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Sqrt:
return new lite::Sqrt(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Rsqrt:
return new lite::Rsqrt(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Square:
return new lite::Square(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Exp:
return new lite::Exp(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Gather:
return new lite::Gather(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LocalResponseNormalization:
return new lite::LocalResponseNormalization(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Maximum:
return new lite::Maximum(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Minimum:
return new lite::Minimum(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Pad:
return new lite::Pad(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_StridedSlice:
return new lite::StridedSlice(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Prelu:
return new lite::Prelu(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Round:
return new lite::Round(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_ReverseSequence:
return new lite::ReverseSequence(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LogicalAnd:
return new lite::LogicalAnd(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LogicalOr:
return new lite::LogicalOr(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LogicalNot:
return new lite::LogicalNot(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_FloorDiv:
return new lite::FloorDiv(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_FloorMod:
return new lite::FloorMod(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Equal:
return new lite::Equal(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_NotEqual:
return new lite::NotEqual(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Less:
return new lite::Less(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_LessEqual:
return new lite::LessEqual(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Greater:
return new lite::Greater(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_GreaterEqual:
return new lite::GreaterEqual(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Floor:
return new lite::Floor(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Ceil:
return new lite::Ceil(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_Split:
return new lite::Split(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_OneHot:
return new lite::OneHot(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_QuantDTypeCast:

@ -108,6 +108,12 @@ class Activation : public Primitive {
const schema::Activation *GetAttribute() const { return this->primitive->value_as_Activation(); }
};
class Prelu : public Activation {
public:
explicit Prelu(schema::Primitive *primitive) : Activation(primitive) {}
const schema::Prelu *GetAttribute() const { return this->primitive->value_as_Prelu(); }
};
class Split : public Primitive {
public:
explicit Split(schema::Primitive *primitive) : Primitive(primitive) {}
@ -259,12 +265,84 @@ class Div : public Arithmetic {
const schema::Div *GetAttribute() const { return this->primitive->value_as_Div(); }
};
class LogicalAnd : public Arithmetic {
public:
explicit LogicalAnd(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::LogicalAnd *GetAttribute() const { return this->primitive->value_as_LogicalAnd(); }
};
class LogicalOr : public Arithmetic {
public:
explicit LogicalOr(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::LogicalOr *GetAttribute() const { return this->primitive->value_as_LogicalOr(); }
};
class Maximum : public Arithmetic {
public:
explicit Maximum(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Maximum *GetAttribute() const { return this->primitive->value_as_Maximum(); }
};
class Minimum : public Arithmetic {
public:
explicit Minimum(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Minimum *GetAttribute() const { return this->primitive->value_as_Minimum(); }
};
class FloorDiv : public Arithmetic {
public:
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::FloorDiv *GetAttribute() const { return this->primitive->value_as_FloorDiv(); }
};
class FloorMod : public Arithmetic {
public:
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::FloorMod *GetAttribute() const { return this->primitive->value_as_FloorMod(); }
};
class SquaredDifference : public Arithmetic {
public:
explicit SquaredDifference(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::SquaredDifference *GetAttribute() const { return this->primitive->value_as_SquaredDifference(); }
};
class Equal : public Arithmetic {
public:
explicit Equal(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Equal *GetAttribute() const { return this->primitive->value_as_Equal(); }
};
class NotEqual : public Arithmetic {
public:
explicit NotEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::NotEqual *GetAttribute() const { return this->primitive->value_as_NotEqual(); }
};
class Less : public Arithmetic {
public:
explicit Less(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Less *GetAttribute() const { return this->primitive->value_as_Less(); }
};
class LessEqual : public Arithmetic {
public:
explicit LessEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::LessEqual *GetAttribute() const { return this->primitive->value_as_LessEqual(); }
};
class Greater : public Arithmetic {
public:
explicit Greater(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Greater *GetAttribute() const { return this->primitive->value_as_Greater(); }
};
class GreaterEqual : public Arithmetic {
public:
explicit GreaterEqual(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::GreaterEqual *GetAttribute() const { return this->primitive->value_as_GreaterEqual(); }
};
class Eltwise : public Arithmetic {
public:
explicit Eltwise(schema::Primitive *primitive) : Arithmetic(primitive) {}
@ -331,6 +409,18 @@ class LogicalNot : public ArithmeticSelf {
const schema::LogicalNot *GetAttribute() const { return this->primitive->value_as_LogicalNot(); }
};
class Floor : public ArithmeticSelf {
public:
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); }
};
class Ceil : public ArithmeticSelf {
public:
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); }
};
class RealDiv : public Arithmetic {
public:
explicit RealDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
@ -364,12 +454,6 @@ class Cast : public Primitive {
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
};
class Ceil : public ArithmeticSelf {
public:
explicit Ceil(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
const schema::Ceil *GetAttribute() const { return this->primitive->value_as_Ceil(); }
};
class Concat : public Primitive {
public:
explicit Concat(schema::Primitive *primitive) : Primitive(primitive) {}
@ -475,24 +559,6 @@ class Squeeze : public Primitive {
int InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) override;
};
class Floor : public ArithmeticSelf {
public:
explicit Floor(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
const schema::Floor *GetAttribute() const { return this->primitive->value_as_Floor(); }
};
class FloorDiv : public Arithmetic {
public:
explicit FloorDiv(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); }
};
class FloorMod : public Arithmetic {
public:
explicit FloorMod(schema::Primitive *primitive) : Arithmetic(primitive) {}
const schema::Sub *GetAttribute() const { return this->primitive->value_as_Sub(); }
};
class Transpose : public Primitive {
public:
explicit Transpose(schema::Primitive *primitive) : Primitive(primitive) {}

Loading…
Cancel
Save