!7570 [MSLITE] Decoupling primitive problems

Merge pull request !7570 from yeyunpeng2020/primitive
pull/7570/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 733ad9cf2a

@ -19,6 +19,7 @@
#include "include/errorcode.h"
#include "src/common/graph_util.h"
#include "include/version.h"
#include "src/ops/ops_register.h"
namespace mindspore::lite {
@ -31,7 +32,12 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
}
auto c_node = meta_graph->nodes()->GetAs<schema::CNode>(i);
auto src_prim = c_node->primitive();
#ifdef PRIMITIVE_WRITEABLE
node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim));
#else
auto primitive = const_cast<schema::Primitive *>(src_prim);
node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive);
#endif
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "unpack primitive == nullptr!";
delete node;

@ -15,6 +15,7 @@
*/
#include "src/ops/abs.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
@ -27,6 +28,9 @@ int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AbsCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Abs>(primitive); }
Registry AbsRegistry(schema::PrimitiveType_Abs, AbsCreator);
#endif
Registry AbsParameterRegistry(schema::PrimitiveType_Abs, PopulateArithmeticSelf);
} // namespace lite
} // namespace mindspore

@ -16,6 +16,8 @@
#include "src/ops/activation.h"
#include <memory>
#include "src/ops/ops_register.h"
#include "nnacl/fp32/activation.h"
namespace mindspore {
namespace lite {
@ -80,6 +82,30 @@ int Activation::GetType() const { return this->primitive_->value_as_Activation()
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
float Activation::GetMinVal() const { return this->primitive_->value_as_Activation()->min_val(); }
float Activation::GetMaxVal() const { return this->primitive_->value_as_Activation()->max_val(); }
PrimitiveC *ActivationCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<Activation>(primitive);
}
Registry ActivationRegistry(schema::PrimitiveType_Activation, ActivationCreator);
#endif
OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *primitive) {
ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationParameter));
act_param->op_parameter_.type_ = primitive->Type();
auto activation =
reinterpret_cast<mindspore::lite::Activation *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->type_ = static_cast<int>(activation->GetType());
act_param->alpha_ = activation->GetAlpha();
act_param->min_val_ = activation->GetMinVal();
act_param->max_val_ = activation->GetMaxVal();
return reinterpret_cast<OpParameter *>(act_param);
}
Registry ActivationParameterRegistry(schema::PrimitiveType_Activation, PopulateActivationParameter);
} // namespace lite
} // namespace mindspore

@ -16,6 +16,8 @@
#include "src/ops/activation_grad.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -74,6 +76,11 @@ int ActivationGrad::UnPackToFlatBuilder(const schema::Primitive *primitive, flat
}
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
float ActivationGrad::GetAlpha() const { return this->primitive_->value_as_ActivationGrad()->alpha(); }
PrimitiveC *ActivationGradCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<ActivationGrad>(primitive);
}
Registry ActivationGradRegistry(schema::PrimitiveType_ActivationGrad, ActivationGradCreator);
#endif
} // namespace lite
} // namespace mindspore

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/ops/adam.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -62,6 +64,9 @@ int Adam::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AdamCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Adam>(primitive); }
Registry AdamRegistry(schema::PrimitiveType_Adam, AdamCreator);
#endif
int Adam::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {

@ -16,6 +16,8 @@
#include "src/ops/add.h"
#include <memory>
#include "src/ops/ops_register.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
@ -71,6 +73,31 @@ int Add::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
}
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
PrimitiveC *AddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Add>(primitive); }
Registry AddRegistry(schema::PrimitiveType_Add, AddCreator);
#endif
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
} // namespace lite
} // namespace mindspore

@ -16,6 +16,8 @@
#include "src/ops/addn.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -62,8 +64,22 @@ int AddN::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
}
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
PrimitiveC *AddNCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<AddN>(primitive); }
Registry AddNRegistry(schema::PrimitiveType_AddN, AddNCreator);
#endif
OpParameter *PopulateAddNParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (addn_param == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(addn_param, 0, sizeof(OpParameter));
addn_param->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(addn_param);
}
Registry AddNParameterRegistry(schema::PrimitiveType_AddN, PopulateAddNParameter);
namespace {
constexpr int kLeastInputNum = 2;
}

