diff --git a/mindspore/lite/src/model.cc b/mindspore/lite/src/model.cc index 4563607bd1..644cb5da32 100644 --- a/mindspore/lite/src/model.cc +++ b/mindspore/lite/src/model.cc @@ -19,6 +19,7 @@ #include "include/errorcode.h" #include "src/common/graph_util.h" #include "include/version.h" +#include "src/ops/ops_register.h" namespace mindspore::lite { @@ -31,7 +32,12 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) { } auto c_node = meta_graph->nodes()->GetAs(i); auto src_prim = c_node->primitive(); +#ifdef PRIMITIVE_WRITEABLE node->primitive_ = PrimitiveC::Create(const_cast(src_prim)); +#else + auto primitive = const_cast(src_prim); + node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); +#endif if (node->primitive_ == nullptr) { MS_LOG(ERROR) << "unpack primitive == nullptr!"; delete node; diff --git a/mindspore/lite/src/ops/abs.cc b/mindspore/lite/src/ops/abs.cc index 1416513b06..3e9c3a19d0 100644 --- a/mindspore/lite/src/ops/abs.cc +++ b/mindspore/lite/src/ops/abs.cc @@ -15,6 +15,7 @@ */ #include "src/ops/abs.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -27,6 +28,9 @@ int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *AbsCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AbsRegistry(schema::PrimitiveType_Abs, AbsCreator); #endif +Registry AbsParameterRegistry(schema::PrimitiveType_Abs, PopulateArithmeticSelf); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/activation.cc b/mindspore/lite/src/ops/activation.cc index 185645ec84..428f526713 100644 --- a/mindspore/lite/src/ops/activation.cc +++ b/mindspore/lite/src/ops/activation.cc @@ -16,6 +16,8 @@ #include "src/ops/activation.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/fp32/activation.h" namespace mindspore { namespace lite { @@ -80,6 +82,30 @@ int Activation::GetType() const { return this->primitive_->value_as_Activation() float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); } float Activation::GetMinVal() const { return this->primitive_->value_as_Activation()->min_val(); } float Activation::GetMaxVal() const { return this->primitive_->value_as_Activation()->max_val(); } + +PrimitiveC *ActivationCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ActivationRegistry(schema::PrimitiveType_Activation, ActivationCreator); #endif +OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *primitive) { + ActivationParameter *act_param = reinterpret_cast(malloc(sizeof(ActivationParameter))); + if (act_param == nullptr) { + MS_LOG(ERROR) << "malloc ActivationParameter failed."; + return nullptr; + } + memset(act_param, 0, sizeof(ActivationParameter)); + act_param->op_parameter_.type_ = primitive->Type(); + auto activation = + reinterpret_cast(const_cast(primitive)); + act_param->type_ = static_cast(activation->GetType()); + act_param->alpha_ = activation->GetAlpha(); + act_param->min_val_ = activation->GetMinVal(); + act_param->max_val_ = activation->GetMaxVal(); + return reinterpret_cast(act_param); +} + +Registry ActivationParameterRegistry(schema::PrimitiveType_Activation, PopulateActivationParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/activation_grad.cc b/mindspore/lite/src/ops/activation_grad.cc index f9c45a9bea..3dcf2ab216 100644 --- a/mindspore/lite/src/ops/activation_grad.cc +++ b/mindspore/lite/src/ops/activation_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/activation_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -74,6 +76,11 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat } int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); } float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); } + +PrimitiveC *ActivationGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ActivationGradRegistry(schema::PrimitiveType_ActivationGrad, ActivationGradCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/adam.cc b/mindspore/lite/src/ops/adam.cc index 51eaf4648f..a3a2b8d922 100644 --- a/mindspore/lite/src/ops/adam.cc +++ b/mindspore/lite/src/ops/adam.cc @@ -14,6 +14,8 @@ * limitations under the License. */ #include "src/ops/adam.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -62,6 +64,9 @@ int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *AdamCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator); #endif int Adam::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index 251ebf4df7..1d92a6625f 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -16,6 +16,8 @@ #include "src/ops/add.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/arithmetic_common.h" namespace mindspore { namespace lite { @@ -71,6 +73,31 @@ int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl } int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); } +PrimitiveC *AddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AddRegistry(schema::PrimitiveType_Add, AddCreator); #endif + +OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive))->GetActivationType(); + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} +Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/addn.cc b/mindspore/lite/src/ops/addn.cc index 65564bb7aa..11e8556058 100644 --- a/mindspore/lite/src/ops/addn.cc +++ b/mindspore/lite/src/ops/addn.cc @@ -16,6 +16,8 @@ #include "src/ops/addn.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -62,8 +64,22 @@ int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F } int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); } +PrimitiveC *AddNCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AddNRegistry(schema::PrimitiveType_AddN, AddNCreator); #endif +OpParameter *PopulateAddNParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *addn_param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (addn_param == nullptr) { + MS_LOG(ERROR) << "malloc OpParameter failed."; + return nullptr; + } + memset(addn_param, 0, sizeof(OpParameter)); + addn_param->type_ = primitive->Type(); + return reinterpret_cast(addn_param); +} +Registry AddNParameterRegistry(schema::PrimitiveType_AddN, PopulateAddNParameter); + namespace { constexpr int kLeastInputNum = 2; } diff --git a/mindspore/lite/src/ops/apply_momentum.cc b/mindspore/lite/src/ops/apply_momentum.cc index fc625ad85c..56c50f00e4 100644 --- a/mindspore/lite/src/ops/apply_momentum.cc +++ b/mindspore/lite/src/ops/apply_momentum.cc @@ -14,6 +14,8 @@ * limitations under the License. */ #include "src/ops/apply_momentum.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -67,6 +69,11 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ApplyMomentumCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ApplyMomentumRegistry(schema::PrimitiveType_ApplyMomentum, ApplyMomentumCreator); #endif int ApplyMomentum::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/argmax.cc b/mindspore/lite/src/ops/argmax.cc index dd2550f813..562b7a39b5 100644 --- a/mindspore/lite/src/ops/argmax.cc +++ b/mindspore/lite/src/ops/argmax.cc @@ -16,6 +16,9 @@ #include "src/ops/argmax.h" +#include "src/ops/ops_register.h" +#include "nnacl/arg_min_max_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,8 +55,29 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK() bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); } int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); } +PrimitiveC *ArgMaxCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ArgMaxRegistry(schema::PrimitiveType_ArgMax, ArgMaxCreator); #endif +OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitive) { + ArgMinMaxParameter *arg_param = reinterpret_cast(malloc(sizeof(ArgMinMaxParameter))); + if (arg_param == nullptr) { + MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; + return nullptr; + } + memset(arg_param, 0, sizeof(ArgMinMaxParameter)); + arg_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + arg_param->axis_ = param->GetAxis(); + arg_param->topk_ = param->GetTopK(); + arg_param->axis_type_ = param->GetAxisType(); + arg_param->out_value_ = param->GetOutMaxValue(); + arg_param->keep_dims_ = param->GetKeepDims(); + return reinterpret_cast(arg_param); +} + +Registry ArgMaxParameterRegistry(schema::PrimitiveType_ArgMax, PopulateArgMaxParameter); + int ArgMax::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/argmin.cc b/mindspore/lite/src/ops/argmin.cc index ce9e8f5ee2..5bd0250bf1 100644 --- a/mindspore/lite/src/ops/argmin.cc +++ b/mindspore/lite/src/ops/argmin.cc @@ -16,6 +16,9 @@ #include "src/ops/argmin.h" +#include "src/ops/ops_register.h" +#include "nnacl/arg_min_max_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,8 +55,29 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK() bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); } int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); } +PrimitiveC *ArgMinCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ArgMinRegistry(schema::PrimitiveType_ArgMin, ArgMinCreator); #endif +OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitive) { + ArgMinMaxParameter *arg_param = reinterpret_cast(malloc(sizeof(ArgMinMaxParameter))); + if (arg_param == nullptr) { + MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed."; + return nullptr; + } + memset(arg_param, 0, sizeof(ArgMinMaxParameter)); + arg_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + arg_param->axis_ = param->GetAxis(); + arg_param->topk_ = param->GetTopK(); + arg_param->axis_type_ = param->GetAxisType(); + arg_param->out_value_ = param->GetOutMaxValue(); + arg_param->keep_dims_ = param->GetKeepDims(); + return reinterpret_cast(arg_param); +} + +Registry ArgMinParameterRegistry(schema::PrimitiveType_ArgMin, PopulateArgMinParameter); + int ArgMin::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/arithmetic.cc b/mindspore/lite/src/ops/arithmetic.cc index d020e32451..120d6eaf8f 100644 --- a/mindspore/lite/src/ops/arithmetic.cc +++ b/mindspore/lite/src/ops/arithmetic.cc @@ -21,6 +21,29 @@ namespace mindspore { namespace lite { + +OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + + arithmetic_param->activation_type_ = 0; + + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} + int Arithmetic::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/arithmetic.h b/mindspore/lite/src/ops/arithmetic.h index 78283c82e8..6ee2d51694 100644 --- a/mindspore/lite/src/ops/arithmetic.h +++ b/mindspore/lite/src/ops/arithmetic.h @@ -21,6 +21,7 @@ #include #include #include "src/ops/primitive_c.h" +#include "nnacl/arithmetic_common.h" namespace mindspore { namespace lite { @@ -51,6 +52,8 @@ class Arithmetic : public PrimitiveC { std::vector in_shape1_; std::vector out_shape_; }; + +OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/arithmetic_grad.h b/mindspore/lite/src/ops/arithmetic_grad.h index 3cc4fb7d4d..dc851666c5 100644 --- a/mindspore/lite/src/ops/arithmetic_grad.h +++ b/mindspore/lite/src/ops/arithmetic_grad.h @@ -21,6 +21,7 @@ #include #include #include "src/ops/primitive_c.h" +#include "nnacl/arithmetic_self_parameter.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/src/ops/arithmetic_self.cc b/mindspore/lite/src/ops/arithmetic_self.cc index 0cdeeca9d8..798fa9fd5a 100644 --- a/mindspore/lite/src/ops/arithmetic_self.cc +++ b/mindspore/lite/src/ops/arithmetic_self.cc @@ -17,9 +17,21 @@ #include "src/ops/arithmetic_self.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { +OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticSelfParameter *arithmetic_self_param = + reinterpret_cast(malloc(sizeof(ArithmeticSelfParameter))); + if (arithmetic_self_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed."; + return nullptr; + } + memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter)); + arithmetic_self_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(arithmetic_self_param); +} int ArithmeticSelf::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); diff --git a/mindspore/lite/src/ops/arithmetic_self.h b/mindspore/lite/src/ops/arithmetic_self.h index 17bbfdc6d9..71ae5ea30a 100644 --- a/mindspore/lite/src/ops/arithmetic_self.h +++ b/mindspore/lite/src/ops/arithmetic_self.h @@ -19,6 +19,7 @@ #include #include "src/ops/primitive_c.h" +#include "nnacl/arithmetic_self_parameter.h" namespace mindspore { namespace lite { @@ -37,6 +38,7 @@ class ArithmeticSelf : public PrimitiveC { #endif int InferShape(std::vector inputs_, std::vector outputs_) override; }; +OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/assign.cc b/mindspore/lite/src/ops/assign.cc index ca59337651..456b6802e8 100644 --- a/mindspore/lite/src/ops/assign.cc +++ b/mindspore/lite/src/ops/assign.cc @@ -17,6 +17,8 @@ #include "src/ops/assign.h" #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -56,6 +58,9 @@ int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *AssignCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry AssignRegistry(schema::PrimitiveType_Assign, AssignCreator); #endif int Assign::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/batch_norm.cc b/mindspore/lite/src/ops/batch_norm.cc index 5ae19f1a76..e2ca9d43f2 100644 --- a/mindspore/lite/src/ops/batch_norm.cc +++ b/mindspore/lite/src/ops/batch_norm.cc @@ -16,6 +16,9 @@ #include "src/ops/batch_norm.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/batchnorm_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -60,6 +63,28 @@ int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe } float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); } +PrimitiveC *BatchNormCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry BatchNormRegistry(schema::PrimitiveType_BatchNorm, BatchNormCreator); #endif + +OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) { + const auto param = + reinterpret_cast(const_cast(primitive)); + BatchNormParameter *batch_norm_param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); + if (batch_norm_param == nullptr) { + MS_LOG(ERROR) << "malloc BatchNormParameter failed."; + return nullptr; + } + memset(batch_norm_param, 0, sizeof(BatchNormParameter)); + batch_norm_param->op_parameter_.type_ = primitive->Type(); + batch_norm_param->epsilon_ = param->GetEpsilon(); + batch_norm_param->fused_ = false; + return reinterpret_cast(batch_norm_param); +} + +Registry BatchNormParameterRegistry(schema::PrimitiveType_BatchNorm, PopulateBatchNorm); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/batch_to_space.cc b/mindspore/lite/src/ops/batch_to_space.cc index c723e03e1c..e55edbc9f1 100644 --- a/mindspore/lite/src/ops/batch_to_space.cc +++ b/mindspore/lite/src/ops/batch_to_space.cc @@ -20,6 +20,9 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/batch_to_space.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -66,7 +69,49 @@ std::vector BatchToSpace::GetCrops() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator); #endif + +OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) { + BatchToSpaceParameter *batch_space_param = + reinterpret_cast(malloc(sizeof(BatchToSpaceParameter))); + if (batch_space_param == nullptr) { + MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed."; + return nullptr; + } + memset(batch_space_param, 0, sizeof(BatchToSpaceParameter)); + batch_space_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + auto block_shape = param->GetBlockShape(); + if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) { + MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; + free(batch_space_param); + return nullptr; + } + + auto crops = param->GetCrops(); + if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) { + MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE; + free(batch_space_param); + return nullptr; + } + + for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) { + batch_space_param->block_shape_[i] = block_shape[i]; + } + + for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { + batch_space_param->crops_[i] = crops[i]; + } + return reinterpret_cast(batch_space_param); +} + +Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter); +Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter); + namespace { constexpr int kBatchToSpaceOutputNum = 1; constexpr int kBatchToSpaceInputNum = 1; diff --git a/mindspore/lite/src/ops/bias_add.cc b/mindspore/lite/src/ops/bias_add.cc index bb7059e1ab..ec0e54040e 100644 --- a/mindspore/lite/src/ops/bias_add.cc +++ b/mindspore/lite/src/ops/bias_add.cc @@ -16,6 +16,8 @@ #include "src/ops/bias_add.h" #include +#include "nnacl/arithmetic_common.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -78,6 +80,22 @@ std::vector BiasAdd::GetAxis() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator); #endif + +OpParameter *PopulateBiasAddParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + + return reinterpret_cast(arithmetic_param); +} +Registry BiasAddParameterRegistry(schema::PrimitiveType_BiasAdd, PopulateBiasAddParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/bias_grad.cc b/mindspore/lite/src/ops/bias_grad.cc index 6da4712224..3da20349b6 100644 --- a/mindspore/lite/src/ops/bias_grad.cc +++ b/mindspore/lite/src/ops/bias_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/bias_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -74,6 +76,11 @@ std::vector BiasGrad::GetAxis() const { auto fb_vector = this->primitive_->value_as_BiasGrad()->axis(); return std::vector(fb_vector->begin(), fb_vector->end()); } + +PrimitiveC *BiasGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry BiasGradRegistry(schema::PrimitiveType_BiasGrad, BiasGradCreator); #endif int BiasGrad::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/bn_grad.cc b/mindspore/lite/src/ops/bn_grad.cc index d623117327..ef450a809e 100644 --- a/mindspore/lite/src/ops/bn_grad.cc +++ b/mindspore/lite/src/ops/bn_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/bn_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/broadcast_to.cc b/mindspore/lite/src/ops/broadcast_to.cc index 5e483a8abe..a4741e2954 100644 --- a/mindspore/lite/src/ops/broadcast_to.cc +++ b/mindspore/lite/src/ops/broadcast_to.cc @@ -16,6 +16,9 @@ #include "src/ops/broadcast_to.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/broadcast_to.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -50,7 +53,32 @@ std::vector BroadcastTo::GetDstShape() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +PrimitiveC *BroadcastToCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreator); #endif + +OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *primitive) { + BroadcastToParameter *broadcast_param = + reinterpret_cast(malloc(sizeof(BroadcastToParameter))); + if (broadcast_param == nullptr) { + MS_LOG(ERROR) << "malloc BroadcastToParameter failed."; + return nullptr; + } + memset(broadcast_param, 0, sizeof(BroadcastToParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + broadcast_param->op_parameter_.type_ = primitive->Type(); + auto dst_shape = param->GetDstShape(); + broadcast_param->shape_size_ = dst_shape.size(); + for (size_t i = 0; i < broadcast_param->shape_size_; ++i) { + broadcast_param->shape_[i] = dst_shape[i]; + } + return reinterpret_cast(broadcast_param); +} + +Registry BroadcastToParameterRegistry(schema::PrimitiveType_BroadcastTo, PopulateBroadcastToParameter); + namespace { constexpr int kBroadcastToInputNum = 1; constexpr int kBroadcastToOutputNum = 1; diff --git a/mindspore/lite/src/ops/cast.cc b/mindspore/lite/src/ops/cast.cc index 41ff55115b..348abd0c99 100644 --- a/mindspore/lite/src/ops/cast.cc +++ b/mindspore/lite/src/ops/cast.cc @@ -16,6 +16,9 @@ #include "src/ops/cast.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/cast.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -75,8 +78,26 @@ int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); } int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); } +PrimitiveC *CastCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry CastRegistry(schema::PrimitiveType_Cast, CastCreator); #endif +OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) { + CastParameter *cast_param = reinterpret_cast(malloc(sizeof(CastParameter))); + if (cast_param == nullptr) { + MS_LOG(ERROR) << "malloc CastParameter failed."; + return nullptr; + } + memset(cast_param, 0, sizeof(CastParameter)); + cast_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + cast_param->src_type_ = param->GetSrcT(); + cast_param->dst_type_ = param->GetDstT(); + return reinterpret_cast(cast_param); +} + +Registry CastParameterRegistry(schema::PrimitiveType_Cast, PopulateCastParameter); + int Cast::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/ceil.cc b/mindspore/lite/src/ops/ceil.cc new file mode 100644 index 0000000000..848f92e2dc --- /dev/null +++ b/mindspore/lite/src/ops/ceil.cc @@ -0,0 +1,27 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/ceil.h" + +#include "src/ops/ops_register.h" + +namespace mindspore { +namespace lite { + +Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf); + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/ceil.h b/mindspore/lite/src/ops/ceil.h index 70093e00ca..3a7ebfa739 100644 --- a/mindspore/lite/src/ops/ceil.h +++ b/mindspore/lite/src/ops/ceil.h @@ -21,6 +21,7 @@ #include #include #include "src/ops/arithmetic_self.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -43,6 +44,7 @@ class Ceil : public ArithmeticSelf { } #endif }; + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/clip.cc b/mindspore/lite/src/ops/clip.cc index 08e654337a..f928601e47 100644 --- a/mindspore/lite/src/ops/clip.cc +++ b/mindspore/lite/src/ops/clip.cc @@ -16,6 +16,9 @@ #include "src/ops/clip.h" +#include "src/ops/ops_register.h" +#include "nnacl/clip.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -42,6 +45,24 @@ int Clip::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F float Clip::GetMax() const { return this->primitive_->value_as_Clip()->max(); } float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); } +PrimitiveC *ClipCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ClipRegistry(schema::PrimitiveType_Clip, ClipCreator); #endif +OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) { + ClipParameter *act_param = reinterpret_cast(malloc(sizeof(ClipParameter))); + if (act_param == nullptr) { + MS_LOG(ERROR) << "malloc ClipParameter failed."; + return nullptr; + } + memset(act_param, 0, sizeof(ClipParameter)); + act_param->op_parameter_.type_ = primitive->Type(); + auto activation = reinterpret_cast(const_cast(primitive)); + act_param->min_val_ = activation->GetMin(); + act_param->max_val_ = activation->GetMax(); + return reinterpret_cast(act_param); +} + +Registry ClipParameterRegistry(schema::PrimitiveType_Clip, PopulateClipParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/concat.cc b/mindspore/lite/src/ops/concat.cc index 31a63fb002..27e29bc8c0 100644 --- a/mindspore/lite/src/ops/concat.cc +++ b/mindspore/lite/src/ops/concat.cc @@ -19,6 +19,8 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/concat_parameter.h" namespace mindspore { namespace lite { @@ -76,8 +78,26 @@ int Concat::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: int Concat::GetAxis() const { return this->primitive_->value_as_Concat()->axis(); } int Concat::GetN() const { return this->primitive_->value_as_Concat()->n(); } +PrimitiveC *ConcatCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ConcatRegistry(schema::PrimitiveType_Concat, ConcatCreator); + #endif +OpParameter *PopulateConcatParameter(const mindspore::lite::PrimitiveC *primitive) { + ConcatParameter *concat_param = reinterpret_cast(malloc(sizeof(ConcatParameter))); + if (concat_param == nullptr) { + MS_LOG(ERROR) << "malloc ConcatParameter failed."; + return nullptr; + } + memset(concat_param, 0, sizeof(ConcatParameter)); + concat_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + concat_param->axis_ = param->GetAxis(); + return reinterpret_cast(concat_param); +} + +Registry ConcatParameterRegistry(schema::PrimitiveType_Concat, PopulateConcatParameter); + namespace { constexpr int kConcatOutputNum = 1; } diff --git a/mindspore/lite/src/ops/constant_of_shape.cc b/mindspore/lite/src/ops/constant_of_shape.cc index b6e6018394..f1e7e60bdb 100644 --- a/mindspore/lite/src/ops/constant_of_shape.cc +++ b/mindspore/lite/src/ops/constant_of_shape.cc @@ -18,6 +18,8 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/constant_of_shape.h" namespace mindspore::lite { namespace { @@ -45,8 +47,29 @@ int ConstantOfShape::UnPackToFlatBuilder(const schema::Primitive *primitive, fla } float ConstantOfShape::GetValue() const { return this->primitive_->value_as_ConstantOfShape()->value(); } +PrimitiveC *ConstantOfShapeCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ConstantOfShapeRegistry(schema::PrimitiveType_ConstantOfShape, ConstantOfShapeCreator); + #endif +OpParameter *PopulateConstantOfShapeParameter(const mindspore::lite::PrimitiveC *primitive) { + auto attr = + reinterpret_cast(const_cast(primitive)); + ConstantOfShapeParameter *param = + reinterpret_cast(malloc(sizeof(ConstantOfShapeParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc ConstantOfShapeParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(ConstantOfShapeParameter)); + param->op_parameter_.type_ = primitive->Type(); + param->value_ = attr->GetValue(); + return reinterpret_cast(param); +} +Registry ConstantOfShapeParameterRegistry(schema::PrimitiveType_ConstantOfShape, PopulateConstantOfShapeParameter); + int ConstantOfShape::InferShape(std::vector inputs_, std::vector outputs_) { if (inputs_.size() != kShapeInputNum) { MS_LOG(ERROR) << "inputs to ConstantOfShape operator should be 1, but " << inputs_.size() << " is given."; diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index f5a74cdecf..a4a0c6ca17 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -24,9 +24,10 @@ #include "src/common/log_adapter.h" #ifdef PRIMITIVE_WRITEABLE #include - #include "tools/converter/quantizer/quantize_util.h" #endif +#include "nnacl/conv_parameter.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -320,7 +321,51 @@ int Conv2D::GetDilateH() const { return this->primitive_->value_as_Conv2D()->dil bool Conv2D::GetHasBias() const { return this->primitive_->value_as_Conv2D()->hasBias(); } int Conv2D::GetActivationType() const { return this->primitive_->value_as_Conv2D()->activationType(); } +PrimitiveC *Conv2DCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry Conv2DRegistry(schema::PrimitiveType_Conv2D, Conv2DCreator); #endif +OpParameter *PopulateConvParameter(const mindspore::lite::PrimitiveC *primitive) { + ConvParameter *conv_param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "malloc ConvParameter failed."; + return nullptr; + } + memset(conv_param, 0, sizeof(ConvParameter)); + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); + conv_param->kernel_h_ = conv_primitive->GetKernelH(); + conv_param->kernel_w_ = conv_primitive->GetKernelW(); + conv_param->group_ = conv_primitive->GetGroup(); + conv_param->stride_h_ = conv_primitive->GetStrideH(); + conv_param->stride_w_ = conv_primitive->GetStrideW(); + + auto conv2d_lite_primitive = (lite::Conv2D *)primitive; + conv_param->pad_u_ = conv2d_lite_primitive->PadUp(); + conv_param->pad_d_ = conv2d_lite_primitive->PadDown(); + conv_param->pad_l_ = conv2d_lite_primitive->PadLeft(); + conv_param->pad_r_ = conv2d_lite_primitive->PadRight(); + conv_param->dilation_h_ = conv_primitive->GetDilateH(); + conv_param->dilation_w_ = conv_primitive->GetDilateW(); + conv_param->input_channel_ = conv_primitive->GetChannelIn(); + conv_param->output_channel_ = conv_primitive->GetChannelOut(); + conv_param->group_ = conv_primitive->GetGroup(); + auto act_type = conv_primitive->GetActivationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + conv_param->act_type_ = ActType_Relu6; + break; + default: + conv_param->act_type_ = ActType_No; + break; + } + return reinterpret_cast(conv_param); +} +Registry Conv2DParameterRegistry(schema::PrimitiveType_Conv2D, PopulateConvParameter); + void Conv2D::ConvInferShape(int input_h, int input_w, int *output_h, int *output_w) { MS_ASSERT(this->primitive_ != nullptr); int kernel_w = GetKernelW(); diff --git a/mindspore/lite/src/ops/conv2d_grad_filter.cc b/mindspore/lite/src/ops/conv2d_grad_filter.cc index 42b1283ead..5665ec971f 100644 --- a/mindspore/lite/src/ops/conv2d_grad_filter.cc +++ b/mindspore/lite/src/ops/conv2d_grad_filter.cc @@ -15,6 +15,7 @@ */ #include "src/ops/conv2d_grad_filter.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -176,6 +177,10 @@ int Conv2DGradFilter::GetActivationType() const { return this->primitive_->value_as_Conv2DGradFilter()->activationType(); } +PrimitiveC *Conv2DGradFilterCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry conv2DGradFilterRegistry(schema::PrimitiveType_Conv2DGradFilter, Conv2DGradFilterCreator); #endif int Conv2DGradFilter::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/conv2d_grad_input.cc b/mindspore/lite/src/ops/conv2d_grad_input.cc index a49f62a3e0..7b2fefd811 100644 --- a/mindspore/lite/src/ops/conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/conv2d_grad_input.cc @@ -16,6 +16,8 @@ #include "src/ops/conv2d_grad_input.h" #include "src/ops/group_conv2d_grad_input.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -178,6 +180,10 @@ int Conv2DGradInput::GetActivationType() const { return this->primitive_->value_as_Conv2DGradInput()->activationType(); } +PrimitiveC *Conv2DGradInputCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry Conv2DGradInputRegistry(schema::PrimitiveType_Conv2DGradInput, Conv2DGradInputCreator); #endif int Conv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/cos.cc b/mindspore/lite/src/ops/cos.cc index 373b121d97..c01aec3588 100644 --- a/mindspore/lite/src/ops/cos.cc +++ b/mindspore/lite/src/ops/cos.cc @@ -16,6 +16,8 @@ #include "src/ops/cos.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -27,6 +29,10 @@ int Cos::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *CosCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry CosRegistry(schema::PrimitiveType_Cos, CosCreator); #endif +Registry CosParameterRegistry(schema::PrimitiveType_Cos, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/crop.cc b/mindspore/lite/src/ops/crop.cc index 4cca858775..1d88d4361e 100644 --- a/mindspore/lite/src/ops/crop.cc +++ b/mindspore/lite/src/ops/crop.cc @@ -16,6 +16,9 @@ #include "src/ops/crop.h" +#include "src/ops/ops_register.h" +#include "nnacl/crop_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -51,7 +54,33 @@ std::vector Crop::GetOffsets() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +PrimitiveC *CropCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry CropRegistry(schema::PrimitiveType_Crop, CropCreator); #endif + +OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive) { + auto param = reinterpret_cast(const_cast(primitive)); + auto param_offset = param->GetOffsets(); + if (param_offset.size() > CROP_OFFSET_MAX_SIZE) { + MS_LOG(ERROR) << "crop_param offset size(" << param_offset.size() << ") should <= " << CROP_OFFSET_MAX_SIZE; + return nullptr; + } + CropParameter *crop_param = reinterpret_cast(malloc(sizeof(CropParameter))); + if (crop_param == nullptr) { + MS_LOG(ERROR) << "malloc CropParameter failed."; + return nullptr; + } + memset(crop_param, 0, sizeof(CropParameter)); + crop_param->op_parameter_.type_ = primitive->Type(); + crop_param->axis_ = param->GetAxis(); + crop_param->offset_size_ = param_offset.size(); + for (size_t i = 0; i < param_offset.size(); ++i) { + crop_param->offset_[i] = param_offset[i]; + } + return reinterpret_cast(crop_param); +} +Registry CropParameterRegistry(schema::PrimitiveType_Crop, PopulateCropParameter); + namespace { constexpr int kCropOutputNum = 1; constexpr int kCropInputNum = 2; diff --git a/mindspore/lite/src/ops/custom_extract_features.cc b/mindspore/lite/src/ops/custom_extract_features.cc index faa6c51a48..afaaf010e0 100644 --- a/mindspore/lite/src/ops/custom_extract_features.cc +++ b/mindspore/lite/src/ops/custom_extract_features.cc @@ -17,6 +17,8 @@ #include "src/common/string_util.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -31,8 +33,25 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *CustomExtractFeaturesCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry CustomExtractFeaturesRegistry(schema::PrimitiveType_CustomExtractFeatures, CustomExtractFeaturesCreator); #endif +OpParameter *PopulateExtractFeaturesParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "new OpParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(OpParameter)); + param->type_ = primitive->Type(); + return param; +} +Registry CustomExtractFeaturesParameterRegistry(schema::PrimitiveType_CustomExtractFeatures, + PopulateExtractFeaturesParameter); + int CustomExtractFeatures::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.at(0); auto output0 = outputs_.at(0); diff --git a/mindspore/lite/src/ops/custom_normalize.cc b/mindspore/lite/src/ops/custom_normalize.cc index c720a9afaf..565b90fc14 100644 --- a/mindspore/lite/src/ops/custom_normalize.cc +++ b/mindspore/lite/src/ops/custom_normalize.cc @@ -17,6 +17,8 @@ #include "src/common/string_util.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -30,7 +32,25 @@ int CustomNormalize::UnPackToFlatBuilder(const schema::Primitive *primitive, fla fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *CustomNormalizeCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry CustomNormalizeRegistry(schema::PrimitiveType_CustomNormalize, CustomNormalizeCreator); #endif + +OpParameter *PopulateCustomNormalizeParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "new OpParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(OpParameter)); + param->type_ = primitive->Type(); + return param; +} +Registry CustomNormalizeParameterRegistry(schema::PrimitiveType_CustomNormalize, PopulateCustomNormalizeParameter); + int CustomNormalize::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.at(0); auto output = outputs_.at(0); diff --git a/mindspore/lite/src/ops/custom_predict.cc b/mindspore/lite/src/ops/custom_predict.cc index 05cbeac13c..72625aef38 100644 --- a/mindspore/lite/src/ops/custom_predict.cc +++ b/mindspore/lite/src/ops/custom_predict.cc @@ -15,6 +15,9 @@ */ #include "src/ops/custom_predict.h" +#include "src/ops/ops_register.h" +#include "nnacl/predict_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -45,7 +48,27 @@ int CustomPredict::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *CustomPredictCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry CustomPredictRegistry(schema::PrimitiveType_CustomPredict, CustomPredictCreator); #endif +OpParameter *PopulateCustomPredictParameter(const mindspore::lite::PrimitiveC *primitive) { + PredictParameter *param = reinterpret_cast(malloc(sizeof(PredictParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "malloc param failed."; + return nullptr; + } + memset(param, 0, sizeof(PredictParameter)); + param->op_parameter_.type_ = primitive->Type(); + auto prim = reinterpret_cast(const_cast(primitive)); + param->output_num = prim->GetOutputNum(); + param->weight_threshold = prim->GetWeightThreshold(); + return reinterpret_cast(param); +} +Registry CustomPredictParameterRegistry(schema::PrimitiveType_CustomPredict, PopulateCustomPredictParameter); + int CustomPredict::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.at(0); auto output0 = outputs_.at(0); diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index f3b6bfc289..bc9b39e706 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -25,6 +25,9 @@ #include "tools/converter/quantizer/quantize_util.h" #endif +#include "src/ops/ops_register.h" +#include "nnacl/conv_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -295,7 +298,51 @@ int DeConv2D::GetDilateH() const { return this->primitive_->value_as_DeConv2D()- bool DeConv2D::GetHasBias() const { return this->primitive_->value_as_DeConv2D()->hasBias(); } int DeConv2D::GetActivationType() const { return this->primitive_->value_as_DeConv2D()->activationType(); } +PrimitiveC *DeConv2DCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry DeConv2DRegistry(schema::PrimitiveType_DeConv2D, DeConv2DCreator); #endif + +OpParameter *PopulateDeconvParameter(const mindspore::lite::PrimitiveC *primitive) { + ConvParameter *conv_param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "malloc ConvParameter failed."; + return nullptr; + } + memset(conv_param, 0, sizeof(ConvParameter)); + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); + conv_param->kernel_h_ = conv_primitive->GetKernelH(); + conv_param->kernel_w_ = conv_primitive->GetKernelW(); + conv_param->stride_h_ = conv_primitive->GetStrideH(); + conv_param->stride_w_ = conv_primitive->GetStrideW(); + + auto deconv_lite_primitive = (lite::DeConv2D *)primitive; + conv_param->pad_u_ = deconv_lite_primitive->PadUp(); + conv_param->pad_d_ = deconv_lite_primitive->PadDown(); + conv_param->pad_l_ = deconv_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconv_lite_primitive->PadRight(); + conv_param->dilation_h_ = conv_primitive->GetDilateH(); + conv_param->dilation_w_ = conv_primitive->GetDilateW(); + auto act_type = conv_primitive->GetActivationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + conv_param->act_type_ = ActType_Relu6; + break; + default: + conv_param->act_type_ = ActType_No; + break; + } + return reinterpret_cast(conv_param); +} + +Registry DeConv2DParameterRegistry(schema::PrimitiveType_DeConv2D, PopulateDeconvParameter); + int DeConv2D::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index 634ff21ba8..6af0cf523a 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -16,6 +16,9 @@ #include "src/ops/dedepthwise_conv2d.h" +#include "src/ops/ops_register.h" +#include "nnacl/conv_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -109,7 +112,51 @@ int DeDepthwiseConv2D::GetActivationType() const { return this->primitive_->value_as_DeDepthwiseConv2D()->activationType(); } +PrimitiveC *DeDepthwiseConv2DCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry DeDepthwiseConv2DRegistry(schema::PrimitiveType_DeDepthwiseConv2D, DeDepthwiseConv2DCreator); #endif + +OpParameter *PopulateDeconvDwParameter(const mindspore::lite::PrimitiveC *primitive) { + ConvParameter *conv_param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "malloc ConvParameter failed."; + return nullptr; + } + memset(conv_param, 0, sizeof(ConvParameter)); + conv_param->op_parameter_.type_ = primitive->Type(); + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); + conv_param->kernel_h_ = conv_primitive->GetKernelH(); + conv_param->kernel_w_ = conv_primitive->GetKernelW(); + conv_param->stride_h_ = conv_primitive->GetStrideH(); + conv_param->stride_w_ = conv_primitive->GetStrideW(); + + auto deconvdw_lite_primitive = (mindspore::lite::DeDepthwiseConv2D *)primitive; + conv_param->pad_u_ = deconvdw_lite_primitive->PadUp(); + conv_param->pad_d_ = deconvdw_lite_primitive->PadDown(); + conv_param->pad_l_ = deconvdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = deconvdw_lite_primitive->PadRight(); + conv_param->dilation_h_ = conv_primitive->GetDilateH(); + conv_param->dilation_w_ = conv_primitive->GetDilateW(); + auto act_type = conv_primitive->GetActivationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + conv_param->act_type_ = ActType_Relu6; + break; + default: + conv_param->act_type_ = ActType_No; + break; + } + return reinterpret_cast(conv_param); +} + +Registry DeDepthwiseConv2DParameterRegistry(schema::PrimitiveType_DeDepthwiseConv2D, PopulateDeconvDwParameter); + int DeDepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { MS_LOG(ERROR) << "inputs number is invalid"; diff --git a/mindspore/lite/src/ops/depend.cc b/mindspore/lite/src/ops/depend.cc index b199d7474e..20313a9b85 100644 --- a/mindspore/lite/src/ops/depend.cc +++ b/mindspore/lite/src/ops/depend.cc @@ -17,6 +17,8 @@ #include #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,6 +54,9 @@ int Depend::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *DependCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry DependRegistry(schema::PrimitiveType_Depend, DependCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/depth_to_space.cc b/mindspore/lite/src/ops/depth_to_space.cc index e062c112ad..e2ab49cc6b 100644 --- a/mindspore/lite/src/ops/depth_to_space.cc +++ b/mindspore/lite/src/ops/depth_to_space.cc @@ -16,6 +16,9 @@ #include "src/ops/depth_to_space.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/depth_to_space_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -42,7 +45,29 @@ int DepthToSpace::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu int DepthToSpace::GetBlockSize() const { return this->primitive_->value_as_DepthToSpace()->blockSize(); } int DepthToSpace::GetFormat() const { return this->primitive_->value_as_DepthToSpace()->format(); } +PrimitiveC *DepthToSpaceCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry DepthToSpaceRegistry(schema::PrimitiveType_DepthToSpace, DepthToSpaceCreator); + #endif + +OpParameter *PopulateDepthToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) { + DepthToSpaceParameter *depth_space_param = + reinterpret_cast(malloc(sizeof(DepthToSpaceParameter))); + if (depth_space_param == nullptr) { + MS_LOG(ERROR) << "malloc DepthToSpaceParameter failed."; + return nullptr; + } + memset(depth_space_param, 0, sizeof(DepthToSpaceParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + depth_space_param->op_parameter_.type_ = primitive->Type(); + depth_space_param->block_size_ = param->GetBlockSize(); + return reinterpret_cast(depth_space_param); +} + +Registry DepthToSpaceParameterRegistry(schema::PrimitiveType_DepthToSpace, PopulateDepthToSpaceParameter); + namespace { constexpr int kDepthToSpaceOutputNum = 1; constexpr int kDepthToSpaceInputNum = 1; diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index dde0bd67bb..cd6b8eca9a 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -21,6 +21,9 @@ #ifdef PRIMITIVE_WRITEABLE #include "tools/converter/quantizer/quantize_util.h" #endif +#include "src/ops/ops_register.h" +#include "nnacl/conv_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -191,7 +194,54 @@ int DepthwiseConv2D::GetActivationType() const { return this->primitive_->value_as_DepthwiseConv2D()->activationType(); } +PrimitiveC *DepthWiseConv2DCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry DepthWiseConv2DRegistry(schema::PrimitiveType_DepthwiseConv2D, DepthWiseConv2DCreator); + #endif + +OpParameter *PopulateConvDwParameter(const mindspore::lite::PrimitiveC *primitive) { + ConvParameter *conv_param = reinterpret_cast(malloc(sizeof(ConvParameter))); + if (conv_param == nullptr) { + MS_LOG(ERROR) << "malloc ConvParameter failed."; + return nullptr; + } + memset(conv_param, 0, sizeof(ConvParameter)); + conv_param->op_parameter_.type_ = primitive->Type(); + + auto conv_primitive = + reinterpret_cast(const_cast(primitive)); + conv_param->kernel_h_ = conv_primitive->GetKernelH(); + conv_param->kernel_w_ = conv_primitive->GetKernelW(); + conv_param->stride_h_ = conv_primitive->GetStrideH(); + conv_param->stride_w_ = conv_primitive->GetStrideW(); + + auto convdw_lite_primitive = (lite::DepthwiseConv2D *)primitive; + conv_param->pad_u_ = convdw_lite_primitive->PadUp(); + conv_param->pad_d_ = convdw_lite_primitive->PadDown(); + conv_param->pad_l_ = convdw_lite_primitive->PadLeft(); + conv_param->pad_r_ = convdw_lite_primitive->PadRight(); + conv_param->input_channel_ = convdw_lite_primitive->GetInputChannel(); + conv_param->dilation_h_ = conv_primitive->GetDilateH(); + conv_param->dilation_w_ = conv_primitive->GetDilateW(); + auto act_type = conv_primitive->GetActivationType(); + switch (act_type) { + case schema::ActivationType_RELU: + conv_param->act_type_ = ActType_Relu; + break; + case schema::ActivationType_RELU6: + conv_param->act_type_ = ActType_Relu6; + break; + default: + conv_param->act_type_ = ActType_No; + break; + } + return reinterpret_cast(conv_param); +} + +Registry DepthwiseConv2DParameterRegistry(schema::PrimitiveType_DepthwiseConv2D, PopulateConvDwParameter); + int DepthwiseConv2D::InferShape(std::vector inputs_, std::vector outputs_) { if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { MS_LOG(ERROR) << "inputs number is invalid"; diff --git a/mindspore/lite/src/ops/detection_post_process.cc b/mindspore/lite/src/ops/detection_post_process.cc index cf31065a5b..eed4fb9b66 100644 --- a/mindspore/lite/src/ops/detection_post_process.cc +++ b/mindspore/lite/src/ops/detection_post_process.cc @@ -16,6 +16,9 @@ #include "src/ops/detection_post_process.h" +#include "src/ops/ops_register.h" +#include "nnacl/detection_post_process_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -139,7 +142,38 @@ bool DetectionPostProcess::GetOutQuantized() const { return this->primitive_->value_as_DetectionPostProcess()->OutQuantized(); } +PrimitiveC *DetectionPostProcessCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry DetectionPostProcessRegistry(schema::PrimitiveType_DetectionPostProcess, DetectionPostProcessCreator); #endif +OpParameter *PopulateDetectionPostProcessParameter(const mindspore::lite::PrimitiveC *primitive) { + DetectionPostProcessParameter *detection_post_process_parameter = + reinterpret_cast(malloc(sizeof(DetectionPostProcessParameter))); + if (detection_post_process_parameter == nullptr) { + MS_LOG(ERROR) << "malloc EluParameter failed."; + return nullptr; + } + memset(detection_post_process_parameter, 0, sizeof(DetectionPostProcessParameter)); + detection_post_process_parameter->op_parameter_.type_ = primitive->Type(); + auto param = + reinterpret_cast(const_cast(primitive)); + detection_post_process_parameter->h_scale_ = param->GetHScale(); + detection_post_process_parameter->w_scale_ = param->GetWScale(); + detection_post_process_parameter->x_scale_ = param->GetXScale(); + detection_post_process_parameter->y_scale_ = param->GetYScale(); + detection_post_process_parameter->nms_iou_threshold_ = param->GetNmsIouThreshold(); + detection_post_process_parameter->nms_score_threshold_ = param->GetNmsScoreThreshold(); + detection_post_process_parameter->max_detections_ = param->GetMaxDetections(); + detection_post_process_parameter->detections_per_class_ = param->GetDetectionsPerClass(); + detection_post_process_parameter->max_classes_per_detection_ = param->GetMaxClassesPerDetection(); + detection_post_process_parameter->num_classes_ = param->GetNumClasses(); + detection_post_process_parameter->use_regular_nms_ = param->GetUseRegularNms(); + return reinterpret_cast(detection_post_process_parameter); +} +Registry DetectionPostProcessParameterRegistry(schema::PrimitiveType_DetectionPostProcess, + PopulateDetectionPostProcessParameter); + namespace { constexpr int kDetectionPostProcessOutputNum = 4; constexpr int kDetectionPostProcessInputNum = 3; diff --git a/mindspore/lite/src/ops/div.cc b/mindspore/lite/src/ops/div.cc index 12300fe9ec..a9ce630bb2 100644 --- a/mindspore/lite/src/ops/div.cc +++ b/mindspore/lite/src/ops/div.cc @@ -16,6 +16,8 @@ #include "src/ops/div.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -41,6 +43,30 @@ int Div::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl } int Div::GetActivationType() const { return this->primitive_->value_as_Div()->activationType(); } +PrimitiveC *DivCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC
(primitive); } +Registry DivRegistry(schema::PrimitiveType_Div, DivCreator); #endif +OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive))->GetActivationType(); + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} +Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/dropout.cc b/mindspore/lite/src/ops/dropout.cc index 7cce7ad1cc..94451e4bb1 100644 --- a/mindspore/lite/src/ops/dropout.cc +++ b/mindspore/lite/src/ops/dropout.cc @@ -16,6 +16,8 @@ #include "src/ops/dropout.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -39,6 +41,8 @@ int Dropout::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers } float Dropout::GetRatio() const { return this->primitive_->value_as_Dropout()->ratio(); } +PrimitiveC *DropoutCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry DropoutRegistry(schema::PrimitiveType_Dropout, DropoutCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/eltwise.cc b/mindspore/lite/src/ops/eltwise.cc index 3760cd9f85..2598514af7 100644 --- a/mindspore/lite/src/ops/eltwise.cc +++ b/mindspore/lite/src/ops/eltwise.cc @@ -16,6 +16,9 @@ #include "src/ops/eltwise.h" +#include "src/ops/ops_register.h" +#include "nnacl/arithmetic_common.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -39,6 +42,35 @@ int Eltwise::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers } int Eltwise::GetMode() const { return this->primitive_->value_as_Eltwise()->mode(); } +PrimitiveC *EltwiseCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry EltwiseRegistry(schema::PrimitiveType_Eltwise, EltwiseCreator); #endif + +OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + auto eltwise = reinterpret_cast(const_cast(primitive)); + switch (eltwise->GetMode()) { + case schema::EltwiseMode_PROD: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul; + break; + case schema::EltwiseMode_SUM: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add; + break; + case schema::EltwiseMode_MAXIMUM: + arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum; + break; + default: + free(arithmetic_param); + return nullptr; + } + return reinterpret_cast(arithmetic_param); +} +Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/elu.cc b/mindspore/lite/src/ops/elu.cc index eefafa5bf9..d68a22e5ed 100644 --- a/mindspore/lite/src/ops/elu.cc +++ b/mindspore/lite/src/ops/elu.cc @@ -16,6 +16,8 @@ #include "src/ops/elu.h" #include +#include "nnacl/fp32/elu.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -61,6 +63,22 @@ int Elu::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl } float Elu::GetAlpha() const { return this->primitive_->value_as_Elu()->alpha(); } +PrimitiveC *EluCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry EluRegistry(schema::PrimitiveType_Elu, EluCreator); #endif + +OpParameter *PopulateEluParameter(const mindspore::lite::PrimitiveC *primitive) { + EluParameter *elu_parameter = reinterpret_cast(malloc(sizeof(EluParameter))); + if (elu_parameter == nullptr) { + MS_LOG(ERROR) << "malloc EluParameter failed."; + return nullptr; + } + memset(elu_parameter, 0, sizeof(EluParameter)); + elu_parameter->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + elu_parameter->alpha_ = param->GetAlpha(); + return reinterpret_cast(elu_parameter); +} +Registry EluParameterRegistry(schema::PrimitiveType_Elu, PopulateEluParameter); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/embedding_lookup.cc b/mindspore/lite/src/ops/embedding_lookup.cc index 38079af394..e028f130d9 100644 --- a/mindspore/lite/src/ops/embedding_lookup.cc +++ b/mindspore/lite/src/ops/embedding_lookup.cc @@ -16,6 +16,9 @@ #include "src/ops/embedding_lookup.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/embedding_lookup.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -41,8 +44,35 @@ int EmbeddingLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla } float EmbeddingLookup::GetMaxNorm() const { return this->primitive_->value_as_EmbeddingLookup()->maxNorm(); } +PrimitiveC *EmbeddingLookupCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry EmbeddingLookupRegistry(schema::PrimitiveType_EmbeddingLookup, EmbeddingLookupCreator); #endif +OpParameter *PopulateEmbeddingLookupParameter(const mindspore::lite::PrimitiveC *primitive) { + EmbeddingLookupParameter *embedding_lookup_parameter = + reinterpret_cast(malloc(sizeof(EmbeddingLookupParameter))); + if (embedding_lookup_parameter == nullptr) { + MS_LOG(ERROR) << "malloc EmbeddingLookupParameter failed."; + return nullptr; + } + memset(embedding_lookup_parameter, 0, sizeof(EmbeddingLookupParameter)); + embedding_lookup_parameter->op_parameter_.type_ = primitive->Type(); + auto param = + reinterpret_cast(const_cast(primitive)); + embedding_lookup_parameter->max_norm_ = param->GetMaxNorm(); + if (embedding_lookup_parameter->max_norm_ < 0) { + MS_LOG(ERROR) << "Embedding lookup max norm should be positive number, got " + << embedding_lookup_parameter->max_norm_; + free(embedding_lookup_parameter); + return nullptr; + } + return reinterpret_cast(embedding_lookup_parameter); +} + +Registry EmbeddingLookupParameterRegistry(schema::PrimitiveType_EmbeddingLookup, PopulateEmbeddingLookupParameter); + int EmbeddingLookup::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() < kDoubleNum) { diff --git a/mindspore/lite/src/ops/embedding_lookup_sparse.cc b/mindspore/lite/src/ops/embedding_lookup_sparse.cc index b981defde2..5dd0c55d6c 100644 --- a/mindspore/lite/src/ops/embedding_lookup_sparse.cc +++ b/mindspore/lite/src/ops/embedding_lookup_sparse.cc @@ -16,6 +16,8 @@ #include "src/ops/embedding_lookup_sparse.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -76,6 +78,10 @@ float EmbeddingLookupSparse::GetMaxNortm() const { return this->primitive_->value_as_EmbeddingLookupSparse()->maxNortm(); } +PrimitiveC *EmbeddingLookupSparseCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry EmbeddingLookupSparseRegistry(schema::PrimitiveType_EmbeddingLookupSparse, EmbeddingLookupSparseCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/equal.cc b/mindspore/lite/src/ops/equal.cc index b7525d7902..4db93bd4bd 100644 --- a/mindspore/lite/src/ops/equal.cc +++ b/mindspore/lite/src/ops/equal.cc @@ -16,6 +16,8 @@ #include "src/ops/equal.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -28,6 +30,8 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: return RET_OK; } +PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); #endif int Equal::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -39,5 +43,6 @@ int Equal::InferShape(std::vector inputs_, std::vector outpu output->SetFormat(input->GetFormat()); return RET_OK; } +Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/exp.cc b/mindspore/lite/src/ops/exp.cc index 0a1f88ffc1..967d7775d8 100644 --- a/mindspore/lite/src/ops/exp.cc +++ b/mindspore/lite/src/ops/exp.cc @@ -16,6 +16,9 @@ #include "src/ops/exp.h" +#include "src/ops/ops_register.h" +#include "src/ops/arithmetic_self.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -71,6 +74,10 @@ float Exp::GetBase() const { return this->primitive_->value_as_Exp()->base(); } float Exp::GetScale() const { return this->primitive_->value_as_Exp()->scale(); } float Exp::GetShift() const { return this->primitive_->value_as_Exp()->shift(); } +PrimitiveC *ExpCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ExpRegistry(schema::PrimitiveType_Exp, ExpCreator); #endif +Registry ExpParameterRegistry(schema::PrimitiveType_Exp, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/expand_dims.cc b/mindspore/lite/src/ops/expand_dims.cc index 9952793931..a57ec68c4a 100644 --- a/mindspore/lite/src/ops/expand_dims.cc +++ b/mindspore/lite/src/ops/expand_dims.cc @@ -16,6 +16,9 @@ #include "src/ops/expand_dims.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/expandDims.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -40,8 +43,27 @@ int ExpandDims::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff } int ExpandDims::GetDim() const { return this->primitive_->value_as_ExpandDims()->dim(); } +PrimitiveC *ExpandDimsCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ExpandDimsRegistry(schema::PrimitiveType_ExpandDims, ExpandDimsCreator); #endif +OpParameter *PopulateExpandDimsParameter(const mindspore::lite::PrimitiveC *primitive) { + auto param = reinterpret_cast(const_cast(primitive)); + ExpandDimsParameter *expand_dims_param = reinterpret_cast(malloc(sizeof(ExpandDimsParameter))); + if (expand_dims_param == nullptr) { + MS_LOG(ERROR) << "malloc ExpandDimsParameter failed."; + return nullptr; + } + memset(expand_dims_param, 0, sizeof(ExpandDimsParameter)); + expand_dims_param->op_parameter_.type_ = primitive->Type(); + expand_dims_param->dim_ = param->GetDim(); + return reinterpret_cast(expand_dims_param); +} + +Registry ExpandDimsParameterRegistry(schema::PrimitiveType_ExpandDims, PopulateExpandDimsParameter); + int ExpandDims::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc index 5bc8b3eb26..dfa565eeb0 100644 --- a/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc +++ b/mindspore/lite/src/ops/fake_quant_with_min_max_vars.cc @@ -16,6 +16,8 @@ #include "src/ops/fake_quant_with_min_max_vars.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -54,6 +56,10 @@ int FakeQuantWithMinMaxVars::GetNumBits() const { return this->primitive_->value_as_FakeQuantWithMinMaxVars()->numBits(); } +PrimitiveC *FakeQuantWithMinMaxVarsCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FakeQuantWithMinMaxVarsRegistry(schema::PrimitiveType_FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVarsCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/fill.cc b/mindspore/lite/src/ops/fill.cc index 9b1c8da7f7..e8a2d68432 100644 --- a/mindspore/lite/src/ops/fill.cc +++ b/mindspore/lite/src/ops/fill.cc @@ -16,6 +16,9 @@ #include "src/ops/fill.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/fill.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -48,8 +51,30 @@ std::vector Fill::GetDims() const { return std::vector(fb_vector->begin(), fb_vector->end()); } +PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator); #endif +OpParameter *PopulateFillParameter(const mindspore::lite::PrimitiveC *primitive) { + const auto param = reinterpret_cast(const_cast(primitive)); + FillParameter *fill_param = reinterpret_cast(malloc(sizeof(FillParameter))); + if (fill_param == nullptr) { + MS_LOG(ERROR) << "malloc FillParameter failed."; + return nullptr; + } + memset(fill_param, 0, sizeof(FillParameter)); + fill_param->op_parameter_.type_ = primitive->Type(); + auto flatDims = param->GetDims(); + fill_param->num_dims_ = flatDims.size(); + int i = 0; + for (auto iter = flatDims.begin(); iter != flatDims.end(); iter++) { + fill_param->dims_[i++] = *iter; + } + return reinterpret_cast(fill_param); +} + +Registry FillParameterRegistry(schema::PrimitiveType_Fill, PopulateFillParameter); + int Fill::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/flatten.cc b/mindspore/lite/src/ops/flatten.cc index b77df96150..65b523b652 100644 --- a/mindspore/lite/src/ops/flatten.cc +++ b/mindspore/lite/src/ops/flatten.cc @@ -17,6 +17,9 @@ #include "src/ops/flatten.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/flatten.h" + namespace mindspore { namespace lite { @@ -86,6 +89,22 @@ int Flatten::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *FlattenCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry FlattenRegistry(schema::PrimitiveType_Flatten, FlattenCreator); #endif + +OpParameter *PopulateFlattenParameter(const mindspore::lite::PrimitiveC *primitive) { + FlattenParameter *flatten_param = reinterpret_cast(malloc(sizeof(FlattenParameter))); + if (flatten_param == nullptr) { + MS_LOG(ERROR) << "malloc FlattenParameter failed."; + return nullptr; + } + memset(flatten_param, 0, sizeof(FlattenParameter)); + flatten_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(flatten_param); +} + +Registry FlattenParameterRegistry(schema::PrimitiveType_Flatten, PopulateFlattenParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/flatten_grad.cc b/mindspore/lite/src/ops/flatten_grad.cc index 15ec161d58..f80af5c78c 100644 --- a/mindspore/lite/src/ops/flatten_grad.cc +++ b/mindspore/lite/src/ops/flatten_grad.cc @@ -17,6 +17,8 @@ #include "src/ops/flatten_grad.h" #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { int FlattenGrad::InferShape(std::vector inputs_, std::vector outputs_) { @@ -85,6 +87,10 @@ int FlattenGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *FlattenGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FlattenGradRegistry(schema::PrimitiveType_FlattenGrad, FlattenGradCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor.cc b/mindspore/lite/src/ops/floor.cc index d284b102d9..1f544f2513 100644 --- a/mindspore/lite/src/ops/floor.cc +++ b/mindspore/lite/src/ops/floor.cc @@ -16,6 +16,8 @@ #include "src/ops/floor.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -28,7 +30,10 @@ int Floor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } - +PrimitiveC *FloorCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry FloorRegistry(schema::PrimitiveType_Floor, FloorCreator); #endif +Registry FloorParameterRegistry(schema::PrimitiveType_Floor, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_div.cc b/mindspore/lite/src/ops/floor_div.cc index 9eff6de98a..ee737a7e59 100644 --- a/mindspore/lite/src/ops/floor_div.cc +++ b/mindspore/lite/src/ops/floor_div.cc @@ -16,6 +16,8 @@ #include "src/ops/floor_div.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -29,6 +31,11 @@ int FloorDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer return RET_OK; } +PrimitiveC *FloorDivCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FloorDivRegistry(schema::PrimitiveType_FloorDiv, FloorDivCreator); #endif +Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/floor_mod.cc b/mindspore/lite/src/ops/floor_mod.cc index f903620655..d633f364d8 100644 --- a/mindspore/lite/src/ops/floor_mod.cc +++ b/mindspore/lite/src/ops/floor_mod.cc @@ -16,6 +16,8 @@ #include "src/ops/floor_mod.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -28,7 +30,11 @@ int FloorMod::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer fbb->Finish(prim_offset); return RET_OK; } - +PrimitiveC *FloorModCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FloorModRegistry(schema::PrimitiveType_FloorMod, FloorModCreator); #endif +Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/full_connection.cc b/mindspore/lite/src/ops/full_connection.cc index d0661869fd..ad045f4f28 100644 --- a/mindspore/lite/src/ops/full_connection.cc +++ b/mindspore/lite/src/ops/full_connection.cc @@ -16,6 +16,9 @@ #include "src/ops/full_connection.h" +#include "src/ops/ops_register.h" +#include "nnacl/matmul_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -51,7 +54,38 @@ int FullConnection::GetAxis() const { return this->primitive_->value_as_FullConn bool FullConnection::GetUseAxis() const { return this->primitive_->value_as_FullConnection()->useAxis(); } int FullConnection::GetActivationType() const { return this->primitive_->value_as_FullConnection()->activationType(); } +PrimitiveC *FullConnectionCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FullConnectionRegistry(schema::PrimitiveType_FullConnection, FullConnectionCreator); #endif + +OpParameter *PopulateFullconnectionParameter(const mindspore::lite::PrimitiveC *primitive) { + auto param = + reinterpret_cast(const_cast(primitive)); + MatMulParameter *matmul_param = reinterpret_cast(malloc(sizeof(MatMulParameter))); + if (matmul_param == nullptr) { + MS_LOG(ERROR) << "malloc MatMulParameter failed."; + return nullptr; + } + memset(matmul_param, 0, sizeof(MatMulParameter)); + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = true; + matmul_param->a_transpose_ = false; + matmul_param->has_bias_ = param->GetHasBias(); + if (param->GetActivationType() == schema::ActivationType_RELU) { + matmul_param->act_type_ = ActType_Relu; + } else if (param->GetActivationType() == schema::ActivationType_RELU6) { + matmul_param->act_type_ = ActType_Relu6; + } else { + matmul_param->act_type_ = ActType_No; + } + + return reinterpret_cast(matmul_param); +} + +Registry FullConnectionParameterRegistry(schema::PrimitiveType_FullConnection, PopulateFullconnectionParameter); + int FullConnection::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input0 = inputs_.front(); diff --git a/mindspore/lite/src/ops/fused_batchnorm.cc b/mindspore/lite/src/ops/fused_batchnorm.cc index a08451af54..94a3e97e5f 100644 --- a/mindspore/lite/src/ops/fused_batchnorm.cc +++ b/mindspore/lite/src/ops/fused_batchnorm.cc @@ -16,6 +16,9 @@ #include "src/ops/fused_batchnorm.h" +#include "src/ops/ops_register.h" +#include "nnacl/batchnorm_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -72,7 +75,29 @@ float FusedBatchNorm::GetEpsilon() const { return this->primitive_->value_as_Fus float FusedBatchNorm::GetMomentum() const { return this->primitive_->value_as_FusedBatchNorm()->momentum(); } int FusedBatchNorm::GetSpatial() const { return this->primitive_->value_as_FusedBatchNorm()->spatial(); } +PrimitiveC *FusedBatchNormCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry FusedBatchNormRegistry(schema::PrimitiveType_FusedBatchNorm, FusedBatchNormCreator); #endif +OpParameter *PopulateFusedBatchNorm(const mindspore::lite::PrimitiveC *primitive) { + BatchNormParameter *batch_norm_param = reinterpret_cast(malloc(sizeof(BatchNormParameter))); + if (batch_norm_param == nullptr) { + MS_LOG(ERROR) << "malloc BatchNormParameter failed."; + return nullptr; + } + memset(batch_norm_param, 0, sizeof(BatchNormParameter)); + batch_norm_param->op_parameter_.type_ = primitive->Type(); + auto param = + reinterpret_cast(const_cast(primitive)); + batch_norm_param->epsilon_ = param->GetEpsilon(); + batch_norm_param->momentum_ = param->GetMomentum(); + batch_norm_param->fused_ = true; + return reinterpret_cast(batch_norm_param); +} + +Registry FusedBatchNormParameterRegistry(schema::PrimitiveType_FusedBatchNorm, PopulateFusedBatchNorm); + int FusedBatchNorm::InferShape(std::vector inputs_, std::vector outputs_) { for (size_t i = 0; i < inputs_.size(); i++) { if (outputs_.size() <= i) break; diff --git a/mindspore/lite/src/ops/gather.cc b/mindspore/lite/src/ops/gather.cc index a0ee20942e..56cda0a19f 100644 --- a/mindspore/lite/src/ops/gather.cc +++ b/mindspore/lite/src/ops/gather.cc @@ -19,6 +19,9 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/gather_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -46,8 +49,25 @@ int Gather::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: int Gather::GetAxis() const { return this->primitive_->value_as_Gather()->axis(); } int Gather::GetBatchDims() const { return this->primitive_->value_as_Gather()->batchDims(); } +PrimitiveC *GatherCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry GatherRegistry(schema::PrimitiveType_Gather, GatherCreator); #endif +OpParameter *PopulateGatherParameter(const mindspore::lite::PrimitiveC *primitive) { + auto gather_attr = reinterpret_cast(const_cast(primitive)); + GatherParameter *gather_param = reinterpret_cast(malloc(sizeof(GatherParameter))); + if (gather_param == nullptr) { + MS_LOG(ERROR) << "malloc GatherParameter failed."; + return nullptr; + } + memset(gather_param, 0, sizeof(GatherParameter)); + gather_param->op_parameter_.type_ = primitive->Type(); + gather_param->axis_ = gather_attr->GetAxis(); + gather_param->batchDims_ = gather_attr->GetBatchDims(); + return reinterpret_cast(gather_param); +} +Registry GatherParameterRegistry(schema::PrimitiveType_Gather, PopulateGatherParameter); + int Gather::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/gather_nd.cc b/mindspore/lite/src/ops/gather_nd.cc index 0456b02d0a..ae3376f7ca 100644 --- a/mindspore/lite/src/ops/gather_nd.cc +++ b/mindspore/lite/src/ops/gather_nd.cc @@ -16,6 +16,9 @@ #include "src/ops/gather_nd.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/gatherNd.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -40,8 +43,28 @@ int GatherNd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer } int GatherNd::GetBatchDims() const { return this->primitive_->value_as_GatherNd()->batchDims(); } +PrimitiveC *GatherNdCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry GatherNdRegistry(schema::PrimitiveType_GatherNd, GatherNdCreator); #endif +OpParameter *PopulateGatherNdParameter(const mindspore::lite::PrimitiveC *primitive) { + GatherNdParameter *gather_nd_param = reinterpret_cast(malloc(sizeof(GatherNdParameter))); + if (gather_nd_param == nullptr) { + MS_LOG(ERROR) << "malloc GatherNdParameter failed."; + return nullptr; + } + memset(gather_nd_param, 0, sizeof(GatherNdParameter)); + gather_nd_param->op_parameter_.type_ = primitive->Type(); + auto gatherNd_attr = + reinterpret_cast(const_cast(primitive)); + gather_nd_param->batchDims_ = gatherNd_attr->GetBatchDims(); + return reinterpret_cast(gather_nd_param); +} + +Registry GatherNdParameterRegistry(schema::PrimitiveType_GatherNd, PopulateGatherNdParameter); + int GatherNd::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/greater.cc b/mindspore/lite/src/ops/greater.cc index 0d7bf7f555..72828eacd8 100644 --- a/mindspore/lite/src/ops/greater.cc +++ b/mindspore/lite/src/ops/greater.cc @@ -16,6 +16,8 @@ #include "src/ops/greater.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -28,6 +30,9 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); #endif int Greater::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -39,5 +44,6 @@ int Greater::InferShape(std::vector inputs_, std::vector out output->SetFormat(input->GetFormat()); return RET_OK; } +Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/greater_equal.cc b/mindspore/lite/src/ops/greater_equal.cc index ce30ba1742..203e8969f5 100644 --- a/mindspore/lite/src/ops/greater_equal.cc +++ b/mindspore/lite/src/ops/greater_equal.cc @@ -16,6 +16,8 @@ #include "src/ops/greater_equal.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -27,6 +29,12 @@ int GreaterEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); + #endif int GreaterEqual::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -38,5 +46,6 @@ int GreaterEqual::InferShape(std::vector inputs_, std::vectorSetFormat(input->GetFormat()); return RET_OK; } +Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/group_conv2d_grad_input.cc b/mindspore/lite/src/ops/group_conv2d_grad_input.cc index 82c1606b49..47f0a14ecd 100644 --- a/mindspore/lite/src/ops/group_conv2d_grad_input.cc +++ b/mindspore/lite/src/ops/group_conv2d_grad_input.cc @@ -16,6 +16,8 @@ #include "src/ops/group_conv2d_grad_input.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -127,6 +129,10 @@ int GroupConv2DGradInput::GetActivationType() const { return this->primitive_->value_as_GroupConv2DGradInput()->activationType(); } +PrimitiveC *GroupConv2DGradInputCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry GroupConv2DGradInputRegistry(schema::PrimitiveType_GroupConv2DGradInput, GroupConv2DGradInputCreator); #endif int GroupConv2DGradInput::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/hashtable_lookup.cc b/mindspore/lite/src/ops/hashtable_lookup.cc index 949fb029a0..48a107c055 100644 --- a/mindspore/lite/src/ops/hashtable_lookup.cc +++ b/mindspore/lite/src/ops/hashtable_lookup.cc @@ -17,6 +17,8 @@ #include "src/common/string_util.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -30,7 +32,24 @@ int HashtableLookup::UnPackToFlatBuilder(const schema::Primitive *primitive, fla fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *HashtableLookupCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry HashtableLookupRegistry(schema::PrimitiveType_HashtableLookup, HashtableLookupCreator); #endif + +OpParameter *PopulateHashtableLookupParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *param = reinterpret_cast(malloc(sizeof(OpParameter))); + if (param == nullptr) { + MS_LOG(ERROR) << "new OpParameter failed."; + return nullptr; + } + memset(param, 0, sizeof(OpParameter)); + param->type_ = primitive->Type(); + return param; +} +Registry HashtableLookupParameterRegistry(schema::PrimitiveType_HashtableLookup, PopulateHashtableLookupParameter); + int HashtableLookup::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.at(0); auto values = inputs_.at(2); diff --git a/mindspore/lite/src/ops/l2_norm.cc b/mindspore/lite/src/ops/l2_norm.cc index 4fc1486daa..030e55e061 100644 --- a/mindspore/lite/src/ops/l2_norm.cc +++ b/mindspore/lite/src/ops/l2_norm.cc @@ -16,6 +16,9 @@ #include "src/ops/l2_norm.h" +#include "src/ops/ops_register.h" +#include "nnacl/l2_norm_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -57,6 +60,44 @@ std::vector L2Norm::GetAxis() const { float L2Norm::GetEpsilon() const { return this->primitive_->value_as_L2Norm()->epsilon(); } int L2Norm::GetActivationType() const { return this->primitive_->value_as_L2Norm()->activationType(); } +PrimitiveC *L2NormCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry L2NormRegistry(schema::PrimitiveType_L2Norm, L2NormCreator); #endif +OpParameter *PopulateL2NormParameter(const mindspore::lite::PrimitiveC *primitive) { + L2NormParameter *l2_norm_parameter = reinterpret_cast(malloc(sizeof(L2NormParameter))); + if (l2_norm_parameter == nullptr) { + MS_LOG(ERROR) << "malloc L2NormParameter failed."; + return nullptr; + } + memset(l2_norm_parameter, 0, sizeof(L2NormParameter)); + l2_norm_parameter->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + auto axis_vec = param->GetAxis(); + l2_norm_parameter->axis_num_ = axis_vec.size(); + l2_norm_parameter->axis_ = reinterpret_cast(malloc(axis_vec.size() * sizeof(int))); + if (l2_norm_parameter->axis_ == nullptr) { + MS_LOG(ERROR) << "malloc axis_ data failed"; + free(l2_norm_parameter); + return nullptr; + } + for (size_t i = 0; i < axis_vec.size(); i++) { + l2_norm_parameter->axis_[i] = axis_vec[i]; + } + if (param->GetEpsilon() < 1e-6) { + l2_norm_parameter->epsilon_ = 1e-6; + } else { + l2_norm_parameter->epsilon_ = param->GetEpsilon(); + } + if (param->GetActivationType() == static_cast(schema::ActivationType_RELU)) { + l2_norm_parameter->act_type_ = ActType_Relu; + } else if (param->GetActivationType() == static_cast(schema::ActivationType_RELU6)) { + l2_norm_parameter->act_type_ = ActType_Relu6; + } else { + l2_norm_parameter->act_type_ = ActType_No; + } + return reinterpret_cast(l2_norm_parameter); +} +Registry L2NormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateL2NormParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/leaky_relu.cc b/mindspore/lite/src/ops/leaky_relu.cc index 7164372f6b..dfc7069093 100644 --- a/mindspore/lite/src/ops/leaky_relu.cc +++ b/mindspore/lite/src/ops/leaky_relu.cc @@ -16,6 +16,8 @@ #include "src/ops/leaky_relu.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -42,6 +44,10 @@ int LeakyReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *LeakyReLUCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LeakyReLURegistry(schema::PrimitiveType_LeakyReLU, LeakyReLUCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less.cc b/mindspore/lite/src/ops/less.cc index efc6e2af9d..a6c87ecbf8 100644 --- a/mindspore/lite/src/ops/less.cc +++ b/mindspore/lite/src/ops/less.cc @@ -16,6 +16,8 @@ #include "src/ops/less.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,10 @@ int Less::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); + #endif int Less::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -40,5 +46,6 @@ int Less::InferShape(std::vector inputs_, std::vector output output->SetFormat(input->GetFormat()); return RET_OK; } +Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/less_equal.cc b/mindspore/lite/src/ops/less_equal.cc index 0acc83c213..ec5df2c1db 100644 --- a/mindspore/lite/src/ops/less_equal.cc +++ b/mindspore/lite/src/ops/less_equal.cc @@ -16,6 +16,8 @@ #include "src/ops/less_equal.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,10 @@ int LessEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); #endif int LessEqual::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -39,5 +45,6 @@ int LessEqual::InferShape(std::vector inputs_, std::vector o output->SetFormat(input->GetFormat()); return RET_OK; } +Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/local_response_normalization.cc b/mindspore/lite/src/ops/local_response_normalization.cc index 518567d785..92e9231eb8 100644 --- a/mindspore/lite/src/ops/local_response_normalization.cc +++ b/mindspore/lite/src/ops/local_response_normalization.cc @@ -16,6 +16,9 @@ #include "src/ops/local_response_normalization.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/local_response_norm.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -76,6 +79,34 @@ int LocalResponseNormalization::UnPackToFlatBuilder(const schema::Primitive *pri return RET_OK; } +PrimitiveC *LocalResponseNormalizationCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LocalResponseNormalizationRegistry(schema::PrimitiveType_LocalResponseNormalization, + LocalResponseNormalizationCreator); + #endif + +OpParameter *PopulateLocalResponseNormParameter(const mindspore::lite::PrimitiveC *primitive) { + auto local_response_norm_attr = reinterpret_cast( + const_cast(primitive)); + LocalResponseNormParameter *lrn_param = + reinterpret_cast(malloc(sizeof(LocalResponseNormParameter))); + if (lrn_param == nullptr) { + MS_LOG(ERROR) << "malloc LocalResponseNormParameter failed."; + return nullptr; + } + memset(lrn_param, 0, sizeof(LocalResponseNormParameter)); + lrn_param->op_parameter_.type_ = primitive->Type(); + lrn_param->depth_radius_ = local_response_norm_attr->GetDepthRadius(); + lrn_param->bias_ = local_response_norm_attr->GetBias(); + lrn_param->alpha_ = local_response_norm_attr->GetAlpha(); + lrn_param->beta_ = local_response_norm_attr->GetBeta(); + return reinterpret_cast(lrn_param); +} + +Registry LocalResponseNormalizationParameterRegistry(schema::PrimitiveType_LocalResponseNormalization, + PopulateLocalResponseNormParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/log.cc b/mindspore/lite/src/ops/log.cc index 6567f194f0..6835c6d19a 100644 --- a/mindspore/lite/src/ops/log.cc +++ b/mindspore/lite/src/ops/log.cc @@ -17,6 +17,8 @@ #include "src/ops/log.h" #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -50,6 +52,11 @@ int Log::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *LogCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry LogRegistry(schema::PrimitiveType_Log, LogCreator); + #endif +Registry LogParameterRegistry(schema::PrimitiveType_Log, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/log_grad.cc b/mindspore/lite/src/ops/log_grad.cc index a290ff5f1c..3e797c7fcf 100644 --- a/mindspore/lite/src/ops/log_grad.cc +++ b/mindspore/lite/src/ops/log_grad.cc @@ -16,6 +16,9 @@ #include "src/ops/log_grad.h" +#include "src/ops/ops_register.h" +#include "src/ops/arithmetic_self.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -32,6 +35,11 @@ int LogGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LogGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry LogGradRegistry(schema::PrimitiveType_LogGrad, LogGradCreator); #endif +Registry LogGradParameterRegistry(schema::PrimitiveType_LogGrad, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_and.cc b/mindspore/lite/src/ops/logical_and.cc index 8cc73dfe50..c6e43984ae 100644 --- a/mindspore/lite/src/ops/logical_and.cc +++ b/mindspore/lite/src/ops/logical_and.cc @@ -16,6 +16,8 @@ #include "src/ops/logical_and.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,13 @@ int LogicalAnd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LogicalAndCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LogicalAndRegistry(schema::PrimitiveType_LogicalAnd, LogicalAndCreator); #endif + +Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_not.cc b/mindspore/lite/src/ops/logical_not.cc index c67869f932..3847e2b6a8 100644 --- a/mindspore/lite/src/ops/logical_not.cc +++ b/mindspore/lite/src/ops/logical_not.cc @@ -16,6 +16,8 @@ #include "src/ops/logical_not.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,12 @@ int LogicalNot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *LogicalNotCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LogicalNotRegistry(schema::PrimitiveType_LogicalNot, LogicalNotCreator); #endif +Registry LogicalNotParameterRegistry(schema::PrimitiveType_LogicalNot, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/logical_or.cc b/mindspore/lite/src/ops/logical_or.cc index 2d8f73f040..1fc125b059 100644 --- a/mindspore/lite/src/ops/logical_or.cc +++ b/mindspore/lite/src/ops/logical_or.cc @@ -16,6 +16,8 @@ #include "src/ops/logical_or.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,12 @@ int LogicalOr::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LogicalOrCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LogicalOrRegistry(schema::PrimitiveType_LogicalOr, LogicalOrCreator); #endif +Registry LogicalOrParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/lrn.cc b/mindspore/lite/src/ops/lrn.cc index 55d7745c9f..c4bde64483 100644 --- a/mindspore/lite/src/ops/lrn.cc +++ b/mindspore/lite/src/ops/lrn.cc @@ -16,6 +16,8 @@ #include "src/ops/lrn.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -49,6 +51,9 @@ int Lrn::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LrnCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry LrnRegistry(schema::PrimitiveType_Lrn, LrnCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/lsh_projection.cc b/mindspore/lite/src/ops/lsh_projection.cc index 2fd733047b..b206df98c2 100644 --- a/mindspore/lite/src/ops/lsh_projection.cc +++ b/mindspore/lite/src/ops/lsh_projection.cc @@ -17,6 +17,8 @@ #include "nnacl/lsh_projection_parameter.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -38,7 +40,29 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LshProjectionCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry LshProjectionRegistry(schema::PrimitiveType_LshProjection, LshProjectionCreator); + #endif + +OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) { + LshProjectionParameter *lsh_project_param = + reinterpret_cast(malloc(sizeof(LshProjectionParameter))); + if (lsh_project_param == nullptr) { + MS_LOG(ERROR) << "malloc LshProjectionParameter failed."; + return nullptr; + } + memset(lsh_project_param, 0, sizeof(LshProjectionParameter)); + lsh_project_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + lsh_project_param->lsh_type_ = param->GetLshType(); + return reinterpret_cast(lsh_project_param); +} +Registry LshProjectionParameterRegistry(schema::PrimitiveType_LshProjection, PopulateLshProjectionParameter); + int LshProjection::InferShape(std::vector inputs_, std::vector outputs_) { if (inputs_.size() != kDoubleNum && inputs_.size() != kMultiNum) { MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given."; diff --git a/mindspore/lite/src/ops/lstm.cc b/mindspore/lite/src/ops/lstm.cc index eec7b1b915..ef24a0f2bc 100644 --- a/mindspore/lite/src/ops/lstm.cc +++ b/mindspore/lite/src/ops/lstm.cc @@ -16,6 +16,9 @@ #include "src/ops/lstm.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/lstm.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -39,8 +42,31 @@ int Lstm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *LstmCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry LstmRegistry(schema::PrimitiveType_Lstm, LstmCreator); + #endif +OpParameter *PopulateLstmParameter(const mindspore::lite::PrimitiveC *primitive) { + LstmParameter *lstm_param = reinterpret_cast(malloc(sizeof(LstmParameter))); + if (lstm_param == nullptr) { + MS_LOG(ERROR) << "malloc LstmParameter failed."; + return nullptr; + } + memset(lstm_param, 0, sizeof(LstmParameter)); + lstm_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + if (param == nullptr) { + free(lstm_param); + MS_LOG(ERROR) << "get Lstm param nullptr."; + return nullptr; + } + lstm_param->bidirectional_ = param->GetBidirection(); + return reinterpret_cast(lstm_param); +} +Registry LstmParameterRegistry(schema::PrimitiveType_Lstm, PopulateLstmParameter); + const int kLstmInputNum = 6; const int kLstmOutputNum = 3; int Lstm::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/make_tuple.cc b/mindspore/lite/src/ops/make_tuple.cc index 78ca0b1084..b149679867 100644 --- a/mindspore/lite/src/ops/make_tuple.cc +++ b/mindspore/lite/src/ops/make_tuple.cc @@ -18,6 +18,8 @@ #include #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -57,6 +59,11 @@ int MakeTuple::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *MakeTupleCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry MakeTupleRegistry(schema::PrimitiveType_MakeTuple, MakeTupleCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index 3d1c38045b..e91d51a4d7 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -21,6 +21,9 @@ #include "tools/converter/quantizer/quantize_util.h" #endif +#include "src/ops/ops_register.h" +#include "nnacl/matmul_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -86,8 +89,27 @@ int MatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: return RET_OK; } +PrimitiveC *MatMulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MatMulRegistry(schema::PrimitiveType_MatMul, MatMulCreator); #endif +OpParameter *PopulateMatMulParameter(const mindspore::lite::PrimitiveC *primitive) { + auto param = reinterpret_cast(const_cast(primitive)); + MatMulParameter *matmul_param = reinterpret_cast(malloc(sizeof(MatMulParameter))); + if (matmul_param == nullptr) { + MS_LOG(ERROR) << "malloc MatMulParameter failed."; + return nullptr; + } + memset(matmul_param, 0, sizeof(MatMulParameter)); + matmul_param->op_parameter_.type_ = primitive->Type(); + matmul_param->b_transpose_ = param->GetTransposeB(); + matmul_param->a_transpose_ = param->GetTransposeA(); + matmul_param->has_bias_ = false; + matmul_param->act_type_ = ActType_No; + return reinterpret_cast(matmul_param); +} +Registry MatMulParameterRegistry(schema::PrimitiveType_MatMul, PopulateMatMulParameter); + int MatMul::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input0 = inputs_.front(); diff --git a/mindspore/lite/src/ops/matrix_diag.cc b/mindspore/lite/src/ops/matrix_diag.cc index b92094c85c..54958d6712 100644 --- a/mindspore/lite/src/ops/matrix_diag.cc +++ b/mindspore/lite/src/ops/matrix_diag.cc @@ -16,6 +16,8 @@ #include "src/ops/matrix_diag.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,6 +54,11 @@ int MatrixDiag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff return RET_OK; } +PrimitiveC *MatrixDiagCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry MatrixDiagRegistry(schema::PrimitiveType_MatrixDiag, MatrixDiagCreator); + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/maximum.cc b/mindspore/lite/src/ops/maximum.cc index 55224dc589..20be10444a 100644 --- a/mindspore/lite/src/ops/maximum.cc +++ b/mindspore/lite/src/ops/maximum.cc @@ -23,6 +23,8 @@ #include "tools/converter/quantizer/quantize_util.h" #endif +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -62,6 +64,10 @@ int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *MaximumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MaximumRegistry(schema::PrimitiveType_Maximum, MaximumCreator); #endif +Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mean.cc b/mindspore/lite/src/ops/mean.cc index fa51105c52..2e827b5fc6 100644 --- a/mindspore/lite/src/ops/mean.cc +++ b/mindspore/lite/src/ops/mean.cc @@ -16,6 +16,9 @@ #include "src/ops/mean.h" +#include "src/ops/ops_register.h" +#include "nnacl/reduce_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -53,8 +56,36 @@ int Mean::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F return RET_OK; } +PrimitiveC *MeanCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MeanRegistry(schema::PrimitiveType_Mean, MeanCreator); #endif +OpParameter *PopulateMeanParameter(const mindspore::lite::PrimitiveC *primitive) { + ReduceParameter *mean_param = reinterpret_cast(malloc(sizeof(ReduceParameter))); + if (mean_param == nullptr) { + MS_LOG(ERROR) << "malloc ReduceParameter failed."; + return nullptr; + } + memset(mean_param, 0, sizeof(ReduceParameter)); + mean_param->op_parameter_.type_ = primitive->Type(); + auto mean = reinterpret_cast(const_cast(primitive)); + mean_param->keep_dims_ = mean->GetKeepDims(); + auto axisVector = mean->GetAxis(); + if (axisVector.size() > REDUCE_MAX_AXES_NUM) { + MS_LOG(ERROR) << "Reduce axes size " << axisVector.size() << " exceed limit " << REDUCE_MAX_AXES_NUM; + free(mean_param); + return nullptr; + } + mean_param->num_axes_ = static_cast(axisVector.size()); + int i = 0; + for (auto iter = axisVector.begin(); iter != axisVector.end(); iter++) { + mean_param->axes_[i++] = *iter; + } + mean_param->mode_ = static_cast(schema::ReduceMode_ReduceMean); + return reinterpret_cast(mean_param); +} +Registry MeanParameterRegistry(schema::PrimitiveType_Mean, PopulateMeanParameter); + namespace { constexpr size_t kInputSize = 1; constexpr size_t kOutputSize = 1; diff --git a/mindspore/lite/src/ops/minimum.cc b/mindspore/lite/src/ops/minimum.cc index c2c8c8fd56..8d584e754a 100644 --- a/mindspore/lite/src/ops/minimum.cc +++ b/mindspore/lite/src/ops/minimum.cc @@ -16,6 +16,8 @@ #include "src/ops/minimum.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,10 @@ int Minimum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *MinimumCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MinimumRegistry(schema::PrimitiveType_Minimum, MinimumCreator); #endif + +Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/mul.cc b/mindspore/lite/src/ops/mul.cc index f6d2fab12a..d7a9dad330 100644 --- a/mindspore/lite/src/ops/mul.cc +++ b/mindspore/lite/src/ops/mul.cc @@ -16,6 +16,8 @@ #include "src/ops/mul.h" #include +#include "nnacl/arithmetic_common.h" +#include "src/ops/ops_register.h" namespace mindspore { namespace lite { @@ -72,6 +74,30 @@ int Mul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl return RET_OK; } +PrimitiveC *MulCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry MulRegistry(schema::PrimitiveType_Mul, MulCreator); #endif +OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive))->GetActivationType(); + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} +Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/nchw2nhwc.cc b/mindspore/lite/src/ops/nchw2nhwc.cc index e573fbf9f7..3413042baf 100644 --- a/mindspore/lite/src/ops/nchw2nhwc.cc +++ b/mindspore/lite/src/ops/nchw2nhwc.cc @@ -17,6 +17,9 @@ #include "src/ops/nchw2nhwc.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/transpose.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,8 +32,29 @@ int Nchw2Nhwc::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *Nchw2NhwcCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry Nchw2NhwcRegistry(schema::PrimitiveType_Nchw2Nhwc, Nchw2NhwcCreator); #endif +OpParameter *PopulateNchw2NhwcParameter(const mindspore::lite::PrimitiveC *primitive) { + TransposeParameter *parameter = reinterpret_cast(malloc(sizeof(TransposeParameter))); + if (parameter == nullptr) { + MS_LOG(ERROR) << "malloc OpParameter failed."; + return nullptr; + } + memset(parameter, 0, sizeof(OpParameter)); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_axes_ = 4; + parameter->perm_[0] = 0; + parameter->perm_[1] = 2; + parameter->perm_[2] = 3; + parameter->perm_[3] = 1; + return reinterpret_cast(parameter); +} +Registry Nchw2NhwcParameterRegistry(schema::PrimitiveType_Nchw2Nhwc, PopulateNchw2NhwcParameter); + int Nchw2Nhwc::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/neg.cc b/mindspore/lite/src/ops/neg.cc index 90fce187c4..8e51f870fc 100644 --- a/mindspore/lite/src/ops/neg.cc +++ b/mindspore/lite/src/ops/neg.cc @@ -16,6 +16,8 @@ #include "src/ops/neg.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,6 +54,9 @@ int Neg::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl return RET_OK; } +PrimitiveC *NegCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry NegRegistry(schema::PrimitiveType_Neg, NegCreator); #endif +Registry NegParameterRegistry(schema::PrimitiveType_Neg, PopulateArithmeticSelf); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/neg_grad.cc b/mindspore/lite/src/ops/neg_grad.cc index b5fad4919b..d244388e6a 100644 --- a/mindspore/lite/src/ops/neg_grad.cc +++ b/mindspore/lite/src/ops/neg_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/neg_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifndef PRIMITIVE_WRITEABLE @@ -28,6 +30,11 @@ int NegGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers return RET_OK; } +PrimitiveC *NegGradCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry NegGradRegistry(schema::PrimitiveType_NegGrad, NegGradCreator); + #endif +Registry NegGradParameterRegistry(schema::PrimitiveType_NegGrad, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/nhwc2nchw.cc b/mindspore/lite/src/ops/nhwc2nchw.cc index c3a03fc453..76022ef857 100644 --- a/mindspore/lite/src/ops/nhwc2nchw.cc +++ b/mindspore/lite/src/ops/nhwc2nchw.cc @@ -17,6 +17,9 @@ #include "src/ops/nhwc2nchw.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/transpose.h" + namespace mindspore { namespace lite { @@ -30,8 +33,30 @@ int Nhwc2Nchw::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *Nhwc2NchwCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry Nhwc2NchwRegistry(schema::PrimitiveType_Nhwc2Nchw, Nhwc2NchwCreator); #endif +OpParameter *PopulateNhwc2NchwParameter(const mindspore::lite::PrimitiveC *primitive) { + TransposeParameter *parameter = reinterpret_cast(malloc(sizeof(TransposeParameter))); + if (parameter == nullptr) { + MS_LOG(ERROR) << "malloc OpParameter failed."; + return nullptr; + } + memset(parameter, 0, sizeof(OpParameter)); + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_axes_ = 4; + parameter->perm_[0] = 0; + parameter->perm_[1] = 3; + parameter->perm_[2] = 1; + parameter->perm_[3] = 2; + return reinterpret_cast(parameter); +} + +Registry Nhwc2NchwParameterRegistry(schema::PrimitiveType_Nhwc2Nchw, PopulateNhwc2NchwParameter); + int Nhwc2Nchw::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/not_equal.cc b/mindspore/lite/src/ops/not_equal.cc index 31904273e3..bb36ba8394 100644 --- a/mindspore/lite/src/ops/not_equal.cc +++ b/mindspore/lite/src/ops/not_equal.cc @@ -16,6 +16,8 @@ #include "src/ops/not_equal.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,11 @@ int NotEqual::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); + #endif int NotEqual::InferShape(std::vector inputs_, std::vector outputs_) { auto input = inputs_.front(); @@ -39,5 +46,6 @@ int NotEqual::InferShape(std::vector inputs_, std::vector ou output->SetFormat(input->GetFormat()); return RET_OK; } +Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/one_hot.cc b/mindspore/lite/src/ops/one_hot.cc index f5235b7c19..d99a725581 100644 --- a/mindspore/lite/src/ops/one_hot.cc +++ b/mindspore/lite/src/ops/one_hot.cc @@ -16,6 +16,9 @@ #include "src/ops/one_hot.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/one_hot.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -71,8 +74,30 @@ int OneHot::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *OneHotCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator); #endif +OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) { + OneHotParameter *one_hot_param = reinterpret_cast(malloc(sizeof(OneHotParameter))); + if (one_hot_param == nullptr) { + MS_LOG(ERROR) << "malloc OneHotParameter failed."; + return nullptr; + } + memset(one_hot_param, 0, sizeof(OneHotParameter)); + one_hot_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + if (param == nullptr) { + free(one_hot_param); + MS_LOG(ERROR) << "get OneHot param nullptr."; + return nullptr; + } + one_hot_param->axis_ = param->GetAxis(); + return reinterpret_cast(one_hot_param); +} +Registry OneHotParameterRegistry(schema::PrimitiveType_OneHot, PopulateOneHotParameter); + namespace { constexpr size_t kOneHotInputNum = 4; } diff --git a/mindspore/lite/src/ops/ops_register.h b/mindspore/lite/src/ops/ops_register.h new file mode 100644 index 0000000000..78dfd0de2b --- /dev/null +++ b/mindspore/lite/src/ops/ops_register.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H +#define LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H + +#include +#include "src/ops/primitive_c.h" +namespace mindspore { +namespace lite { +class OpsRegistry { + public: + static OpsRegistry *GetInstance() { + static OpsRegistry registry; + return ®istry; + } + + void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { + primitive_creators[type] = creator; + } + PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) { + if (primitive_creators.find(type) != primitive_creators.end()) { + return primitive_creators[type]; + } else { + MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(type); + return nullptr; + } + } + + void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } + + ParameterCreator getParameterCreator(schema::PrimitiveType type) { + if (parameter_creators.find(type) != parameter_creators.end()) { + return parameter_creators[type]; + } else { + MS_LOG(ERROR) << "Unsupported parameter type in Create : " << schema::EnumNamePrimitiveType(type); + return nullptr; + } + } + + protected: + std::map primitive_creators; + std::map parameter_creators; +}; + +class Registry { + public: + Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { + OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator); + } + Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { + OpsRegistry::GetInstance()->insertParameterMap(primitive_type, creator); + } +}; + +} // namespace lite +} // namespace mindspore +#endif // LITE_MINDSPORE_LITE_C_OPS_OP_REGISTER_H diff --git a/mindspore/lite/src/ops/p_relu.cc b/mindspore/lite/src/ops/p_relu.cc index 2174e80baa..6cfe20ed57 100644 --- a/mindspore/lite/src/ops/p_relu.cc +++ b/mindspore/lite/src/ops/p_relu.cc @@ -16,6 +16,9 @@ #include "src/ops/p_relu.h" +#include "src/ops/ops_register.h" +#include "nnacl/prelu_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -46,6 +49,24 @@ int PReLU::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: return RET_OK; } +PrimitiveC *PReLUCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PReLURegistry(schema::PrimitiveType_PReLU, PReLUCreator); + #endif + +OpParameter *PopulatePReLUParameter(const mindspore::lite::PrimitiveC *primitive) { + auto param = reinterpret_cast(const_cast(primitive)); + PReluParameter *prelu_param = reinterpret_cast(malloc(sizeof(PReluParameter))); + if (prelu_param == nullptr) { + MS_LOG(ERROR) << "malloc PReluParameter failed."; + return nullptr; + } + memset(prelu_param, 0, sizeof(PReluParameter)); + prelu_param->op_parameter_.type_ = primitive->Type(); + prelu_param->channelShared = param->GetChannelShared(); + return reinterpret_cast(prelu_param); +} +Registry PReLUParameterRegistry(schema::PrimitiveType_PReLU, PopulatePReLUParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 4d6868da58..65dc8b9190 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -16,6 +16,9 @@ #include "src/ops/pad.h" +#include "src/ops/ops_register.h" +#include "nnacl/pad_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -57,7 +60,42 @@ int Pad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PadCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PadRegistry(schema::PrimitiveType_Pad, PadCreator); #endif +OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive) { + PadParameter *pad_param = reinterpret_cast(malloc(sizeof(PadParameter))); + if (pad_param == nullptr) { + MS_LOG(ERROR) << "malloc PadParameter failed."; + return nullptr; + } + memset(pad_param, 0, sizeof(PadParameter)); + pad_param->op_parameter_.type_ = primitive->Type(); + auto pad_node = reinterpret_cast(const_cast(primitive)); + pad_param->pad_mode_ = pad_node->GetPaddingMode(); + if (pad_param->pad_mode_ == static_cast(schema::PaddingMode_CONSTANT)) { + pad_param->constant_value_ = pad_node->GetConstantValue(); + auto size = pad_node->GetPaddings().size(); + if (size > MAX_PAD_SIZE) { + MS_LOG(ERROR) << "Invalid padding size: " << size; + free(pad_param); + return nullptr; + } + + for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) { + pad_param->paddings_[i] = 0; + } + for (size_t i = 0; i < size; i++) { + pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i]; + } + pad_param->padding_length = MAX_PAD_SIZE; + } + + return reinterpret_cast(pad_param); +} +Registry PadParameterRegistry(schema::PrimitiveType_Pad, PopulatePadParameter); + int Pad::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (this->primitive_ == nullptr) { diff --git a/mindspore/lite/src/ops/permute.cc b/mindspore/lite/src/ops/permute.cc index d51c99ebe8..176c1c85b0 100644 --- a/mindspore/lite/src/ops/permute.cc +++ b/mindspore/lite/src/ops/permute.cc @@ -16,6 +16,8 @@ #include "src/ops/permute.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -50,6 +52,9 @@ int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/pooling.cc b/mindspore/lite/src/ops/pooling.cc index 1f97dae91b..24bb472f4d 100644 --- a/mindspore/lite/src/ops/pooling.cc +++ b/mindspore/lite/src/ops/pooling.cc @@ -19,6 +19,9 @@ #include #include +#include "src/ops/ops_register.h" +#include "nnacl/pooling_parameter.h" + namespace mindspore { namespace lite { @@ -158,8 +161,73 @@ int Pooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers return RET_OK; } +PrimitiveC *PoolingCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PoolingRegistry(schema::PrimitiveType_Pooling, PoolingCreator); + #endif +OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primitive) { + auto pooling_primitive = + reinterpret_cast(const_cast(primitive)); + PoolingParameter *pooling_param = reinterpret_cast(malloc(sizeof(PoolingParameter))); + if (pooling_param == nullptr) { + MS_LOG(ERROR) << "malloc PoolingParameter failed."; + return nullptr; + } + memset(pooling_param, 0, sizeof(PoolingParameter)); + pooling_param->op_parameter_.type_ = primitive->Type(); + pooling_param->global_ = pooling_primitive->GetGlobal(); + pooling_param->window_w_ = pooling_primitive->GetWindowW(); + pooling_param->window_h_ = pooling_primitive->GetWindowH(); + auto pooling_lite_primitive = (lite::Pooling *)primitive; + pooling_param->pad_u_ = pooling_lite_primitive->PadUp(); + pooling_param->pad_d_ = pooling_lite_primitive->PadDown(); + pooling_param->pad_l_ = pooling_lite_primitive->PadLeft(); + pooling_param->pad_r_ = pooling_lite_primitive->PadRight(); + pooling_param->stride_w_ = pooling_primitive->GetStrideW(); + pooling_param->stride_h_ = pooling_primitive->GetStrideH(); + pooling_param->avg_mode_ = pooling_primitive->GetAvgMode(); + + auto is_global = pooling_primitive->GetGlobal(); + pooling_param->global_ = is_global; + auto pool_mode = pooling_primitive->GetPoolingMode(); + switch (pool_mode) { + case schema::PoolMode_MAX_POOLING: + pooling_param->pool_mode_ = PoolMode_MaxPool; + break; + case schema::PoolMode_MEAN_POOLING: + pooling_param->pool_mode_ = PoolMode_AvgPool; + break; + default: + pooling_param->pool_mode_ = PoolMode_No; + break; + } + + auto round_mode = pooling_primitive->GetRoundMode(); + switch (round_mode) { + case schema::RoundMode_FLOOR: + pooling_param->round_mode_ = RoundMode_Floor; + break; + case schema::RoundMode_CEIL: + pooling_param->round_mode_ = RoundMode_Ceil; + break; + default: + pooling_param->round_mode_ = RoundMode_No; + break; + } + + if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU) { + pooling_param->act_type_ = ActType_Relu; + } else if (pooling_primitive->GetActivationType() == schema::ActivationType_RELU6) { + pooling_param->act_type_ = ActType_Relu6; + } else { + pooling_param->act_type_ = ActType_No; + } + return reinterpret_cast(pooling_param); +} + +Registry PoolingParameterRegistry(schema::PrimitiveType_Pooling, PopulatePoolingParameter); + int Pooling::PadUp() const { return this->pad_u_; } int Pooling::PadDown() const { return this->pad_d_; } int Pooling::PadLeft() const { return this->pad_l_; } diff --git a/mindspore/lite/src/ops/pooling_grad.cc b/mindspore/lite/src/ops/pooling_grad.cc index dc100de7d3..db9618556a 100644 --- a/mindspore/lite/src/ops/pooling_grad.cc +++ b/mindspore/lite/src/ops/pooling_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/pooling_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -142,6 +144,11 @@ int PoolingGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuf fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PoolingGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry PoolingGradRegistry(schema::PrimitiveType_PoolingGrad, PoolingGradCreator); #endif int PoolingGrad::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/power.cc b/mindspore/lite/src/ops/power.cc index afc01bd57b..b9ea0904e3 100644 --- a/mindspore/lite/src/ops/power.cc +++ b/mindspore/lite/src/ops/power.cc @@ -16,6 +16,9 @@ #include "src/ops/power.h" +#include "src/ops/ops_register.h" +#include "nnacl/power_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -45,8 +48,28 @@ int Power::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PowerCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry PowerRegistry(schema::PrimitiveType_Power, PowerCreator); #endif +OpParameter *PopulatePowerParameter(const mindspore::lite::PrimitiveC *primitive) { + PowerParameter *power_param = reinterpret_cast(malloc(sizeof(PowerParameter))); + if (power_param == nullptr) { + MS_LOG(ERROR) << "malloc PowerParameter failed."; + return nullptr; + } + memset(power_param, 0, sizeof(PowerParameter)); + power_param->op_parameter_.type_ = primitive->Type(); + auto power = reinterpret_cast(const_cast(primitive)); + power_param->power_ = power->GetPower(); + power_param->scale_ = power->GetScale(); + power_param->shift_ = power->GetShift(); + return reinterpret_cast(power_param); +} + +Registry PowerParameterRegistry(schema::PrimitiveType_Power, PopulatePowerParameter); + int Power::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); auto x_tensor = inputs[0]; diff --git a/mindspore/lite/src/ops/power_grad.cc b/mindspore/lite/src/ops/power_grad.cc index 5529e1055a..1178de428f 100644 --- a/mindspore/lite/src/ops/power_grad.cc +++ b/mindspore/lite/src/ops/power_grad.cc @@ -16,6 +16,8 @@ #include "src/ops/power_grad.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -77,6 +79,11 @@ int PowerGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PowerGradCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry PowerGradRegistry(schema::PrimitiveType_PowerGrad, PowerGradCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index c5b862fb08..10406976b9 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -15,8 +15,10 @@ */ #include "src/ops/primitive_c.h" +#ifdef PRIMITIVE_WRITEABLE #include #include +#include "tools/converter/quantizer/quantize_util.h" #include "src/ops/space_to_batch.h" #include "src/ops/space_to_batch_nd.h" #include "src/ops/conv2d.h" @@ -133,9 +135,6 @@ #include "src/ops/custom_normalize.h" #include "src/ops/custom_extract_features.h" #include "src/ops/upsample.h" -#ifdef PRIMITIVE_WRITEABLE -#include "tools/converter/quantizer/quantize_util.h" -#endif #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -158,6 +157,7 @@ #include "src/ops/assign.h" #endif +#endif namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -775,280 +775,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return nullptr; } #else -PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { - MS_ASSERT(primitive); - auto op_type = primitive->value_type(); - switch (op_type) { - case schema::PrimitiveType_SoftMax: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Activation: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Conv2D: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DeConv2D: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Reduce: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Pooling: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ROIPooling: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DepthwiseConv2D: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_FusedBatchNorm: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_BatchNorm: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_FullConnection: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Power: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Pad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Range: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Mul: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Add: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Sub: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Div: - return NewPrimitiveC
(primitive); - case schema::PrimitiveType_BiasAdd: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ExpandDims: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ArgMax: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ArgMin: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Cast: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Reshape: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Scale: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Eltwise: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Ceil: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Concat: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Fill: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Nhwc2Nchw: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Nchw2Nhwc: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Transpose: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Slice: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Squeeze: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Flatten: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Mean: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Stack: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Crop: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SquaredDifference: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_AddN: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Abs: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Sin: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Cos: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Log: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Neg: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Sqrt: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Rsqrt: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Square: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Exp: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Gather: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_GatherNd: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LocalResponseNormalization: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Maximum: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Minimum: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_StridedSlice: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LeakyReLU: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_PReLU: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Round: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Reverse: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ReverseSequence: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LogicalAnd: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LogicalOr: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LogicalNot: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_FloorDiv: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_FloorMod: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Equal: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_NotEqual: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Less: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LessEqual: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Greater: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_GreaterEqual: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Floor: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Split: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_OneHot: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_PriorBox: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SpaceToDepth: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Tile: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Resize: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Unstack: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Unique: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_TopK: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_MatMul: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_QuantDTypeCast: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_EmbeddingLookup: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Elu: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DeDepthwiseConv2D: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Shape: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Unsqueeze: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_BatchToSpace: - case schema::PrimitiveType_BatchToSpaceND: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SpaceToBatch: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SpaceToBatchND: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_BroadcastTo: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DepthToSpace: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Lstm: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ZerosLike: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_MakeTuple: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Where: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ScatterND: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ConstantOfShape: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_L2Norm: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SparseToDense: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DetectionPostProcess: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Dropout: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_RealDiv: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LshProjection: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_HashtableLookup: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SkipGram: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Clip: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_CustomPredict: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_CustomNormalize: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_CustomExtractFeatures: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Upsample: - return NewPrimitiveC(primitive); - -#ifdef SUPPORT_TRAIN - case schema::PrimitiveType_ActivationGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_PoolingGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Conv2DGradFilter: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Conv2DGradInput: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_GroupConv2DGradInput: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_BiasGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_ApplyMomentum: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_BNGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_AddGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SubGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_MulGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_DivGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_SoftmaxCrossEntropy: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_NegGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_LogGrad: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Sgd: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Adam: - return NewPrimitiveC(primitive); - case schema::PrimitiveType_Assign: - return NewPrimitiveC(primitive); -#endif - default: - MS_LOG(ERROR) << "Unsupported primitive type in Create : " << schema::EnumNamePrimitiveType(op_type); - break; - } - return nullptr; -} void PrimitiveC::SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } schema::QuantType PrimitiveC::GetQuantType() const { return quant_type_; } #endif diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index 13532a91ef..631c67a978 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -27,6 +27,7 @@ #else #include "schema/model_generated.h" #endif +#include "nnacl/op_base.h" #include "src/tensor.h" #include "include/errorcode.h" #include "src/common/log_adapter.h" @@ -160,7 +161,6 @@ class PrimitiveC { void SetQuantType(schema::QuantType quant_type); schema::QuantType GetQuantType() const; - protected: template ::value>> static PrimitiveC *NewPrimitiveC(const schema::Primitive *primitive) { auto primc = new T(); @@ -177,6 +177,7 @@ class PrimitiveC { return primc; } + protected: int UnPackSchemaPrimitive(const schema::Primitive *primitive) { flatbuffers::FlatBufferBuilder fbb(1024); if (UnPackToFlatBuilder(primitive, &fbb) != RET_OK) { @@ -212,7 +213,11 @@ class PrimitiveC { bool infer_flag_ = true; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; }; + #endif +typedef PrimitiveC *(*PrimitiveCCreator)(const schema::Primitive *primitive); +typedef OpParameter *(*ParameterCreator)(const PrimitiveC *primitive); + } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_SRC_OPS_PRIMITIVE_C_H_ diff --git a/mindspore/lite/src/ops/prior_box.cc b/mindspore/lite/src/ops/prior_box.cc index 3f9d53abac..36de9bf067 100644 --- a/mindspore/lite/src/ops/prior_box.cc +++ b/mindspore/lite/src/ops/prior_box.cc @@ -16,6 +16,9 @@ #include "src/ops/prior_box.h" +#include "src/ops/ops_register.h" +#include "nnacl/prior_box.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -114,8 +117,70 @@ int PriorBox::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *PriorBoxCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry PriorBoxRegistry(schema::PrimitiveType_PriorBox, PriorBoxCreator); #endif +OpParameter *PopulatePriorBoxParameter(const mindspore::lite::PrimitiveC *primitive) { + PriorBoxParameter *prior_box_param = reinterpret_cast(malloc(sizeof(PriorBoxParameter))); + if (prior_box_param == nullptr) { + MS_LOG(ERROR) << "malloc PriorBoxParameter failed."; + return nullptr; + } + memset(prior_box_param, 0, sizeof(PriorBoxParameter)); + prior_box_param->op_parameter_.type_ = primitive->Type(); + auto prior_box_attr = + reinterpret_cast(const_cast(primitive)); + + if (prior_box_attr->GetMinSizes().size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox min_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->GetMinSizes(); + free(prior_box_param); + return nullptr; + } + prior_box_param->min_sizes_size = prior_box_attr->GetMinSizes().size(); + if (prior_box_attr->GetMaxSizes().size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox max_sizes size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->GetMaxSizes(); + free(prior_box_param); + return nullptr; + } + prior_box_param->max_sizes_size = prior_box_attr->GetMaxSizes().size(); + memcpy(prior_box_param->max_sizes, prior_box_attr->GetMaxSizes().data(), + prior_box_attr->GetMaxSizes().size() * sizeof(int32_t)); + memcpy(prior_box_param->min_sizes, prior_box_attr->GetMinSizes().data(), + prior_box_attr->GetMinSizes().size() * sizeof(int32_t)); + + if (prior_box_attr->GetAspectRatios().size() > PRIOR_BOX_MAX_NUM) { + MS_LOG(ERROR) << "PriorBox aspect_ratios size exceeds max num " << PRIOR_BOX_MAX_NUM << ", got " + << prior_box_attr->GetAspectRatios(); + free(prior_box_param); + return nullptr; + } + prior_box_param->aspect_ratios_size = prior_box_attr->GetAspectRatios().size(); + memcpy(prior_box_param->aspect_ratios, prior_box_attr->GetAspectRatios().data(), + prior_box_attr->GetAspectRatios().size() * sizeof(float)); + if (prior_box_attr->GetVariances().size() != PRIOR_BOX_VAR_NUM) { + MS_LOG(ERROR) << "PriorBox variances size should be " << PRIOR_BOX_VAR_NUM << ", got " + << prior_box_attr->GetVariances().size(); + free(prior_box_param); + return nullptr; + } + memcpy(prior_box_param->variances, prior_box_attr->GetVariances().data(), PRIOR_BOX_VAR_NUM * sizeof(float)); + prior_box_param->flip = prior_box_attr->GetFlip(); + prior_box_param->clip = prior_box_attr->GetClip(); + prior_box_param->offset = prior_box_attr->GetOffset(); + prior_box_param->image_size_h = prior_box_attr->GetImageSizeH(); + prior_box_param->image_size_w = prior_box_attr->GetImageSizeW(); + prior_box_param->step_h = prior_box_attr->GetStepH(); + prior_box_param->step_w = prior_box_attr->GetStepW(); + return reinterpret_cast(prior_box_param); +} +Registry PriorBoxParameterRegistry(schema::PrimitiveType_PriorBox, PopulatePriorBoxParameter); + namespace { constexpr int kPriorBoxPoints = 4; constexpr int kPriorBoxN = 1; diff --git a/mindspore/lite/src/ops/quant.cc b/mindspore/lite/src/ops/quant.cc index 3e9500ce1a..e5ad6aed3e 100644 --- a/mindspore/lite/src/ops/quant.cc +++ b/mindspore/lite/src/ops/quant.cc @@ -18,6 +18,8 @@ #include #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index 6f2270ac84..b9dfc7e06d 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -16,6 +16,9 @@ #include "src/ops/quant_dtype_cast.h" +#include "src/ops/ops_register.h" +#include "nnacl/int8/quant_dtype_cast.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -42,8 +45,30 @@ int QuantDTypeCast::UnPackToFlatBuilder(const schema::Primitive *primitive, flat fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *QuantDTypeCastCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry QuantDTypeCastRegistry(schema::PrimitiveType_QuantDTypeCast, QuantDTypeCastCreator); #endif +OpParameter *PopulateQuantDTypeCastParameter(const mindspore::lite::PrimitiveC *primitive) { + QuantDTypeCastParameter *parameter = + reinterpret_cast(malloc(sizeof(QuantDTypeCastParameter))); + if (parameter == nullptr) { + MS_LOG(ERROR) << "malloc QuantDTypeCastParameter failed."; + return nullptr; + } + memset(parameter, 0, sizeof(QuantDTypeCastParameter)); + parameter->op_parameter_.type_ = primitive->Type(); + auto quant_dtype_cast_param = + reinterpret_cast(const_cast(primitive)); + parameter->srcT = quant_dtype_cast_param->GetSrcT(); + parameter->dstT = quant_dtype_cast_param->GetDstT(); + return reinterpret_cast(parameter); +} +Registry QuantDTypeCastParameterRegistry(schema::PrimitiveType_QuantDTypeCast, PopulateQuantDTypeCastParameter); + int QuantDTypeCast::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index 65b594bf42..e5c64c7ad6 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -16,6 +16,9 @@ #include "src/ops/range.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/range.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -48,8 +51,28 @@ int Range::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *RangeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RangeRegistry(schema::PrimitiveType_Range, RangeCreator); #endif +OpParameter *PopulateRangeParameter(const mindspore::lite::PrimitiveC *primitive) { + auto range_attr = reinterpret_cast(const_cast(primitive)); + RangeParameter *range_param = reinterpret_cast(malloc(sizeof(RangeParameter))); + if (range_param == nullptr) { + MS_LOG(ERROR) << "malloc RangeParameter failed."; + return nullptr; + } + memset(range_param, 0, sizeof(RangeParameter)); + range_param->op_parameter_.type_ = primitive->Type(); + range_param->start_ = range_attr->GetStart(); + range_param->limit_ = range_attr->GetLimit(); + range_param->delta_ = range_attr->GetDelta(); + range_param->dType_ = range_attr->GetDType(); + return reinterpret_cast(range_param); +} +Registry RangeParameterRegistry(schema::PrimitiveType_Range, PopulateRangeParameter); + int Range::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/rank.cc b/mindspore/lite/src/ops/rank.cc index a95efb2006..b4a1fed042 100644 --- a/mindspore/lite/src/ops/rank.cc +++ b/mindspore/lite/src/ops/rank.cc @@ -16,6 +16,8 @@ #include "src/ops/rank.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,9 @@ int Rank::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *RankCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RankRegistry(schema::PrimitiveType_Rank, RankCreator); #endif int Rank::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); diff --git a/mindspore/lite/src/ops/real_div.cc b/mindspore/lite/src/ops/real_div.cc index 0de4bc3427..fd3f0efbda 100644 --- a/mindspore/lite/src/ops/real_div.cc +++ b/mindspore/lite/src/ops/real_div.cc @@ -15,6 +15,8 @@ */ #include "src/ops/real_div.h" +#include "src/ops/ops_register.h" +#include "nnacl/arithmetic_common.h" namespace mindspore { namespace lite { @@ -43,6 +45,8 @@ int RealDiv::UnPackAttr(const Primitive &prim, const std::vector &in return RET_OK; } +Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic); + #else int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); @@ -52,6 +56,8 @@ int RealDiv::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *RealDivCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RealDivRegistry(schema::PrimitiveType_RealDiv, RealDivCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/reduce.cc b/mindspore/lite/src/ops/reduce.cc index 2fd893f405..4e10b0d33d 100644 --- a/mindspore/lite/src/ops/reduce.cc +++ b/mindspore/lite/src/ops/reduce.cc @@ -17,6 +17,9 @@ #include "src/ops/reduce.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/reduce_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -124,8 +127,40 @@ int Reduce::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ReduceCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ReduceRegistry(schema::PrimitiveType_Reduce, ReduceCreator); #endif +OpParameter *PopulateReduceParameter(const mindspore::lite::PrimitiveC *primitive) { + ReduceParameter *reduce_param = reinterpret_cast(malloc(sizeof(ReduceParameter))); + if (reduce_param == nullptr) { + MS_LOG(ERROR) << "malloc ReduceParameter failed."; + return nullptr; + } + memset(reduce_param, 0, sizeof(ReduceParameter)); + reduce_param->op_parameter_.type_ = primitive->Type(); + auto reduce = reinterpret_cast(const_cast(primitive)); + reduce_param->keep_dims_ = reduce->GetKeepDims(); + reduce_param->reduce_to_end_ = reduce->GetReduceToEnd(); + reduce_param->coeff = reduce->GetCoeff(); + auto axisVector = reduce->GetAxes(); + if (axisVector.size() > REDUCE_MAX_AXES_NUM) { + MS_LOG(ERROR) << "Reduce axes size " << axisVector.size() << " exceed limit " << REDUCE_MAX_AXES_NUM; + free(reduce_param); + return nullptr; + } + reduce_param->num_axes_ = static_cast(axisVector.size()); + int i = 0; + for (auto iter = axisVector.begin(); iter != axisVector.end(); iter++) { + reduce_param->axes_[i++] = *iter; + } + reduce_param->mode_ = static_cast(reduce->GetMode()); + return reinterpret_cast(reduce_param); +} + +Registry ReduceParameterRegistry(schema::PrimitiveType_Reduce, PopulateReduceParameter); + namespace { constexpr size_t kInputSize = 1; constexpr size_t kOutputSize = 1; diff --git a/mindspore/lite/src/ops/reshape.cc b/mindspore/lite/src/ops/reshape.cc index 78c4254f10..2b4b0ca863 100644 --- a/mindspore/lite/src/ops/reshape.cc +++ b/mindspore/lite/src/ops/reshape.cc @@ -20,6 +20,8 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/reshape_parameter.h" namespace mindspore { namespace lite { @@ -100,8 +102,24 @@ int Reshape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ReshapeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ReshapeRegistry(schema::PrimitiveType_Reshape, ReshapeCreator); #endif +OpParameter *PopulateReshapeParameter(const mindspore::lite::PrimitiveC *primitive) { + ReshapeParameter *reshape_param = reinterpret_cast(malloc(sizeof(ReshapeParameter))); + if (reshape_param == nullptr) { + MS_LOG(ERROR) << "malloc ReshapeParameter failed."; + return nullptr; + } + memset(reshape_param, 0, sizeof(ReshapeParameter)); + reshape_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(reshape_param); +} + +Registry ReshapeParameterRegistry(schema::PrimitiveType_Reshape, PopulateReshapeParameter); + int Reshape::CalNewShape(const Tensor *in_tensor, std::vector *out_shape) const { size_t in_shape_size = 1; for (size_t i = 0; i < in_tensor->shape().size(); i++) { diff --git a/mindspore/lite/src/ops/resize.cc b/mindspore/lite/src/ops/resize.cc index b509190923..9aa740a846 100644 --- a/mindspore/lite/src/ops/resize.cc +++ b/mindspore/lite/src/ops/resize.cc @@ -16,6 +16,9 @@ #include "src/ops/resize.h" +#include "src/ops/ops_register.h" +#include "nnacl/resize_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -93,7 +96,30 @@ int Resize::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ResizeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ResizeRegistry(schema::PrimitiveType_Resize, ResizeCreator); #endif + +OpParameter *PopulateResizeParameter(const mindspore::lite::PrimitiveC *primitive) { + ResizeParameter *resize_param = reinterpret_cast(malloc(sizeof(ResizeParameter))); + if (resize_param == nullptr) { + MS_LOG(ERROR) << "malloc ResizeParameter failed."; + return nullptr; + } + memset(resize_param, 0, sizeof(ResizeParameter)); + resize_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + resize_param->method_ = static_cast(param->GetMethod()); + resize_param->new_height_ = param->GetNewHeight(); + resize_param->new_width_ = param->GetNewWidth(); + resize_param->align_corners_ = param->GetAlignCorners(); + resize_param->preserve_aspect_ratio_ = param->GetPreserveAspectRatio(); + return reinterpret_cast(resize_param); +} + +Registry ResizeParameterRegistry(schema::PrimitiveType_Resize, PopulateResizeParameter); + namespace { constexpr int kInputRank = 4; } // namespace diff --git a/mindspore/lite/src/ops/return.cc b/mindspore/lite/src/ops/return.cc index 43dd4349e7..4a13f41716 100644 --- a/mindspore/lite/src/ops/return.cc +++ b/mindspore/lite/src/ops/return.cc @@ -17,6 +17,8 @@ #include "src/ops/return.h" #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -47,6 +49,9 @@ int Return::UnPackAttr(const Primitive &prim, const std::vector &inp } return RET_OK; } +#else +PrimitiveC *ReturnCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ReturnRegistry(schema::PrimitiveType_Return, ReturnCreator); #endif namespace { diff --git a/mindspore/lite/src/ops/reverse.cc b/mindspore/lite/src/ops/reverse.cc index 11bb4388d5..8bb9180a0c 100644 --- a/mindspore/lite/src/ops/reverse.cc +++ b/mindspore/lite/src/ops/reverse.cc @@ -16,6 +16,9 @@ #include "src/ops/reverse.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/reverse.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -48,6 +51,31 @@ int Reverse::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ReverseCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ReverseRegistry(schema::PrimitiveType_Reverse, ReverseCreator); + #endif +OpParameter *PopulateReverseParameter(const mindspore::lite::PrimitiveC *primitive) { + auto reverse_attr = + reinterpret_cast(const_cast(primitive)); + ReverseParameter *reverse_param = reinterpret_cast(malloc(sizeof(ReverseParameter))); + if (reverse_param == nullptr) { + MS_LOG(ERROR) << "malloc ReverseParameter failed."; + return nullptr; + } + memset(reverse_param, 0, sizeof(ReverseParameter)); + reverse_param->op_parameter_.type_ = primitive->Type(); + auto flatAxis = reverse_attr->GetAxis(); + reverse_param->num_axis_ = flatAxis.size(); + int i = 0; + for (auto iter = flatAxis.begin(); iter != flatAxis.end(); iter++) { + reverse_param->axis_[i++] = *iter; + } + return reinterpret_cast(reverse_param); +} + +Registry ReverseParameterRegistry(schema::PrimitiveType_Reverse, PopulateReverseParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/reverse_sequence.cc b/mindspore/lite/src/ops/reverse_sequence.cc index 6fceb8b1cb..8ffacf792f 100644 --- a/mindspore/lite/src/ops/reverse_sequence.cc +++ b/mindspore/lite/src/ops/reverse_sequence.cc @@ -16,6 +16,9 @@ #include "src/ops/reverse_sequence.h" +#include "src/ops/ops_register.h" +#include "nnacl/reverse_sequence.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -45,8 +48,31 @@ int ReverseSequence::UnPackToFlatBuilder(const schema::Primitive *primitive, fla fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ReverseSequenceCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ReverseSequenceRegistry(schema::PrimitiveType_ReverseSequence, ReverseSequenceCreator); + #endif +OpParameter *PopulateReverseSequenceParameter(const mindspore::lite::PrimitiveC *primitive) { + ReverseSequenceParameter *reverse_sequence_param = + reinterpret_cast(malloc(sizeof(ReverseSequenceParameter))); + if (reverse_sequence_param == nullptr) { + MS_LOG(ERROR) << "malloc ReverseSequenceParameter failed."; + return nullptr; + } + memset(reverse_sequence_param, 0, sizeof(ReverseSequenceParameter)); + auto param = + reinterpret_cast(const_cast(primitive)); + reverse_sequence_param->op_parameter_.type_ = primitive->Type(); + reverse_sequence_param->seq_axis_ = param->GetSeqAxis(); + reverse_sequence_param->batch_axis_ = param->GetBatchAxis(); + return reinterpret_cast(reverse_sequence_param); +} +Registry ReverseSequenceParameterRegistry(schema::PrimitiveType_ReverseSequence, PopulateReverseSequenceParameter); + int ReverseSequence::InferShape(std::vector inputs, std::vector outputs) { auto input = inputs.front(); auto output = outputs.front(); diff --git a/mindspore/lite/src/ops/roi_pooling.cc b/mindspore/lite/src/ops/roi_pooling.cc index 90608afd85..e1e3cd69fb 100644 --- a/mindspore/lite/src/ops/roi_pooling.cc +++ b/mindspore/lite/src/ops/roi_pooling.cc @@ -16,6 +16,9 @@ #include "src/ops/roi_pooling.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/roi_pooling.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -47,8 +50,31 @@ int ROIPooling::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuff fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ROIPoolingCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ROIPoolingRegistry(schema::PrimitiveType_ROIPooling, ROIPoolingCreator); #endif +OpParameter *PopulateROIPoolingParameter(const mindspore::lite::PrimitiveC *primitive) { + const auto param = + reinterpret_cast(const_cast(primitive)); + ROIPoolingParameter *roi_pooling_param = reinterpret_cast(malloc(sizeof(ROIPoolingParameter))); + if (roi_pooling_param == nullptr) { + MS_LOG(ERROR) << "malloc ROIPoolingParameter failed."; + return nullptr; + } + memset(roi_pooling_param, 0, sizeof(ROIPoolingParameter)); + roi_pooling_param->op_parameter_.type_ = primitive->Type(); + roi_pooling_param->pooledH_ = param->GetPooledW(); + roi_pooling_param->pooledW_ = param->GetPooledW(); + roi_pooling_param->scale_ = param->GetScale(); + return reinterpret_cast(roi_pooling_param); +} + +Registry ROIPoolingParameterRegistry(schema::PrimitiveType_ROIPooling, PopulateROIPoolingParameter); + int ROIPooling::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/round.cc b/mindspore/lite/src/ops/round.cc index ae3167597c..33e902cf4e 100644 --- a/mindspore/lite/src/ops/round.cc +++ b/mindspore/lite/src/ops/round.cc @@ -16,6 +16,8 @@ #include "src/ops/round.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -28,6 +30,12 @@ int Round::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *RoundCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RoundRegistry(schema::PrimitiveType_Round, RoundCreator); + #endif +Registry RoundParameterRegistry(schema::PrimitiveType_Round, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/rsqrt.cc b/mindspore/lite/src/ops/rsqrt.cc index 742aed2953..c3de4f79e9 100644 --- a/mindspore/lite/src/ops/rsqrt.cc +++ b/mindspore/lite/src/ops/rsqrt.cc @@ -16,6 +16,8 @@ #include "src/ops/rsqrt.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,10 @@ int Rsqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *RsqrtCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry RsqrtRegistry(schema::PrimitiveType_Rsqrt, RsqrtCreator); #endif +Registry RsqrtParameterRegistry(schema::PrimitiveType_Rsqrt, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scale.cc b/mindspore/lite/src/ops/scale.cc index 41256f7a77..5a852c209c 100644 --- a/mindspore/lite/src/ops/scale.cc +++ b/mindspore/lite/src/ops/scale.cc @@ -16,6 +16,9 @@ #include "src/ops/scale.h" +#include "src/ops/ops_register.h" +#include "nnacl/scale.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -43,6 +46,29 @@ int Scale::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ScaleCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ScaleRegistry(schema::PrimitiveType_Scale, ScaleCreator); #endif + +OpParameter *PopulateScaleParameter(const mindspore::lite::PrimitiveC *primitive) { + if (primitive == nullptr) { + MS_LOG(ERROR) << "input primitive is nullptr"; + return nullptr; + } + ScaleParameter *scale_param = reinterpret_cast(malloc(sizeof(ScaleParameter))); + if (scale_param == nullptr) { + MS_LOG(ERROR) << "malloc ScaleParameter failed."; + return nullptr; + } + memset(scale_param, 0, sizeof(ScaleParameter)); + scale_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + scale_param->axis_ = param->GetAxis(); + scale_param->activation_type_ = param->GetActivationType(); + return reinterpret_cast(scale_param); +} +Registry ScaleParameterRegistry(schema::PrimitiveType_Scale, PopulateScaleParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/scatter_nd.cc b/mindspore/lite/src/ops/scatter_nd.cc index 121ae4608f..2c3aa94675 100644 --- a/mindspore/lite/src/ops/scatter_nd.cc +++ b/mindspore/lite/src/ops/scatter_nd.cc @@ -16,6 +16,9 @@ #include "src/ops/scatter_nd.h" +#include "src/ops/ops_register.h" +#include "nnacl/scatter_nd.h" + namespace mindspore { namespace lite { @@ -73,5 +76,18 @@ int ScatterND::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe return RET_OK; } #endif + +OpParameter *PopulateScatterNDParameter(const mindspore::lite::PrimitiveC *primitive) { + ScatterNDParameter *scatter_nd_param = reinterpret_cast(malloc(sizeof(ScatterNDParameter))); + if (scatter_nd_param == nullptr) { + MS_LOG(ERROR) << "malloc ScatterNDParameter failed."; + return nullptr; + } + memset(scatter_nd_param, 0, sizeof(ScatterNDParameter)); + scatter_nd_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(scatter_nd_param); +} +Registry ScatterNDParameterRegistry(schema::PrimitiveType_ScatterND, PopulateScatterNDParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/sgd.cc b/mindspore/lite/src/ops/sgd.cc index 038d1e0d68..b3a96e4df2 100644 --- a/mindspore/lite/src/ops/sgd.cc +++ b/mindspore/lite/src/ops/sgd.cc @@ -14,6 +14,8 @@ * limitations under the License. */ #include "src/ops/sgd.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -70,6 +72,10 @@ int Sgd::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SgdCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SgdRegistry(schema::PrimitiveType_Sgd, SgdCreator); + #endif int Sgd::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/shape.cc b/mindspore/lite/src/ops/shape.cc index c0b8d69423..a32f4ef7fe 100644 --- a/mindspore/lite/src/ops/shape.cc +++ b/mindspore/lite/src/ops/shape.cc @@ -19,6 +19,9 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/shape.h" + namespace mindspore { namespace lite { @@ -58,6 +61,22 @@ int Shape::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *ShapeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry ShapeRegistry(schema::PrimitiveType_Shape, ShapeCreator); #endif + +OpParameter *PopulateShapeParameter(const mindspore::lite::PrimitiveC *primitive) { + ShapeParameter *shape_param = reinterpret_cast(malloc(sizeof(ShapeParameter))); + if (shape_param == nullptr) { + MS_LOG(ERROR) << "malloc ShapeParameter failed."; + return nullptr; + } + memset(shape_param, 0, sizeof(ShapeParameter)); + shape_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(shape_param); +} + +Registry ShapeParameterRegistry(schema::PrimitiveType_Shape, PopulateShapeParameter); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/sin.cc b/mindspore/lite/src/ops/sin.cc index eb363cf531..e7f4baee7c 100644 --- a/mindspore/lite/src/ops/sin.cc +++ b/mindspore/lite/src/ops/sin.cc @@ -16,6 +16,8 @@ #include "src/ops/sin.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,12 @@ int Sin::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SinCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SinRegistry(schema::PrimitiveType_Sin, SinCreator); + #endif +Registry SinParameterRegistry(schema::PrimitiveType_Sin, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/skip_gram.cc b/mindspore/lite/src/ops/skip_gram.cc index 90609ab31d..490619ad4b 100644 --- a/mindspore/lite/src/ops/skip_gram.cc +++ b/mindspore/lite/src/ops/skip_gram.cc @@ -16,6 +16,9 @@ #include "src/ops/skip_gram.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/skip_gram.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -50,7 +53,26 @@ int SkipGram::GetNgramSize() const { return this->primitive_->value_as_SkipGram( int SkipGram::GetMaxSkipSize() const { return this->primitive_->value_as_SkipGram()->maxSkipSize(); } bool SkipGram::GetIncludeAllNgrams() const { return this->primitive_->value_as_SkipGram()->includeAllGrams(); } +PrimitiveC *SkipGramCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SkipGramRegistry(schema::PrimitiveType_SkipGram, SkipGramCreator); #endif +OpParameter *PopulateSkipGramParameter(const mindspore::lite::PrimitiveC *primitive) { + SkipGramParameter *skipGramParameter = reinterpret_cast(malloc(sizeof(SkipGramParameter))); + if (skipGramParameter == nullptr) { + MS_LOG(ERROR) << "malloc SkipGramParameter failed."; + return nullptr; + } + memset(skipGramParameter, 0, sizeof(SkipGramParameter)); + skipGramParameter->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + skipGramParameter->ngram_size = param->GetNgramSize(); + skipGramParameter->max_skip_size = param->GetMaxSkipSize(); + skipGramParameter->include_all_ngrams = param->GetIncludeAllNgrams(); + return reinterpret_cast(skipGramParameter); +} +Registry SkipGramParameterRegistry(schema::PrimitiveType_SkipGram, PopulateSkipGramParameter); int SkipGram::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); diff --git a/mindspore/lite/src/ops/slice.cc b/mindspore/lite/src/ops/slice.cc index 6441504a2b..ea5f279f1c 100644 --- a/mindspore/lite/src/ops/slice.cc +++ b/mindspore/lite/src/ops/slice.cc @@ -19,6 +19,9 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/slice_parameter.h" + namespace mindspore { namespace lite { namespace { @@ -146,8 +149,36 @@ int Slice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SliceCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SliceRegistry(schema::PrimitiveType_Slice, SliceCreator); + #endif +OpParameter *PopulateSliceParameter(const mindspore::lite::PrimitiveC *primitive) { + SliceParameter *slice_param = reinterpret_cast(malloc(sizeof(SliceParameter))); + if (slice_param == nullptr) { + MS_LOG(ERROR) << "malloc SliceParameter failed."; + return nullptr; + } + memset(slice_param, 0, sizeof(SliceParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + slice_param->op_parameter_.type_ = primitive->Type(); + auto param_begin = param->GetPostProcessBegin(); + auto param_size = param->GetPostProcessSize(); + if (param_begin.size() != param_size.size()) { + free(slice_param); + return nullptr; + } + slice_param->param_length_ = static_cast(param_begin.size()); + for (int32_t i = 0; i < slice_param->param_length_; ++i) { + slice_param->begin_[i] = param_begin[i]; + slice_param->size_[i] = param_size[i]; + } + return reinterpret_cast(slice_param); +} +Registry SliceParameterRegistry(schema::PrimitiveType_Slice, PopulateSliceParameter); + std::vector Slice::GetPostProcessBegin() const { return this->begin; } std::vector Slice::GetPostProcessSize() const { return this->size; } int Slice::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/softmax.cc b/mindspore/lite/src/ops/softmax.cc index 08b6c24233..187395b8d1 100644 --- a/mindspore/lite/src/ops/softmax.cc +++ b/mindspore/lite/src/ops/softmax.cc @@ -15,6 +15,8 @@ */ #include "src/ops/softmax.h" +#include "src/ops/ops_register.h" +#include "nnacl/softmax_parameter.h" namespace mindspore { namespace lite { @@ -69,8 +71,27 @@ int SoftMax::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SoftMaxCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SoftMaxRegistry(schema::PrimitiveType_SoftMax, SoftMaxCreator); #endif +OpParameter *PopulateSoftmaxParameter(const mindspore::lite::PrimitiveC *primitive) { + auto softmax_primitive = + reinterpret_cast(const_cast(primitive)); + SoftmaxParameter *softmax_param = reinterpret_cast(malloc(sizeof(SoftmaxParameter))); + if (softmax_param == nullptr) { + MS_LOG(ERROR) << "malloc SoftmaxParameter failed."; + return nullptr; + } + memset(softmax_param, 0, sizeof(SoftmaxParameter)); + softmax_param->op_parameter_.type_ = primitive->Type(); + softmax_param->axis_ = softmax_primitive->GetAxis(); + return reinterpret_cast(softmax_param); +} + +Registry SoftMaxParameterRegistry(schema::PrimitiveType_SoftMax, PopulateSoftmaxParameter); + int SoftMax::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/softmax_cross_entropy.cc b/mindspore/lite/src/ops/softmax_cross_entropy.cc index 5f8099f3a4..c7e78b0930 100644 --- a/mindspore/lite/src/ops/softmax_cross_entropy.cc +++ b/mindspore/lite/src/ops/softmax_cross_entropy.cc @@ -16,6 +16,8 @@ #include "src/ops/softmax_cross_entropy.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -78,6 +80,11 @@ int SoftmaxCrossEntropy::UnPackToFlatBuilder(const schema::Primitive *primitive, fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SoftmaxCrossEntropyCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SoftmaxCrossEntropyRegistry(schema::PrimitiveType_SoftmaxCrossEntropy, SoftmaxCrossEntropyCreator); #endif int SoftmaxCrossEntropy::InferShape(std::vector inputs, std::vector outputs) { diff --git a/mindspore/lite/src/ops/space_to_batch.cc b/mindspore/lite/src/ops/space_to_batch.cc index bc0ea6e6de..0629be39f2 100644 --- a/mindspore/lite/src/ops/space_to_batch.cc +++ b/mindspore/lite/src/ops/space_to_batch.cc @@ -17,6 +17,9 @@ #include "src/ops/space_to_batch.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/space_to_batch.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -65,7 +68,30 @@ int SpaceToBatch::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SpaceToBatchCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SpaceToBatchRegistry(schema::PrimitiveType_SpaceToBatch, SpaceToBatchCreator); + #endif +OpParameter *PopulateSpaceToBatchParameter(const mindspore::lite::PrimitiveC *primitive) { + SpaceToBatchParameter *space_batch_param = + reinterpret_cast(malloc(sizeof(SpaceToBatchParameter))); + if (space_batch_param == nullptr) { + MS_LOG(ERROR) << "malloc SpaceToBatchParameter failed."; + return nullptr; + } + memset(space_batch_param, 0, sizeof(SpaceToBatchParameter)); + space_batch_param->op_parameter_.type_ = primitive->Type(); + auto block_sizes = ((mindspore::lite::SpaceToBatch *)primitive)->BlockSizes(); + memcpy(space_batch_param->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); + auto paddings = ((mindspore::lite::SpaceToBatch *)primitive)->Paddings(); + memcpy(space_batch_param->paddings_, (paddings.data()), paddings.size() * sizeof(int)); + return reinterpret_cast(space_batch_param); +} +Registry SpaceToBatchParameterRegistry(schema::PrimitiveType_SpaceToBatch, PopulateSpaceToBatchParameter); + namespace { constexpr int kSpaceToBatchNDOutputNum = 1; constexpr int kSpaceToBatchNDInputNum = 1; diff --git a/mindspore/lite/src/ops/space_to_batch_nd.cc b/mindspore/lite/src/ops/space_to_batch_nd.cc index 91a5a4366a..59c89b6726 100644 --- a/mindspore/lite/src/ops/space_to_batch_nd.cc +++ b/mindspore/lite/src/ops/space_to_batch_nd.cc @@ -17,6 +17,9 @@ #include "src/ops/space_to_batch_nd.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/space_to_batch.h" + namespace mindspore { namespace lite { namespace { @@ -76,8 +79,29 @@ int SpaceToBatchND::UnPackToFlatBuilder(const schema::Primitive *primitive, flat return RET_OK; } +PrimitiveC *SpaceToBatchNDCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SpaceToBatchNDRegistry(schema::PrimitiveType_SpaceToBatchND, SpaceToBatchNDCreator); + #endif // PRIMITIVE_WRITEABLE +OpParameter *PopulateSpaceToBatchNDParameter(const mindspore::lite::PrimitiveC *primitive) { + auto *space_batch_param_nd = new (std::nothrow) SpaceToBatchParameter(); + if (space_batch_param_nd == nullptr) { + MS_LOG(ERROR) << "new SpaceToBatchParameter failed."; + return nullptr; + } + + space_batch_param_nd->op_parameter_.type_ = primitive->Type(); + auto block_sizes = ((mindspore::lite::SpaceToBatchND *)primitive)->GetBlockShape(); + memcpy(space_batch_param_nd->block_sizes_, (block_sizes.data()), block_sizes.size() * sizeof(int)); + auto paddings = ((mindspore::lite::SpaceToBatchND *)primitive)->GetPaddings(); + memcpy(space_batch_param_nd->paddings_, (paddings.data()), paddings.size() * sizeof(int)); + return reinterpret_cast(space_batch_param_nd); +} +Registry SpaceToBatchNDParameterRegistry(schema::PrimitiveType_SpaceToBatchND, PopulateSpaceToBatchNDParameter); + int SpaceToBatchND::InferShape(std::vector inputs, std::vector outputs) { if (outputs.size() != kSpaceToBatchNDOutputNum || inputs.size() != kSpaceToBatchNDInputNum) { MS_LOG(ERROR) << "Invalid output/input size! output size: " << outputs.size() << ",input size: " << inputs.size(); diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index f095b6d45e..2e2b53ed1a 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -17,6 +17,9 @@ #include "src/ops/space_to_depth.h" #include "src/common/common.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/space_to_depth.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -43,7 +46,33 @@ int SpaceToDepth::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SpaceToDepthCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SpaceToDepthRegistry(schema::PrimitiveType_SpaceToDepth, SpaceToDepthCreator); #endif +OpParameter *PopulateSpaceToDepthParameter(const mindspore::lite::PrimitiveC *primitive) { + SpaceToDepthParameter *space_depth_param = + reinterpret_cast(malloc(sizeof(SpaceToDepthParameter))); + if (space_depth_param == nullptr) { + MS_LOG(ERROR) << "malloc SpaceToDepthParameter failed."; + return nullptr; + } + memset(space_depth_param, 0, sizeof(SpaceToDepthParameter)); + space_depth_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + space_depth_param->op_parameter_.type_ = primitive->Type(); + space_depth_param->block_size_ = param->GetBlockSize(); + if (param->GetFormat() != schema::Format::Format_NHWC) { + MS_LOG(ERROR) << "Currently only NHWC format is supported."; + free(space_depth_param); + return nullptr; + } + return reinterpret_cast(space_depth_param); +} +Registry SpaceToDepthParameterRegistry(schema::PrimitiveType_SpaceToDepth, PopulateSpaceToDepthParameter); + namespace { constexpr int kSpaceToDepthOutputNum = 1; constexpr int kSpaceToDepthInputNum = 1; diff --git a/mindspore/lite/src/ops/sparse_to_dense.cc b/mindspore/lite/src/ops/sparse_to_dense.cc index 04a745843d..35be40d322 100644 --- a/mindspore/lite/src/ops/sparse_to_dense.cc +++ b/mindspore/lite/src/ops/sparse_to_dense.cc @@ -16,6 +16,9 @@ #include "src/ops/sparse_to_dense.h" +#include "src/ops/ops_register.h" +#include "nnacl/sparse_to_dense_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -41,8 +44,29 @@ int SparseToDense::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SparseToDenseCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SparseToDenseRegistry(schema::PrimitiveType_SparseToDense, SparseToDenseCreator); #endif +OpParameter *PopulateSparseToDenseParameter(const mindspore::lite::PrimitiveC *primitive) { + SparseToDenseParameter *sparse_to_dense_param = + reinterpret_cast(malloc(sizeof(SparseToDenseParameter))); + if (sparse_to_dense_param == nullptr) { + MS_LOG(ERROR) << "malloc SparseToDenseParameter failed."; + return nullptr; + } + memset(sparse_to_dense_param, 0, sizeof(SparseToDenseParameter)); + sparse_to_dense_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + sparse_to_dense_param->validate_indices_ = param->GetValidateIndices(); + return reinterpret_cast(sparse_to_dense_param); +} + +Registry SparseToDenseParameterRegistry(schema::PrimitiveType_SparseToDense, PopulateSparseToDenseParameter); + int SparseToDense::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); MS_ASSERT(output_shape != nullptr); diff --git a/mindspore/lite/src/ops/split.cc b/mindspore/lite/src/ops/split.cc index d10cb05676..f57634b94d 100644 --- a/mindspore/lite/src/ops/split.cc +++ b/mindspore/lite/src/ops/split.cc @@ -16,6 +16,9 @@ #include "src/ops/split.h" +#include "src/ops/ops_register.h" +#include "nnacl/split_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -88,8 +91,32 @@ int Split::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SplitCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SplitRegistry(schema::PrimitiveType_Split, SplitCreator); #endif +OpParameter *PopulateSplitParameter(const mindspore::lite::PrimitiveC *primitive) { + SplitParameter *split_param = reinterpret_cast(malloc(sizeof(SplitParameter))); + if (split_param == nullptr) { + MS_LOG(ERROR) << "malloc SplitParameter failed."; + return nullptr; + } + memset(split_param, 0, sizeof(SplitParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + split_param->op_parameter_.type_ = primitive->Type(); + split_param->num_split_ = param->GetNumberSplit(); + auto split_sizes_vector_ = param->GetSizeSplits(); + int i = 0; + for (auto iter = split_sizes_vector_.begin(); iter != split_sizes_vector_.end(); iter++) { + split_param->split_sizes_[i++] = *iter; + } + split_param->split_dim_ = param->GetSplitDim(); + split_param->num_split_ = param->GetNumberSplit(); + return reinterpret_cast(split_param); +} +Registry SplitParameterRegistry(schema::PrimitiveType_Split, PopulateSplitParameter); + namespace { constexpr int kSplitInputNum = 1; } // namespace diff --git a/mindspore/lite/src/ops/sqrt.cc b/mindspore/lite/src/ops/sqrt.cc index 6035b5db8c..036776a1a4 100644 --- a/mindspore/lite/src/ops/sqrt.cc +++ b/mindspore/lite/src/ops/sqrt.cc @@ -16,6 +16,8 @@ #include "src/ops/sqrt.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,11 @@ int Sqrt::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SqrtCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SqrtRegistry(schema::PrimitiveType_Sqrt, SqrtCreator); #endif +Registry SqrtParameterRegistry(schema::PrimitiveType_Sqrt, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/square.cc b/mindspore/lite/src/ops/square.cc index 89f5ba8dbf..b0a32b1658 100644 --- a/mindspore/lite/src/ops/square.cc +++ b/mindspore/lite/src/ops/square.cc @@ -16,6 +16,8 @@ #include "src/ops/square.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,11 @@ int Square::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SquareCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SquareRegistry(schema::PrimitiveType_Square, SquareCreator); #endif +Registry SquareGradParameterRegistry(schema::PrimitiveType_Square, PopulateArithmeticSelf); + } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squared_difference.cc b/mindspore/lite/src/ops/squared_difference.cc index b602b29600..3a75d91469 100644 --- a/mindspore/lite/src/ops/squared_difference.cc +++ b/mindspore/lite/src/ops/squared_difference.cc @@ -16,6 +16,8 @@ #include "src/ops/squared_difference.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -29,6 +31,13 @@ int SquaredDifference::UnPackToFlatBuilder(const schema::Primitive *primitive, f fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SquaredDifferenceCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry SquaredDifferenceRegistry(schema::PrimitiveType_SquaredDifference, SquaredDifferenceCreator); + #endif +Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/squeeze.cc b/mindspore/lite/src/ops/squeeze.cc index 522e00206d..eee558b2ff 100644 --- a/mindspore/lite/src/ops/squeeze.cc +++ b/mindspore/lite/src/ops/squeeze.cc @@ -16,6 +16,9 @@ #include "src/ops/squeeze.h" +#include "src/ops/ops_register.h" +#include "nnacl/squeeze.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -77,8 +80,23 @@ int Squeeze::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SqueezeCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SqueezeRegistry(schema::PrimitiveType_Squeeze, SqueezeCreator); #endif +OpParameter *PopulateSqueezeParameter(const mindspore::lite::PrimitiveC *primitive) { + SqueezeParameter *squeeze_param = reinterpret_cast(malloc(sizeof(SqueezeParameter))); + if (squeeze_param == nullptr) { + MS_LOG(ERROR) << "malloc SqueezeParameter failed."; + return nullptr; + } + memset(squeeze_param, 0, sizeof(SqueezeParameter)); + squeeze_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(squeeze_param); +} +Registry SqueezeParameterRegistry(schema::PrimitiveType_Squeeze, PopulateSqueezeParameter); + namespace { constexpr int kSqueezeInputNum = 1; constexpr int kSqueezeOutputNum = 1; diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index c7752ada1a..ec20f03635 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -16,6 +16,9 @@ #include "src/ops/stack.h" +#include "src/ops/ops_register.h" +#include "nnacl/stack_parameter.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -54,8 +57,26 @@ int Stack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *StackCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry StackRegistry(schema::PrimitiveType_Stack, StackCreator); + #endif +OpParameter *PopulateStackParameter(const mindspore::lite::PrimitiveC *primitive) { + StackParameter *stack_param = reinterpret_cast(malloc(sizeof(StackParameter))); + if (stack_param == nullptr) { + MS_LOG(ERROR) << "malloc StackParameter failed."; + return nullptr; + } + memset(stack_param, 0, sizeof(StackParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + stack_param->op_parameter_.type_ = primitive->Type(); + stack_param->axis_ = param->GetAxis(); + return reinterpret_cast(stack_param); +} +Registry StackParameterRegistry(schema::PrimitiveType_Stack, PopulateStackParameter); + namespace { constexpr int kStackOutputNum = 1; constexpr int kStackMinInputNum = 1; diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index b7905e0bef..21a3db8b2f 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -16,6 +16,9 @@ #include "src/ops/strided_slice.h" +#include "src/ops/ops_register.h" +#include "nnacl/strided_slice.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -158,7 +161,37 @@ int StridedSlice::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *StridedSliceCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCreator); #endif + +OpParameter *PopulateStridedSliceParameter(const mindspore::lite::PrimitiveC *primitive) { + StridedSliceParameter *strided_slice_param = + reinterpret_cast(malloc(sizeof(StridedSliceParameter))); + if (strided_slice_param == nullptr) { + MS_LOG(ERROR) << "malloc StridedSliceParameter failed."; + return nullptr; + } + memset(strided_slice_param, 0, sizeof(StridedSliceParameter)); + strided_slice_param->op_parameter_.type_ = primitive->Type(); + auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); + strided_slice_param->num_axes_ = n_dims; + auto begin = ((lite::StridedSlice *)primitive)->GetBegins(); + memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int)); + auto end = ((lite::StridedSlice *)primitive)->GetEnds(); + memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int)); + auto stride = ((lite::StridedSlice *)primitive)->GetStrides(); + memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); + auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); + memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); + return reinterpret_cast(strided_slice_param); +} + +Registry StridedSliceParameterRegistry(schema::PrimitiveType_StridedSlice, PopulateStridedSliceParameter); + namespace { constexpr size_t kStridedSliceOutputNum = 1; constexpr size_t kStridedSliceInputNum = 1; diff --git a/mindspore/lite/src/ops/sub.cc b/mindspore/lite/src/ops/sub.cc index bee2131df8..a5e5292015 100644 --- a/mindspore/lite/src/ops/sub.cc +++ b/mindspore/lite/src/ops/sub.cc @@ -16,6 +16,9 @@ #include "src/ops/sub.h" +#include "src/ops/ops_register.h" +#include "nnacl/arithmetic_common.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -41,6 +44,31 @@ int Sub::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *SubCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry SubRegistry(schema::PrimitiveType_Sub, SubCreator); + #endif +OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) { + ArithmeticParameter *arithmetic_param = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (arithmetic_param == nullptr) { + MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; + return nullptr; + } + memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); + arithmetic_param->op_parameter_.type_ = primitive->Type(); + arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); + arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims(); + arithmetic_param->activation_type_ = + reinterpret_cast(const_cast(primitive))->GetActivationType(); + auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); + memcpy(arithmetic_param->in_shape0_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); + memcpy(arithmetic_param->in_shape1_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); + memcpy(arithmetic_param->out_shape_, static_cast(tmp_shape.data()), tmp_shape.size() * sizeof(int)); + return reinterpret_cast(arithmetic_param); +} +Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 4779384d6f..786d950aa6 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -17,6 +17,9 @@ #include "src/ops/tile.h" #include +#include "src/ops/ops_register.h" +#include "nnacl/fp32/tile.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -108,8 +111,30 @@ int Tile::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *TileCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry TileRegistry(schema::PrimitiveType_Tile, TileCreator); #endif +OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) { + TileParameter *tile_param = reinterpret_cast(malloc(sizeof(TileParameter))); + if (tile_param == nullptr) { + MS_LOG(ERROR) << "malloc TileParameter failed."; + return nullptr; + } + memset(tile_param, 0, sizeof(TileParameter)); + tile_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + auto multiples = param->GetMultiples(); + tile_param->in_dim_ = multiples.size(); + for (int i = 0; i < tile_param->in_dim_; ++i) { + tile_param->multiples_[i] = multiples[i]; + } + return reinterpret_cast(tile_param); +} + +Registry TileParameterRegistry(schema::PrimitiveType_Tile, PopulateTileParameter); + int Tile::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/topk.cc b/mindspore/lite/src/ops/topk.cc index 7ed1782574..02a02fa7a6 100644 --- a/mindspore/lite/src/ops/topk.cc +++ b/mindspore/lite/src/ops/topk.cc @@ -16,6 +16,9 @@ #include "src/ops/topk.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/topk.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -42,8 +45,27 @@ int TopK::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *TopKCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry TopKRegistry(schema::PrimitiveType_TopK, TopKCreator); + #endif +OpParameter *PopulateTopKParameter(const mindspore::lite::PrimitiveC *primitive) { + TopkParameter *topk_param = reinterpret_cast(malloc(sizeof(TopkParameter))); + if (topk_param == nullptr) { + MS_LOG(ERROR) << "malloc TopkParameter failed."; + return nullptr; + } + memset(topk_param, 0, sizeof(TopkParameter)); + topk_param->op_parameter_.type_ = primitive->Type(); + auto param = reinterpret_cast(const_cast(primitive)); + topk_param->k_ = param->GetK(); + topk_param->sorted_ = param->GetSorted(); + return reinterpret_cast(topk_param); +} +Registry TopKParameterRegistry(schema::PrimitiveType_TopK, PopulateTopKParameter); + int TopK::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index 279a6953c7..c19610ef67 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -19,6 +19,9 @@ #include "include/errorcode.h" #include "src/common/log_adapter.h" +#include "src/ops/ops_register.h" +#include "nnacl/transpose.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -101,8 +104,35 @@ int Transpose::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *TransposeCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry TransposeRegistry(schema::PrimitiveType_Transpose, TransposeCreator); + #endif +OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primitive) { + TransposeParameter *transpose_param = reinterpret_cast(malloc(sizeof(TransposeParameter))); + if (transpose_param == nullptr) { + MS_LOG(ERROR) << "malloc TransposeParameter failed."; + return nullptr; + } + memset(transpose_param, 0, sizeof(TransposeParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + transpose_param->op_parameter_.type_ = primitive->Type(); + auto perm_vector_ = param->GetPerm(); + int i = 0; + for (auto iter = perm_vector_.begin(); iter != perm_vector_.end(); iter++) { + transpose_param->perm_[i++] = *iter; + } + transpose_param->num_axes_ = i; + transpose_param->conjugate_ = param->GetConjugate(); + return reinterpret_cast(transpose_param); +} + +Registry TransposeParameterRegistry(schema::PrimitiveType_Transpose, PopulateTransposeParameter); + int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/tuple_get_item.cc b/mindspore/lite/src/ops/tuple_get_item.cc index ac35bb6100..8bc68bc4e3 100644 --- a/mindspore/lite/src/ops/tuple_get_item.cc +++ b/mindspore/lite/src/ops/tuple_get_item.cc @@ -18,6 +18,8 @@ #include #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -57,6 +59,10 @@ int TupleGetItem::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *TupleGetItemCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry TupleGetItemRegistry(schema::PrimitiveType_TupleGetItem, TupleGetItemCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/unique.cc b/mindspore/lite/src/ops/unique.cc index 073a8771a9..44d190a047 100644 --- a/mindspore/lite/src/ops/unique.cc +++ b/mindspore/lite/src/ops/unique.cc @@ -16,6 +16,9 @@ #include "src/ops/unique.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/unique.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -39,8 +42,24 @@ int Unique::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *UniqueCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry UniqueRegistry(schema::PrimitiveType_Unique, UniqueCreator); #endif +OpParameter *PopulateUniqueParameter(const mindspore::lite::PrimitiveC *primitive) { + UniqueParameter *unique_param = reinterpret_cast(malloc(sizeof(UniqueParameter))); + if (unique_param == nullptr) { + MS_LOG(ERROR) << "malloc UniqueParameter failed."; + return nullptr; + } + memset(unique_param, 0, sizeof(UniqueParameter)); + unique_param->op_parameter_.type_ = primitive->Type(); + return reinterpret_cast(unique_param); +} + +Registry UniqueParameterRegistry(schema::PrimitiveType_Unique, PopulateUniqueParameter); + int Unique::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); if (inputs_.size() != kSingleNum || outputs_.size() != kDoubleNum) { diff --git a/mindspore/lite/src/ops/unsqueeze.cc b/mindspore/lite/src/ops/unsqueeze.cc index 7e7e460cbb..62b8674f5f 100644 --- a/mindspore/lite/src/ops/unsqueeze.cc +++ b/mindspore/lite/src/ops/unsqueeze.cc @@ -19,6 +19,9 @@ #include "src/common/log_adapter.h" #include "src/tensor.h" +#include "src/ops/ops_register.h" +#include "nnacl/fp32/unsqueeze.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -51,8 +54,35 @@ int Unsqueeze::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *UnsqueezeCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry UnsqueezeRegistry(schema::PrimitiveType_Unsqueeze, UnsqueezeCreator); + #endif +OpParameter *PopulateUnsqueezeParameter(const mindspore::lite::PrimitiveC *primitive) { + auto unsqueeze_attr = + reinterpret_cast(const_cast(primitive)); + UnsqueezeParameter *unsqueeze_param = reinterpret_cast(malloc(sizeof(UnsqueezeParameter))); + if (unsqueeze_param == nullptr) { + MS_LOG(ERROR) << "malloc UnsqueezeParameter failed."; + return nullptr; + } + memset(unsqueeze_param, 0, sizeof(UnsqueezeParameter)); + unsqueeze_param->op_parameter_.type_ = primitive->Type(); + auto flatAxis = unsqueeze_attr->GetAxis(); + unsqueeze_param->num_dim_ = flatAxis.size(); + int i = 0; + for (auto iter = flatAxis.begin(); iter != flatAxis.end(); iter++) { + unsqueeze_param->dims_[i++] = *iter; + } + return reinterpret_cast(unsqueeze_param); +} + +Registry UnsqueezeParameterRegistry(schema::PrimitiveType_Unsqueeze, PopulateUnsqueezeParameter); + int Unsqueeze::InferShape(std::vector inputs_, std::vector outputs_) { MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); diff --git a/mindspore/lite/src/ops/unstack.cc b/mindspore/lite/src/ops/unstack.cc index 1e892b3651..0102bed99f 100644 --- a/mindspore/lite/src/ops/unstack.cc +++ b/mindspore/lite/src/ops/unstack.cc @@ -15,6 +15,8 @@ */ #include "src/ops/unstack.h" +#include "src/ops/ops_register.h" +#include "nnacl/unstack.h" namespace mindspore { namespace lite { @@ -42,8 +44,26 @@ int Unstack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *UnstackCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry UnstackRegistry(schema::PrimitiveType_Unstack, UnstackCreator); #endif +OpParameter *PopulateUnstackParameter(const mindspore::lite::PrimitiveC *primitive) { + UnstackParameter *unstack_param = reinterpret_cast(malloc(sizeof(UnstackParameter))); + if (unstack_param == nullptr) { + MS_LOG(ERROR) << "malloc UnstackParameter failed."; + return nullptr; + } + memset(unstack_param, 0, sizeof(UnstackParameter)); + auto param = reinterpret_cast(const_cast(primitive)); + unstack_param->op_parameter_.type_ = primitive->Type(); + unstack_param->num_ = param->GetNum(); + unstack_param->axis_ = param->GetAxis(); + return reinterpret_cast(unstack_param); +} +Registry UnstackParameterRegistry(schema::PrimitiveType_Unstack, PopulateUnstackParameter); + int Unstack::InferShape(std::vector inputs, std::vector outputs) { auto input = inputs.at(0); MS_ASSERT(input != nullptr); diff --git a/mindspore/lite/src/ops/upsample.cc b/mindspore/lite/src/ops/upsample.cc index 10c9af70d8..776496ce19 100644 --- a/mindspore/lite/src/ops/upsample.cc +++ b/mindspore/lite/src/ops/upsample.cc @@ -17,6 +17,8 @@ #include "src/ops/upsample.h" #include +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -52,6 +54,11 @@ int Upsample::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffer fbb->Finish(prim_offset); return RET_OK; } +PrimitiveC *UpsampleCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry UpsampleRegistry(schema::PrimitiveType_Upsample, UpsampleCreator); + #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/where.cc b/mindspore/lite/src/ops/where.cc index 39155899c7..4b3707148b 100644 --- a/mindspore/lite/src/ops/where.cc +++ b/mindspore/lite/src/ops/where.cc @@ -16,6 +16,8 @@ #include "src/ops/where.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -50,6 +52,10 @@ int Where::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *WhereCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } +Registry WhereRegistry(schema::PrimitiveType_Where, WhereCreator); + #endif int Where::InferShape(std::vector inputs_, std::vector outputs_) { diff --git a/mindspore/lite/src/ops/zeros_like.cc b/mindspore/lite/src/ops/zeros_like.cc index b57384432a..527fd21735 100644 --- a/mindspore/lite/src/ops/zeros_like.cc +++ b/mindspore/lite/src/ops/zeros_like.cc @@ -16,6 +16,8 @@ #include "src/ops/zeros_like.h" +#include "src/ops/ops_register.h" + namespace mindspore { namespace lite { @@ -30,6 +32,12 @@ int ZerosLike::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe fbb->Finish(prim_offset); return RET_OK; } + +PrimitiveC *ZerosLikeCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry ZerosLikeRegistry(schema::PrimitiveType_ZerosLike, ZerosLikeCreator); + #endif int ZerosLike::InferShape(std::vector inputs_, std::vector outputs_) {