!4919 Change Primitive to PrimitiveC

Merge pull request !4919 from yeyunpeng2020/primitve_1
pull/4919/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4f928a4f7e

@ -20,7 +20,7 @@
#include <string>
#include <vector>
#include <memory>
#include "schema/model_generated.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
#define MS_API __attribute__((visibility("default")))

@ -32,11 +32,10 @@ set(ANF_SRC
${ANF_SRC}
${CMAKE_CURRENT_SOURCE_DIR}/ir/meta_tensor_extends.cc
)
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC})
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/ops/*.cc)
add_library(mindspore-lite SHARED ${LITE_SRC} ${ANF_SRC} ${C_OPS_SRC})
target_link_libraries(mindspore-lite
cpu_kernel_mid_
c_ops_mid
)
add_subdirectory(runtime/kernel/arm)

@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ir/primitive_t_value.h"
namespace mindspore::lite {
std::shared_ptr<PrimitiveTValue> GetReturnPrim() {
auto return_primitiveT = new schema::PrimitiveT;
return_primitiveT->value.type = schema::PrimitiveType_Return;
return_primitiveT->value.value = new schema::ReturnT;
return std::make_shared<PrimitiveTValue>(return_primitiveT);
}
std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim() {
auto make_tuple_primitiveT = new schema::PrimitiveT;
make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple;
make_tuple_primitiveT->value.value = new schema::MakeTupleT;
return std::make_shared<PrimitiveTValue>(make_tuple_primitiveT);
}
std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim() {
auto tuple_get_item_primitiveT = new schema::PrimitiveT();
tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem;
tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT;
return std::make_shared<PrimitiveTValue>(tuple_get_item_primitiveT);
}
} // namespace mindspore::lite

@ -1,91 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_
#define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_
#include <vector>
#include <memory>
#include "schema/inner/model_generated.h"
#include "ir/value.h"
namespace mindspore::lite {
class PrimitiveTValue : public Value {
public:
explicit PrimitiveTValue(schema::PrimitiveT *primt) : primitive(primt) {}
// not responsible to free primitive, the one created the dynamic memory is responsible to free it.
~PrimitiveTValue() override = default;
MS_DECLARE_PARENT(PrimitiveTValue, Value)
schema::PrimitiveT *GetPrimitiveT() const { return this->primitive; }
void SetPrimitiveT(schema::PrimitiveT *primIn) { this->primitive = primIn; }
bool operator==(const Value &rhs) const override {
if (rhs.isa<PrimitiveTValue>()) {
auto other_prim = static_cast<const PrimitiveTValue &>(rhs);
auto a = this->primitive->value.type;
auto b = other_prim.primitive->value.type;
return a == b;
} else {
return false;
}
}
void SetInputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &input_quant_param) {
this->input_quant_param_ = input_quant_param;
}
void SetOutputQuantParam(const std::vector<std::vector<schema::QuantParamT>> &output_quant_param) {
this->output_quant_param_ = output_quant_param;
}
void ClearInputOutputQuantParam() {
input_quant_param_.clear();
output_quant_param_.clear();
}
void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->input_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { return input_quant_param_; }
void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) {
this->output_quant_param_.emplace_back(quant_param);
}
std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { return output_quant_param_; }
void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; }
schema::QuantType GetQuantType() const { return quant_type_; }
protected:
schema::PrimitiveT *primitive = nullptr;
std::vector<std::vector<schema::QuantParamT>> input_quant_param_;
std::vector<std::vector<schema::QuantParamT>> output_quant_param_;
schema::QuantType quant_type_{schema::QuantType_QUANT_NONE};
};
std::shared_ptr<PrimitiveTValue> GetReturnPrim();
std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim();
std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim();
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_

@ -21,11 +21,11 @@
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include "src/ops/primitive_c.h"
#include "src/runtime/kernel/arm/nnacl/op_base.h"
#include "include/context.h"
#include "src/ir/tensor.h"
#include "include/errorcode.h"
#include "src/ops/primitive_c.h"
#ifdef ENABLE_FP16
using FLOAT_t = float16_t;

@ -14,9 +14,9 @@
* limitations under the License.
*/
#include "src/lite_session.h"
#include <vector>
#include "include/errorcode.h"
#include "src/lite_session.h"
#include "utils/log_adapter.h"
#include "src/scheduler.h"
#include "src/runtime/runtime_api.h"
@ -76,6 +76,7 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
this->tensors_.emplace_back(dstTensor);
}
return RET_OK;
}