@ -14,6 +14,8 @@
* limitations under the License.
*/
#include "src/ops/apply_momentum.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -67,6 +69,11 @@ int ApplyMomentum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *ApplyMomentumCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<ApplyMomentum>(primitive);
}
Registry ApplyMomentumRegistry(schema::PrimitiveType_ApplyMomentum, ApplyMomentumCreator);
#endif
int ApplyMomentum::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {

@ -16,6 +16,9 @@
#include "src/ops/argmax.h"
#include "src/ops/ops_register.h"
#include "nnacl/arg_min_max_parameter.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -52,8 +55,29 @@ int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK()
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
PrimitiveC *ArgMaxCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMax>(primitive); }
Registry ArgMaxRegistry(schema::PrimitiveType_ArgMax, ArgMaxCreator);
#endif
OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitive) {
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
return nullptr;
}
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
arg_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::ArgMax *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
arg_param->axis_ = param->GetAxis();
arg_param->topk_ = param->GetTopK();
arg_param->axis_type_ = param->GetAxisType();
arg_param->out_value_ = param->GetOutMaxValue();
arg_param->keep_dims_ = param->GetKeepDims();
return reinterpret_cast<OpParameter *>(arg_param);
}
Registry ArgMaxParameterRegistry(schema::PrimitiveType_ArgMax, PopulateArgMaxParameter);
int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -16,6 +16,9 @@
#include "src/ops/argmin.h"
#include "src/ops/ops_register.h"
#include "nnacl/arg_min_max_parameter.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -52,8 +55,29 @@ int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK()
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
PrimitiveC *ArgMinCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<ArgMin>(primitive); }
Registry ArgMinRegistry(schema::PrimitiveType_ArgMin, ArgMinCreator);
#endif
OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitive) {
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
return nullptr;
}
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
arg_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::ArgMin *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
arg_param->axis_ = param->GetAxis();
arg_param->topk_ = param->GetTopK();
arg_param->axis_type_ = param->GetAxisType();
arg_param->out_value_ = param->GetOutMaxValue();
arg_param->keep_dims_ = param->GetKeepDims();
return reinterpret_cast<OpParameter *>(arg_param);
}
Registry ArgMinParameterRegistry(schema::PrimitiveType_ArgMin, PopulateArgMinParameter);
int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -21,6 +21,29 @@
namespace mindspore {
namespace lite {
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ = 0;
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {

@ -21,6 +21,7 @@
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
@ -51,6 +52,8 @@ class Arithmetic : public PrimitiveC {
std::vector<int> in_shape1_;
std::vector<int> out_shape_;
};
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite
} // namespace mindspore

@ -21,6 +21,7 @@
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
#include "nnacl/arithmetic_self_parameter.h"
namespace mindspore {
namespace lite {

@ -17,9 +17,21 @@
#include "src/ops/arithmetic_self.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticSelfParameter *arithmetic_self_param =
reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (arithmetic_self_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed.";
return nullptr;
}
memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter));
arithmetic_self_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(arithmetic_self_param);
}
int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);

@ -19,6 +19,7 @@
#include <vector>
#include "src/ops/primitive_c.h"
#include "nnacl/arithmetic_self_parameter.h"
namespace mindspore {
namespace lite {
@ -37,6 +38,7 @@ class ArithmeticSelf : public PrimitiveC {
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite
} // namespace mindspore

@ -17,6 +17,8 @@
#include "src/ops/assign.h"
#include <memory>
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -56,6 +58,9 @@ int Assign::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *AssignCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Assign>(primitive); }
Registry AssignRegistry(schema::PrimitiveType_Assign, AssignCreator);
#endif
int Assign::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {

@ -16,6 +16,9 @@
#include "src/ops/batch_norm.h"
#include <memory>
#include "src/ops/ops_register.h"
#include "nnacl/batchnorm_parameter.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -60,6 +63,28 @@ int BatchNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe
}
float BatchNorm::GetEpsilon() const { return this->primitive_->value_as_BatchNorm()->epsilon(); }
PrimitiveC *BatchNormCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchNorm>(primitive);
}
Registry BatchNormRegistry(schema::PrimitiveType_BatchNorm, BatchNormCreator);
#endif
OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
const auto param =
reinterpret_cast<mindspore::lite::BatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (batch_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr;
}
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
batch_norm_param->op_parameter_.type_ = primitive->Type();
batch_norm_param->epsilon_ = param->GetEpsilon();
batch_norm_param->fused_ = false;
return reinterpret_cast<OpParameter *>(batch_norm_param);
}
Registry BatchNormParameterRegistry(schema::PrimitiveType_BatchNorm, PopulateBatchNorm);
} // namespace lite
} // namespace mindspore

