!7708 [MSLITE][Develop] Refactor arithmetic populate and add eltwise int8 kernel

Merge pull request !7708 from sunsuodong/eltwise_int8
pull/7708/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ee1865605d

@ -39,11 +39,11 @@ class Arithmetic : public PrimitiveC {
} }
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
bool Broadcasting() { return this->broadcasting_; } bool Broadcasting() const { return this->broadcasting_; }
int NDims() { return this->ndim_; } int NDims() const { return this->ndim_; }
std::vector<int> InShape0() { return this->in_shape0_; } std::vector<int> InShape0() const { return this->in_shape0_; }
std::vector<int> InShape1() { return this->in_shape1_; } std::vector<int> InShape1() const { return this->in_shape1_; }
std::vector<int> OutputShape() { return this->out_shape_; } std::vector<int> OutputShape() const { return this->out_shape_; }
protected: protected:
bool broadcasting_ = false; bool broadcasting_ = false;

@ -21,20 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
#include "src/ops/arithmetic.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Eltwise : public PrimitiveC { class Eltwise : public Arithmetic {
public: public:
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Eltwise, PrimitiveC); MS_DECLARE_PARENT(Eltwise, Arithmetic);
Eltwise() = default; Eltwise() = default;
explicit Eltwise(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} explicit Eltwise(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
void SetMode(int mode); void SetMode(int mode);
#else #else
Eltwise() = default; Eltwise() = default;
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int GetMode() const; int GetMode() const;

@ -1,46 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/add.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Add *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
} // namespace lite
} // namespace mindspore

@ -13,8 +13,13 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic.h"
#include "src/ops/add.h"
#include "src/ops/sub.h"
#include "src/ops/mul.h"
#include "src/ops/div.h"
#include "src/ops/eltwise.h"
#include "src/ops/greater_equal.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
#include "src/tensor.h" #include "src/tensor.h"
#include "src/ops/primitive_c.h" #include "src/ops/primitive_c.h"
@ -22,27 +27,98 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
ArithmeticParameter *PopulateArithmeticCommonPara(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(param, 0, sizeof(ArithmeticParameter));
param->op_parameter_.type_ = primitive->Type();
param->broadcasting_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->Broadcasting();
param->ndim_ = reinterpret_cast<const lite::Arithmetic *>(primitive)->NDims();
param->activation_type_ = 0;
auto tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape0();
memcpy(param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->InShape1();
memcpy(param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = reinterpret_cast<const lite::Arithmetic *>(primitive)->OutputShape();
memcpy(param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return param;
}
OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) { OpParameter *PopulateArithmetic(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter))); ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (arithmetic_param == nullptr) { if (param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed."; MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateAddParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
param->activation_type_ = reinterpret_cast<const mindspore::lite::Add *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr; return nullptr;
} }
memset(arithmetic_param, 0, sizeof(ArithmeticParameter)); param->activation_type_ = reinterpret_cast<const mindspore::lite::Sub *>(primitive)->GetActivationType();
arithmetic_param->op_parameter_.type_ = primitive->Type(); return reinterpret_cast<OpParameter *>(param);
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting(); }
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ = 0; OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
param->activation_type_ = reinterpret_cast<const mindspore::lite::Mul *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0(); OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) {
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1(); if (param == nullptr) {
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape(); return nullptr;
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int)); }
return reinterpret_cast<OpParameter *>(arithmetic_param); param->activation_type_ = reinterpret_cast<const mindspore::lite::Div *>(primitive)->GetActivationType();
return reinterpret_cast<OpParameter *>(param);
}
OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *param = PopulateArithmeticCommonPara(primitive);
if (param == nullptr) {
MS_LOG(ERROR) << "PopulateArithmeticCommonPara failed.";
return nullptr;
}
auto eltwise = reinterpret_cast<const mindspore::lite::Eltwise *>(primitive);
switch (eltwise->GetMode()) {
case schema::EltwiseMode_PROD:
param->op_parameter_.type_ = schema::PrimitiveType_Mul;
break;
case schema::EltwiseMode_SUM:
param->op_parameter_.type_ = schema::PrimitiveType_Add;
break;
case schema::EltwiseMode_MAXIMUM:
param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
break;
default:
free(param);
return nullptr;
}
return reinterpret_cast<OpParameter *>(param);
} }
Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic); Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithmetic);
@ -51,6 +127,7 @@ Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
@ -58,5 +135,10 @@ Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithme
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);
Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic); Registry FloorModParameterRegistry(schema::PrimitiveType_FloorMod, PopulateArithmetic);
Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic); Registry SquaredDifferenceParameterRegistry(schema::PrimitiveType_SquaredDifference, PopulateArithmetic);
Registry AddParameterRegistry(schema::PrimitiveType_Add, PopulateAddParameter);
Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter);
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter);
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter);
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/div.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateDivParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Div *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry DivParameterRegistry(schema::PrimitiveType_Div, PopulateDivParameter);
} // namespace lite
} // namespace mindspore

