decoupling populate_parameter problem

pull/7622/head
yeyunpeng 4 years ago
parent 4df56b6c1e
commit 4918e404de

@ -31,7 +31,6 @@ set(LITE_SRC
${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc
${CMAKE_CURRENT_SOURCE_DIR}/populate_parameter.cc
${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc
${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/model.cc
@ -65,7 +64,7 @@ if (SUPPORT_TRAIN)
)
endif ()
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/ops/populate/*.cc)
add_subdirectory(runtime/kernel/arm)
add_library(mindspore-lite SHARED ${LITE_SRC} ${C_OPS_SRC})

@ -15,8 +15,7 @@
*/
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/populate_parameter.h"
#include "src/ops/populate/populate_register.h"
#ifdef ENABLE_ARM64
#include <asm/hwcap.h>
#include "common/utils.h"
@ -107,7 +106,8 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te
const InnerContext *ctx, const kernel::KernelKey &key) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != ctx);
auto parameter = kernel::PopulateParameter(primitive);
auto parameter =
PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "PopulateParameter return nullptr, type: "
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type());

@ -16,7 +16,9 @@
#include "src/model_common.h"
#include "include/version.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore::lite {
bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {

@ -15,7 +15,9 @@
*/
#include "src/ops/abs.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
@ -31,6 +33,6 @@ int Abs::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::Fl
PrimitiveC *AbsCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Abs>(primitive); }
Registry AbsRegistry(schema::PrimitiveType_Abs, AbsCreator);
#endif
Registry AbsParameterRegistry(schema::PrimitiveType_Abs, PopulateArithmeticSelf);
} // namespace lite
} // namespace mindspore