@ -21,11 +21,11 @@
#include <vector>
#include <string>
#include <unordered_map>
#include "src/lite_kernel.h"
#include "include/ms_tensor.h"
#include "include/lite_session.h"
#include "include/model.h"
#include "include/context.h"
#include "src/lite_kernel.h"
#include "schema/model_generated.h"
#include "src/executor.h"

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "include/model.h"
#include "src/ops/unique.h"
#include "src/ops/space_to_batch.h"
#include "src/ops/conv2d.h"
@ -106,8 +107,6 @@
#include "src/ops/squared_difference.h"
#include "src/ops/ceil.h"
#include "src/ops/round.h"
#include "src/ops/primitive_c.h"
#include "include/model.h"
#include "utils/log_adapter.h"
namespace mindspore::lite {

@ -1,3 +0,0 @@
file(GLOB_RECURSE C_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_library(c_ops_mid OBJECT ${C_OPS_SRC})

@ -32,7 +32,10 @@ namespace mindspore {
namespace lite {
class Abs : public ArithmeticSelf {
public:
explicit Abs(OriginPrimitive *primitive) : ArithmeticSelf(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Abs(schema::PrimitiveT *primitive) : ArithmeticSelf(primitive) {}
#endif
explicit Abs(schema::Primitive *primitive) : ArithmeticSelf(primitive) {}
};
} // namespace lite
} // namespace mindspore

@ -19,16 +19,16 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Activation::GetType() const { return this->primitive->value.AsActivation()->type; }
float Activation::GetAlpha() const { return this->primitive->value.AsActivation()->alpha; }
int Activation::GetType() const { return this->primitive_->value.AsActivation()->type; }
float Activation::GetAlpha() const { return this->primitive_->value.AsActivation()->alpha; }
void Activation::SetType(int type) { this->primitive->value.AsActivation()->type = (schema::ActivationType)type; }
void Activation::SetAlpha(float alpha) { this->primitive->value.AsActivation()->alpha = alpha; }
void Activation::SetType(int type) { this->primitive_->value.AsActivation()->type = (schema::ActivationType)type; }
void Activation::SetAlpha(float alpha) { this->primitive_->value.AsActivation()->alpha = alpha; }
#else
int Activation::GetType() const { return this->primitive->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive->value_as_Activation()->alpha(); }
int Activation::GetType() const { return this->primitive_->value_as_Activation()->type(); }
float Activation::GetAlpha() const { return this->primitive_->value_as_Activation()->alpha(); }
void Activation::SetType(int type) {}
void Activation::SetAlpha(float alpha) {}

@ -13,26 +13,23 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_H_
namespace mindspore {
namespace lite {
class Activation : public PrimitiveC {
public:
explicit Activation(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Activation(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Activation(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int GetType() const;
float GetAlpha() const;
void SetType(int type);

@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ActivationGrad::GetType() const { return this->primitive->value.AsActivationGrad()->type; }
int ActivationGrad::GetType() const { return this->primitive_->value.AsActivationGrad()->type; }
void ActivationGrad::SetType(int type) {
this->primitive->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
this->primitive_->value.AsActivationGrad()->type = (schema::ActivationGradType)type;
}
#else
int ActivationGrad::GetType() const { return this->primitive->value_as_ActivationGrad()->type(); }
int ActivationGrad::GetType() const { return this->primitive_->value_as_ActivationGrad()->type(); }
void ActivationGrad::SetType(int type) {}
#endif

@ -14,25 +14,23 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ACTIVATION_GRAD_H_
namespace mindspore {
namespace lite {
class ActivationGrad : public PrimitiveC {
public:
explicit ActivationGrad(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ActivationGrad(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ActivationGrad(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int GetType() const;
void SetType(int type);

@ -19,15 +19,15 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Add::GetActivationType() const { return this->primitive->value.AsAdd()->activationType; }
int Add::GetActivationType() const { return this->primitive_->value.AsAdd()->activationType; }
void Add::SetActivationType(int activation_type) {
this->primitive->value.AsAdd()->activationType = (schema::ActivationType)activation_type;
this->primitive_->value.AsAdd()->activationType = (schema::ActivationType)activation_type;
}
#else
int Add::GetActivationType() const { return this->primitive->value_as_Add()->activationType(); }
int Add::GetActivationType() const { return this->primitive_->value_as_Add()->activationType(); }
void Add::SetActivationType(int activation_type) {}
#endif

@ -14,6 +14,9 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_
#include <vector>
#include <set>
#include <cmath>
@ -24,14 +27,15 @@
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_H_
namespace mindspore {
namespace lite {
class Add : public Arithmetic {
public:
explicit Add(OriginPrimitive *primitive) : Arithmetic(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Add(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
explicit Add(schema::Primitive *primitive) : Arithmetic(primitive) {}
int GetActivationType() const;
void SetActivationType(int activation_type);

@ -19,13 +19,13 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int AddN::GetN() const { return this->primitive->value.AsAddN()->N; }
int AddN::GetN() const { return this->primitive_->value.AsAddN()->N; }
void AddN::SetN(int n) { this->primitive->value.AsAddN()->N = n; }
void AddN::SetN(int n) { this->primitive_->value.AsAddN()->N = n; }
#else
int AddN::GetN() const { return this->primitive->value_as_AddN()->N(); }
int AddN::GetN() const { return this->primitive_->value_as_AddN()->N(); }
void AddN::SetN(int n) {}
#endif
@ -34,7 +34,7 @@ namespace {
constexpr int kLeastInputNum = 2;
}
int AddN::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs.front();
MS_ASSERT(input != nullptr);
auto output = outputs.front();

@ -14,25 +14,23 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
#define LITE_MINDSPORE_LITE_C_OPS_ADD_N_H_
namespace mindspore {
namespace lite {
class AddN : public PrimitiveC {
public:
explicit AddN(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit AddN(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit AddN(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetN() const;
void SetN(int n);

@ -19,25 +19,25 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMax::GetAxis() const { return this->primitive->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive->value.AsArgMax()->axisType; }
int ArgMax::GetAxis() const { return this->primitive_->value.AsArgMax()->axis; }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value.AsArgMax()->outMaxValue; }
int ArgMax::GetTopK() const { return this->primitive_->value.AsArgMax()->topK; }
bool ArgMax::GetKeepDims() const { return this->primitive_->value.AsArgMax()->keepDims; }
int ArgMax::GetAxisType() const { return this->primitive_->value.AsArgMax()->axisType; }
void ArgMax::SetAxis(int axis) { this->primitive->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive->value.AsArgMax()->axisType = axis_type; }
void ArgMax::SetAxis(int axis) { this->primitive_->value.AsArgMax()->axis = axis; }
void ArgMax::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMax()->outMaxValue = out_max_value; }
void ArgMax::SetTopK(int top_k) { this->primitive_->value.AsArgMax()->topK = top_k; }
void ArgMax::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMax()->keepDims = keep_dims; }
void ArgMax::SetAxisType(int axis_type) { this->primitive_->value.AsArgMax()->axisType = axis_type; }
#else
int ArgMax::GetAxis() const { return this->primitive->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive->value_as_ArgMax()->axisType(); }
int ArgMax::GetAxis() const { return this->primitive_->value_as_ArgMax()->axis(); }
bool ArgMax::GetOutMaxValue() const { return this->primitive_->value_as_ArgMax()->outMaxValue(); }
int ArgMax::GetTopK() const { return this->primitive_->value_as_ArgMax()->topK(); }
bool ArgMax::GetKeepDims() const { return this->primitive_->value_as_ArgMax()->keepDims(); }
int ArgMax::GetAxisType() const { return this->primitive_->value_as_ArgMax()->axisType(); }
void ArgMax::SetAxis(int axis) {}
void ArgMax::SetOutMaxValue(bool out_max_value) {}
@ -47,7 +47,7 @@ void ArgMax::SetAxisType(int axis_type) {}
#endif
int ArgMax::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();

@ -14,25 +14,23 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MAX_H_
namespace mindspore {
namespace lite {
class ArgMax : public PrimitiveC {
public:
explicit ArgMax(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMax(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMax(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;

@ -19,25 +19,25 @@
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int ArgMin::GetAxis() const { return this->primitive->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive->value.AsArgMin()->axisType; }
int ArgMin::GetAxis() const { return this->primitive_->value.AsArgMin()->axis; }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value.AsArgMin()->outMaxValue; }
int ArgMin::GetTopK() const { return this->primitive_->value.AsArgMin()->topK; }
bool ArgMin::GetKeepDims() const { return this->primitive_->value.AsArgMin()->keepDims; }
int ArgMin::GetAxisType() const { return this->primitive_->value.AsArgMin()->axisType; }
void ArgMin::SetAxis(int axis) { this->primitive->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive->value.AsArgMin()->axisType = axis_type; }
void ArgMin::SetAxis(int axis) { this->primitive_->value.AsArgMin()->axis = axis; }
void ArgMin::SetOutMaxValue(bool out_max_value) { this->primitive_->value.AsArgMin()->outMaxValue = out_max_value; }
void ArgMin::SetTopK(int top_k) { this->primitive_->value.AsArgMin()->topK = top_k; }
void ArgMin::SetKeepDims(bool keep_dims) { this->primitive_->value.AsArgMin()->keepDims = keep_dims; }
void ArgMin::SetAxisType(int axis_type) { this->primitive_->value.AsArgMin()->axisType = axis_type; }
#else
int ArgMin::GetAxis() const { return this->primitive->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive->value_as_ArgMin()->axisType(); }
int ArgMin::GetAxis() const { return this->primitive_->value_as_ArgMin()->axis(); }
bool ArgMin::GetOutMaxValue() const { return this->primitive_->value_as_ArgMin()->outMaxValue(); }
int ArgMin::GetTopK() const { return this->primitive_->value_as_ArgMin()->topK(); }
bool ArgMin::GetKeepDims() const { return this->primitive_->value_as_ArgMin()->keepDims(); }
int ArgMin::GetAxisType() const { return this->primitive_->value_as_ArgMin()->axisType(); }
void ArgMin::SetAxis(int axis) {}
void ArgMin::SetOutMaxValue(bool out_max_value) {}
@ -47,7 +47,7 @@ void ArgMin::SetAxisType(int axis_type) {}
#endif
int ArgMin::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();

@ -14,25 +14,23 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARG_MIN_H_
namespace mindspore {
namespace lite {
class ArgMin : public PrimitiveC {
public:
explicit ArgMin(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit ArgMin(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit ArgMin(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
int GetAxis() const;

@ -22,7 +22,7 @@
namespace mindspore {
namespace lite {
int Arithmetic::InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
if (inputs_.size() != kDoubleNum) {
MS_LOG(ERROR) << "The number of input must be " << kDoubleNum;
return RET_INPUT_TENSOR_ERROR;

@ -14,25 +14,23 @@
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
#include <vector>
#include <set>
#include <cmath>
#include "ir/dtype/type_id.h"
#include "src/ops/primitive_c.h"
#ifdef PRIMITIVE_WRITEABLE
#include "schema/inner/model_generated.h"
#else
#include "schema/model_generated.h"
#endif
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_H_
namespace mindspore {
namespace lite {
class Arithmetic : public PrimitiveC {
public:
explicit Arithmetic(OriginPrimitive *primitive) : PrimitiveC(primitive) {}
#ifdef PRIMITIVE_WRITEABLE
explicit Arithmetic(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#endif
explicit Arithmetic(schema::Primitive *primitive) : PrimitiveC(primitive) {}
int InferShape(std::vector<lite::tensor::Tensor *> inputs_, std::vector<lite::tensor::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; }

@ -22,7 +22,7 @@ namespace mindspore {
namespace lite {
int ArithmeticSelf::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save