@ -1,52 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/eltwise.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateEltwiseParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
auto eltwise = reinterpret_cast<mindspore::lite::Eltwise *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
switch (eltwise->GetMode()) {
case schema::EltwiseMode_PROD:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Mul;
break;
case schema::EltwiseMode_SUM:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Add;
break;
case schema::EltwiseMode_MAXIMUM:
arithmetic_param->op_parameter_.type_ = schema::PrimitiveType_Maximum;
break;
default:
free(arithmetic_param);
return nullptr;
}
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry EltwiseParameterRegistry(schema::PrimitiveType_Eltwise, PopulateEltwiseParameter);
} // namespace lite
} // namespace mindspore

@ -1,48 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/mul.h"
#include "nnacl/arithmetic_common.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateMulParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Mul *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry MulParameterRegistry(schema::PrimitiveType_Mul, PopulateMulParameter);
} // namespace lite
} // namespace mindspore

@ -1,47 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/sub.h"
#include "src/ops/primitive_c.h"
#include "src/ops/populate/populate_register.h"
#include "nnacl/arithmetic_common.h"
namespace mindspore {
namespace lite {
OpParameter *PopulateSubParameter(const mindspore::lite::PrimitiveC *primitive) {
ArithmeticParameter *arithmetic_param = reinterpret_cast<ArithmeticParameter *>(malloc(sizeof(ArithmeticParameter)));
if (arithmetic_param == nullptr) {
MS_LOG(ERROR) << "malloc ArithmeticParameter failed.";
return nullptr;
}
memset(arithmetic_param, 0, sizeof(ArithmeticParameter));
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
arithmetic_param->activation_type_ =
reinterpret_cast<mindspore::lite::Sub *>(const_cast<mindspore::lite::PrimitiveC *>(primitive))->GetActivationType();
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();
memcpy(arithmetic_param->in_shape1_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->OutputShape();
memcpy(arithmetic_param->out_shape_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(arithmetic_param);
}
Registry SubParameterRegistry(schema::PrimitiveType_Sub, PopulateSubParameter);
} // namespace lite
} // namespace mindspore

@ -15,6 +15,8 @@
*/ */
#include "src/runtime/kernel/arm/int8/arithmetic_int8.h" #include "src/runtime/kernel/arm/int8/arithmetic_int8.h"
#include "src/runtime/kernel/arm/int8/add_int8.h"
#include "src/runtime/kernel/arm/int8/mul_int8.h"
#include "nnacl/arithmetic_common.h" #include "nnacl/arithmetic_common.h"
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
@ -27,11 +29,14 @@ using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::lite::RET_PARAM_INVALID; using mindspore::lite::RET_PARAM_INVALID;
using mindspore::schema::PrimitiveType_Add;
using mindspore::schema::PrimitiveType_Eltwise;
using mindspore::schema::PrimitiveType_Equal; using mindspore::schema::PrimitiveType_Equal;
using mindspore::schema::PrimitiveType_Greater; using mindspore::schema::PrimitiveType_Greater;
using mindspore::schema::PrimitiveType_GreaterEqual; using mindspore::schema::PrimitiveType_GreaterEqual;
using mindspore::schema::PrimitiveType_Less; using mindspore::schema::PrimitiveType_Less;
using mindspore::schema::PrimitiveType_LessEqual; using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_NotEqual; using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel { namespace mindspore::kernel {
@ -159,11 +164,15 @@ kernel::LiteKernel *CpuArithmeticInt8KernelCreator(const std::vector<lite::Tenso
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter, const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc, const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
if (parameter == nullptr) { kernel::LiteKernel *kernel = nullptr;
MS_LOG(ERROR) << "Input parameter is null!"; if (desc.type == PrimitiveType_Eltwise && static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Add) {
return nullptr; kernel = new (std::nothrow) QuantizedAddCPUKernel(parameter, inputs, outputs, ctx, primitive);
} else if (desc.type == PrimitiveType_Eltwise &&
static_cast<schema::PrimitiveType>(parameter->type_) == PrimitiveType_Mul) {
kernel = new (std::nothrow) MulInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
} }
auto kernel = new (std::nothrow) ArithmeticInt8CPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) { if (kernel == nullptr) {
MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_; MS_LOG(ERROR) << "Create ArithmeticInt8CPUKernel failed, name: " << parameter->name_;
free(parameter); free(parameter);
@ -185,5 +194,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Less, CpuArithmeticInt8KernelCre
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LessEqual, CpuArithmeticInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LessEqual, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Greater, CpuArithmeticInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Greater, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GreaterEqual, CpuArithmeticInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_GreaterEqual, CpuArithmeticInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Eltwise, CpuArithmeticInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

Loading…
Cancel
Save