@ -20,6 +20,9 @@
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#include "src/ops/ops_register.h"
#include "nnacl/batch_to_space.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -66,7 +69,49 @@ std::vector<int> BatchToSpace::GetCrops() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchToSpace>(primitive);
}
Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator);
#endif
OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) {
BatchToSpaceParameter *batch_space_param =
reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
if (batch_space_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
return nullptr;
}
memset(batch_space_param, 0, sizeof(BatchToSpaceParameter));
batch_space_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto block_shape = param->GetBlockShape();
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
free(batch_space_param);
return nullptr;
}
auto crops = param->GetCrops();
if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) {
MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE;
free(batch_space_param);
return nullptr;
}
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
batch_space_param->block_shape_[i] = block_shape[i];
}
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
batch_space_param->crops_[i] = crops[i];
}
return reinterpret_cast<OpParameter *>(batch_space_param);
}
Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter);
Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter);
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
constexpr int kBatchToSpaceInputNum = 1;

@ -16,6 +16,8 @@
#include "src/ops/bias_add.h"
#include <memory>
#include "nnacl/arithmetic_common.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
@ -78,6 +80,22 @@ std::vector<int> BiasAdd::GetAxis() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<BiasAdd>(primitive); }
Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator);
#endif
OpParameter *PopulateBiasAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry BiasAddParameterRegistry(schema::PrimitiveType_BiasAdd, PopulateBiasAddParameter);
} // namespace lite
} // namespace mindspore

@ -16,6 +16,8 @@
#include "src/ops/bias_grad.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -74,6 +76,11 @@ std::vector<int> BiasGrad::GetAxis() const {
auto fb_vector = this->primitive_->value_as_BiasGrad()->axis();
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
PrimitiveC *BiasGradCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BiasGrad>(primitive);
}
Registry BiasGradRegistry(schema::PrimitiveType_BiasGrad, BiasGradCreator);
#endif
int BiasGrad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {

@ -16,6 +16,8 @@
#include "src/ops/bn_grad.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE

@ -16,6 +16,9 @@
#include "src/ops/broadcast_to.h"
#include "src/ops/ops_register.h"
#include "nnacl/fp32/broadcast_to.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -50,7 +53,32 @@ std::vector<int> BroadcastTo::GetDstShape() const {
return std::vector<int>(fb_vector->begin(), fb_vector->end());
}
PrimitiveC *BroadcastToCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BroadcastTo>(primitive);
}
Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreator);
#endif
OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *primitive) {
BroadcastToParameter *broadcast_param =
reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter)));
if (broadcast_param == nullptr) {
MS_LOG(ERROR) << "malloc BroadcastToParameter failed.";
return nullptr;
}
memset(broadcast_param, 0, sizeof(BroadcastToParameter));
auto param = reinterpret_cast<mindspore::lite::BroadcastTo *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
broadcast_param->op_parameter_.type_ = primitive->Type();
auto dst_shape = param->GetDstShape();
broadcast_param->shape_size_ = dst_shape.size();
for (size_t i = 0; i < broadcast_param->shape_size_; ++i) {
broadcast_param->shape_[i] = dst_shape[i];
}
return reinterpret_cast<OpParameter *>(broadcast_param);
}
Registry BroadcastToParameterRegistry(schema::PrimitiveType_BroadcastTo, PopulateBroadcastToParameter);
namespace {
constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOutputNum = 1;

@ -16,6 +16,9 @@
#include "src/ops/cast.h"
#include "src/ops/ops_register.h"
#include "nnacl/fp32/cast.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
@ -75,8 +78,26 @@ int Cast::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::F
int Cast::GetSrcT() const { return this->primitive_->value_as_Cast()->srcT(); }
int Cast::GetDstT() const { return this->primitive_->value_as_Cast()->dstT(); }
PrimitiveC *CastCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Cast>(primitive); }
Registry CastRegistry(schema::PrimitiveType_Cast, CastCreator);
#endif
OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) {
CastParameter *cast_param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter)));
if (cast_param == nullptr) {
MS_LOG(ERROR) << "malloc CastParameter failed.";
return nullptr;
}
memset(cast_param, 0, sizeof(CastParameter));
cast_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Cast *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
cast_param->src_type_ = param->GetSrcT();
cast_param->dst_type_ = param->GetDstT();
return reinterpret_cast<OpParameter *>(cast_param);
}
Registry CastParameterRegistry(schema::PrimitiveType_Cast, PopulateCastParameter);
int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -0,0 +1,27 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/ceil.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
Registry CeilParameterRegistry(schema::PrimitiveType_Ceil, PopulateArithmeticSelf);
} // namespace lite
} // namespace mindspore

@ -21,6 +21,7 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic_self.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {
@ -43,6 +44,7 @@ class Ceil : public ArithmeticSelf {
}
#endif
};
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save