@ -16,8 +16,9 @@
#include "src/ops/activation.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/fp32/activation.h"
#endif
namespace mindspore {
namespace lite {
@ -88,24 +89,6 @@ PrimitiveC *ActivationCreator(const schema::Primitive *primitive) {
}
Registry ActivationRegistry(schema::PrimitiveType_Activation, ActivationCreator);
#endif
OpParameter *PopulateActivationParameter(const mindspore::lite::PrimitiveC *primitive) {
ActivationParameter *act_param = reinterpret_cast<ActivationParameter *>(malloc(sizeof(ActivationParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ActivationParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ActivationParameter));
act_param->op_parameter_.type_ = primitive->Type();
auto activation =
reinterpret_cast<mindspore::lite::Activation *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->type_ = static_cast<int>(activation->GetType());
act_param->alpha_ = activation->GetAlpha();
act_param->min_val_ = activation->GetMinVal();
act_param->max_val_ = activation->GetMaxVal();
return reinterpret_cast<OpParameter *>(act_param);
}
Registry ActivationParameterRegistry(schema::PrimitiveType_Activation, PopulateActivationParameter);
} // namespace lite
} // namespace mindspore

@ -16,7 +16,9 @@
#include "src/ops/activation_grad.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -14,7 +14,9 @@
* limitations under the License.
*/
#include "src/ops/adam.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -16,8 +16,9 @@
#include "src/ops/add.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/arithmetic_common.h"
#endif
namespace mindspore {
namespace lite {
@ -77,27 +78,5 @@ PrimitiveC *AddCreator(const schema::Primitive *primitive) { return PrimitiveC::
Registry AddRegistry(schema::PrimitiveType_Add, AddCreator);
#endif
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
} // namespace lite
} // namespace mindspore

@ -16,7 +16,9 @@
#include "src/ops/addn.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
@ -68,18 +70,6 @@ PrimitiveC *AddNCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry AddNRegistry(schema::PrimitiveType_AddN, AddNCreator);
#endif
OpParameter *PopulateAddNParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *addn_param = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
if (addn_param == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(addn_param, 0, sizeof(OpParameter));
addn_param->type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(addn_param);
}
Registry AddNParameterRegistry(schema::PrimitiveType_AddN, PopulateAddNParameter);
namespace {
constexpr int kLeastInputNum = 2;
}

@ -14,7 +14,9 @@
* limitations under the License.
*/
#include "src/ops/apply_momentum.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -16,8 +16,9 @@
#include "src/ops/argmax.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/arg_min_max_parameter.h"
#endif
namespace mindspore {
namespace lite {
@ -59,25 +60,6 @@ PrimitiveC *ArgMaxCreator(const schema::Primitive *primitive) { return Primitive
Registry ArgMaxRegistry(schema::PrimitiveType_ArgMax, ArgMaxCreator);
#endif
OpParameter *PopulateArgMaxParameter(const mindspore::lite::PrimitiveC *primitive) {
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
return nullptr;
}
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
arg_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::ArgMax *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
arg_param->axis_ = param->GetAxis();
arg_param->topk_ = param->GetTopK();
arg_param->axis_type_ = param->GetAxisType();
arg_param->out_value_ = param->GetOutMaxValue();
arg_param->keep_dims_ = param->GetKeepDims();
return reinterpret_cast<OpParameter *>(arg_param);
}
Registry ArgMaxParameterRegistry(schema::PrimitiveType_ArgMax, PopulateArgMaxParameter);
int ArgMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -16,8 +16,9 @@
#include "src/ops/argmin.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/arg_min_max_parameter.h"
#endif
namespace mindspore {
namespace lite {
@ -59,25 +60,6 @@ PrimitiveC *ArgMinCreator(const schema::Primitive *primitive) { return Primitive
Registry ArgMinRegistry(schema::PrimitiveType_ArgMin, ArgMinCreator);
#endif
OpParameter *PopulateArgMinParameter(const mindspore::lite::PrimitiveC *primitive) {
ArgMinMaxParameter *arg_param = reinterpret_cast<ArgMinMaxParameter *>(malloc(sizeof(ArgMinMaxParameter)));
if (arg_param == nullptr) {
MS_LOG(ERROR) << "malloc ArgMinMaxParameter failed.";
return nullptr;
}
memset(arg_param, 0, sizeof(ArgMinMaxParameter));
arg_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::ArgMin *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
arg_param->axis_ = param->GetAxis();
arg_param->topk_ = param->GetTopK();
arg_param->axis_type_ = param->GetAxisType();
arg_param->out_value_ = param->GetOutMaxValue();
arg_param->keep_dims_ = param->GetKeepDims();
return reinterpret_cast<OpParameter *>(arg_param);
}
Registry ArgMinParameterRegistry(schema::PrimitiveType_ArgMin, PopulateArgMinParameter);
int ArgMin::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -22,28 +22,6 @@
namespace mindspore {
namespace lite {
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ = 0;
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {

@ -52,8 +52,6 @@ class Arithmetic : public PrimitiveC {
std::vector<int> in_shape1_;
std::vector<int> out_shape_;
};
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive);
} // namespace lite
} // namespace mindspore

@ -17,21 +17,12 @@
#include "src/ops/arithmetic_self.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
OpParameter *PopulateArithmeticSelf(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticSelfParameter *arithmetic_self_param =
reinterpret_cast<ArithmeticSelfParameter *>(malloc(sizeof(ArithmeticSelfParameter)));
if (arithmetic_self_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticSelfParameter failed.";
return nullptr;
}
memset(arithmetic_self_param, 0, sizeof(ArithmeticSelfParameter));
arithmetic_self_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(arithmetic_self_param);
}
int ArithmeticSelf::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);

@ -17,7 +17,9 @@
#include "src/ops/assign.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -16,8 +16,9 @@
#include "src/ops/batch_norm.h"
#include <memory>
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/batchnorm_parameter.h"
#endif
namespace mindspore {
namespace lite {
@ -69,22 +70,5 @@ PrimitiveC *BatchNormCreator(const schema::Primitive *primitive) {
Registry BatchNormRegistry(schema::PrimitiveType_BatchNorm, BatchNormCreator);
#endif
OpParameter *PopulateBatchNorm(const mindspore::lite::PrimitiveC *primitive) {
const auto param =
reinterpret_cast<mindspore::lite::BatchNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
BatchNormParameter *batch_norm_param = reinterpret_cast<BatchNormParameter *>(malloc(sizeof(BatchNormParameter)));
if (batch_norm_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchNormParameter failed.";
return nullptr;
}
memset(batch_norm_param, 0, sizeof(BatchNormParameter));
batch_norm_param->op_parameter_.type_ = primitive->Type();
batch_norm_param->epsilon_ = param->GetEpsilon();
batch_norm_param->fused_ = false;
return reinterpret_cast<OpParameter *>(batch_norm_param);
}
Registry BatchNormParameterRegistry(schema::PrimitiveType_BatchNorm, PopulateBatchNorm);
} // namespace lite
} // namespace mindspore

@ -20,8 +20,9 @@
#include "src/common/log_adapter.h"
#include "src/tensor.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/batch_to_space.h"
#endif
namespace mindspore {
namespace lite {
@ -75,43 +76,6 @@ PrimitiveC *BatchToSpaceCreator(const schema::Primitive *primitive) {
Registry BatchToSpaceRegistry(schema::PrimitiveType_BatchToSpace, BatchToSpaceCreator);
#endif
OpParameter *PopulateBatchToSpaceParameter(const mindspore::lite::PrimitiveC *primitive) {
BatchToSpaceParameter *batch_space_param =
reinterpret_cast<BatchToSpaceParameter *>(malloc(sizeof(BatchToSpaceParameter)));
if (batch_space_param == nullptr) {
MS_LOG(ERROR) << "malloc BatchToSpaceParameter failed.";
return nullptr;
}
memset(batch_space_param, 0, sizeof(BatchToSpaceParameter));
batch_space_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::BatchToSpace *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto block_shape = param->GetBlockShape();
if (block_shape.size() != BATCH_TO_SPACE_BLOCK_SHAPE_SIZE) {
MS_LOG(ERROR) << "batch_to_space blockShape size should be " << BATCH_TO_SPACE_BLOCK_SHAPE_SIZE;
free(batch_space_param);
return nullptr;
}
auto crops = param->GetCrops();
if (crops.size() != BATCH_TO_SPACE_CROPS_SIZE) {
MS_LOG(ERROR) << "batch_to_space crops size should be " << BATCH_TO_SPACE_CROPS_SIZE;
free(batch_space_param);
return nullptr;
}
for (int i = 0; i < BATCH_TO_SPACE_BLOCK_SHAPE_SIZE; ++i) {
batch_space_param->block_shape_[i] = block_shape[i];
}
for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) {
batch_space_param->crops_[i] = crops[i];
}
return reinterpret_cast<OpParameter *>(batch_space_param);
}
Registry BatchToSpaceParameterRegistry(schema::PrimitiveType_BatchToSpace, PopulateBatchToSpaceParameter);
Registry BatchToSpaceNDParameterRegistry(schema::PrimitiveType_BatchToSpaceND, PopulateBatchToSpaceParameter);
namespace {
constexpr int kBatchToSpaceOutputNum = 1;
constexpr int kBatchToSpaceInputNum = 1;

@ -16,8 +16,10 @@
#include "src/ops/bias_add.h"
#include <memory>
#include "nnacl/arithmetic_common.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
@ -84,18 +86,5 @@ PrimitiveC *BiasAddCreator(const schema::Primitive *primitive) { return Primitiv
Registry BiasAddRegistry(schema::PrimitiveType_BiasAdd, BiasAddCreator);
#endif
OpParameter *PopulateBiasAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry BiasAddParameterRegistry(schema::PrimitiveType_BiasAdd, PopulateBiasAddParameter);
} // namespace lite
} // namespace mindspore

@ -16,7 +16,9 @@
#include "src/ops/bias_grad.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -16,7 +16,9 @@
#include "src/ops/bn_grad.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {

@ -16,8 +16,9 @@
#include "src/ops/broadcast_to.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/fp32/broadcast_to.h"
#endif
namespace mindspore {
namespace lite {
@ -59,26 +60,6 @@ PrimitiveC *BroadcastToCreator(const schema::Primitive *primitive) {
Registry BroadcastToRegistry(schema::PrimitiveType_BroadcastTo, BroadcastToCreator);
#endif
OpParameter *PopulateBroadcastToParameter(const mindspore::lite::PrimitiveC *primitive) {
BroadcastToParameter *broadcast_param =
reinterpret_cast<BroadcastToParameter *>(malloc(sizeof(BroadcastToParameter)));
if (broadcast_param == nullptr) {
MS_LOG(ERROR) << "malloc BroadcastToParameter failed.";
return nullptr;
}
memset(broadcast_param, 0, sizeof(BroadcastToParameter));
auto param = reinterpret_cast<mindspore::lite::BroadcastTo *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
broadcast_param->op_parameter_.type_ = primitive->Type();
auto dst_shape = param->GetDstShape();
broadcast_param->shape_size_ = dst_shape.size();
for (size_t i = 0; i < broadcast_param->shape_size_; ++i) {
broadcast_param->shape_[i] = dst_shape[i];
}
return reinterpret_cast<OpParameter *>(broadcast_param);
}
Registry BroadcastToParameterRegistry(schema::PrimitiveType_BroadcastTo, PopulateBroadcastToParameter);
namespace {
constexpr int kBroadcastToInputNum = 1;
constexpr int kBroadcastToOutputNum = 1;

@ -16,8 +16,9 @@
#include "src/ops/cast.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#include "nnacl/fp32/cast.h"
#endif
namespace mindspore {
namespace lite {
@ -82,22 +83,6 @@ PrimitiveC *CastCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry CastRegistry(schema::PrimitiveType_Cast, CastCreator);
#endif
OpParameter *PopulateCastParameter(const mindspore::lite::PrimitiveC *primitive) {
CastParameter *cast_param = reinterpret_cast<CastParameter *>(malloc(sizeof(CastParameter)));
if (cast_param == nullptr) {
MS_LOG(ERROR) << "malloc CastParameter failed.";
return nullptr;
}
memset(cast_param, 0, sizeof(CastParameter));
cast_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Cast *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
cast_param->src_type_ = param->GetSrcT();
cast_param->dst_type_ = param->GetDstT();
return reinterpret_cast<OpParameter *>(cast_param);
}
Registry CastParameterRegistry(schema::PrimitiveType_Cast, PopulateCastParameter);
int Cast::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();

@ -21,7 +21,6 @@
#include <set>
#include <cmath>
#include "src/ops/arithmetic_self.h"
#include "src/ops/ops_register.h"
namespace mindspore {
namespace lite {

@ -16,7 +16,9 @@
#include "src/ops/clip.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
#include "nnacl/clip.h"
namespace mindspore {
@ -48,21 +50,6 @@ float Clip::GetMin() const { return this->primitive_->value_as_Clip()->min(); }
PrimitiveC *ClipCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Clip>(primitive); }
Registry ClipRegistry(schema::PrimitiveType_Clip, ClipCreator);
#endif
OpParameter *PopulateClipParameter(const mindspore::lite::PrimitiveC *primitive) {
ClipParameter *act_param = reinterpret_cast<ClipParameter *>(malloc(sizeof(ClipParameter)));
if (act_param == nullptr) {
MS_LOG(ERROR) << "malloc ClipParameter failed.";
return nullptr;
}
memset(act_param, 0, sizeof(ClipParameter));
act_param->op_parameter_.type_ = primitive->Type();
auto activation = reinterpret_cast<mindspore::lite::Clip *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
act_param->min_val_ = activation->GetMin();
act_param->max_val_ = activation->GetMax();
return reinterpret_cast<OpParameter *>(act_param);
}
Registry ClipParameterRegistry(schema::PrimitiveType_Clip, PopulateClipParameter);
} // namespace lite
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save