fix arithmetic compare, matmul, logicalnot, constant_folding_fusion

pull/9052/head
gongdaguo 4 years ago
parent 14a51ef727
commit 815b7af9ec

@ -0,0 +1,33 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/arithmetic_compare.h"
namespace mindspore {
namespace lite {
int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto res = Arithmetic::InferShape(inputs_, outputs_);
if (res == RET_OK) {
auto output = outputs_.front();
output->set_data_type(TypeId::kNumberTypeBool);
return RET_OK;
} else {
return res;
}
}
} // namespace lite
} // namespace mindspore

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
#define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/arithmetic.h"
namespace mindspore {
namespace lite {
class ArithmeticCompare : public Arithmetic {
public:
ArithmeticCompare() = default;
~ArithmeticCompare() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic);
explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
#endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_

@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); }
Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator);
#endif #endif
int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -20,21 +20,20 @@
#include <vector> #include <vector>
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Equal : public Arithmetic { class Equal : public ArithmeticCompare {
public: public:
Equal() = default; Equal() = default;
~Equal() = default; ~Equal() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Equal, PrimitiveC); MS_DECLARE_PARENT(Equal, ArithmeticCompare);
explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers
PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); }
Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator);
#endif #endif
int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -20,21 +20,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Greater : public Arithmetic { class Greater : public ArithmeticCompare {
public: public:
Greater() = default; Greater() = default;
~Greater() = default; ~Greater() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Greater, Arithmetic); MS_DECLARE_PARENT(Greater, ArithmeticCompare);
explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) {
Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator);
#endif #endif
int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class GreaterEqual : public Arithmetic { class GreaterEqual : public ArithmeticCompare {
public: public:
GreaterEqual() = default; GreaterEqual() = default;
~GreaterEqual() = default; ~GreaterEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(GreaterEqual, Arithmetic); MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare);
explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); Registry LessRegistry(schema::PrimitiveType_Less, LessCreator);
#endif #endif
int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class Less : public Arithmetic { class Less : public ArithmeticCompare {
public: public:
Less() = default; Less() = default;
~Less() = default; ~Less() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Less, Arithmetic); MS_DECLARE_PARENT(Less, ArithmeticCompare);
explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) {
} }
Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator);
#endif #endif
int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class LessEqual : public Arithmetic { class LessEqual : public ArithmeticCompare {
public: public:
LessEqual() = default; LessEqual() = default;
~LessEqual() = default; ~LessEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(LessEqual, Arithmetic); MS_DECLARE_PARENT(LessEqual, ArithmeticCompare);
explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
input0->set_shape(a_shape); input0->set_shape(a_shape);
} }
if (a_shape.size() < 2 || b_shape.size() < 2) { bool del_start = false;
MS_LOG(ERROR) << "inputs shape is invalid"; bool del_end = false;
return RET_INPUT_TENSOR_ERROR; if (a_shape.size() == 1) {
a_shape.insert(a_shape.begin(), 1);
input0->set_shape(a_shape);
del_start = true;
}
if (b_shape.size() == 1) {
b_shape.push_back(1);
input1->set_shape(b_shape);
del_end = true;
} }
for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) {
if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) {
@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
} }
std::vector<int> c_shape(a_shape); std::vector<int> c_shape(a_shape);
c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1];
if (del_start) {
c_shape.erase(c_shape.begin());
}
if (del_end) {
c_shape.pop_back();
}
output->set_shape(c_shape); output->set_shape(c_shape);
return RET_OK; return RET_OK;
} }

@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) {
Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator);
#endif #endif
int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front();
MS_ASSERT(output != nullptr);
output->set_shape(input->shape());
output->set_data_type(TypeId::kNumberTypeBool);
output->set_format(input->format());
return RET_OK;
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -21,21 +21,20 @@
#include <set> #include <set>
#include <cmath> #include <cmath>
#include "src/ops/arithmetic.h" #include "src/ops/arithmetic_compare.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
class NotEqual : public Arithmetic { class NotEqual : public ArithmeticCompare {
public: public:
NotEqual() = default; NotEqual() = default;
~NotEqual() = default; ~NotEqual() = default;
#ifdef PRIMITIVE_WRITEABLE #ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(NotEqual, Arithmetic); MS_DECLARE_PARENT(NotEqual, ArithmeticCompare);
explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {}
#else #else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif #endif
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme
Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic);
Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic);
Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic);
Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic);
Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic);
Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic);
Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic);
Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic);
Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic);

@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual;
using mindspore::schema::PrimitiveType_NotEqual; using mindspore::schema::PrimitiveType_NotEqual;
namespace mindspore::kernel { namespace mindspore::kernel {
namespace {
typedef struct {
int primitive_type_;
ArithmeticCompareFp32Func func_;
} TYPE_FUNC_INFO;
} // namespace
ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count,
TYPE_FUNC_INFO type_func_table[] = { int out_thread_stride) {
{PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, if (dim > break_pos_) {
{PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, if (data_type_ == kDataTypeInt) {
{PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride,
for (size_t i = 0; i < sizeof(type_func_table); i++) { reinterpret_cast<int *>(input1) + out_thread_stride,
if (type_func_table[i].primitive_type_ == primitive_type) { reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
return type_func_table[i].func_;
} }
return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride,
reinterpret_cast<float *>(input1) + out_thread_stride,
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count);
} }
return nullptr; for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) {
} int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i;
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int ArithmeticCompareCPUKernel::Init() { int error_code;
if (!InferShapeDone()) { if (data_type_ == kDataTypeInt) {
return RET_OK; error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim],
reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim],
reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim],
dim + 1, out_count, out_thread_stride);
}
if (error_code != RET_OK) {
return error_code;
}
} }
return ReSize(); return RET_OK;
} }
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
auto element_num = out_tensors_[0]->ElementsNum();
int ArithmeticCompareCPUKernel::DoExecute(int task_id) { MS_ASSERT(thread_count_ != 0);
if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) { int stride = UP_DIV(element_num, thread_count_);
MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! "; int count = MSMIN(stride, element_num - stride * task_id);
return RET_ERROR;
} if (func_fp32_ == nullptr) {
int elements_num = in_tensors_.at(0)->ElementsNum(); MS_LOG(ERROR) << "func_fp32_ function is nullptr!";
int stride = UP_DIV(elements_num, op_parameter_->thread_num_);
int offset = task_id * stride;
int count = MSMIN(stride, elements_num - offset);
if (count <= 0) {
return RET_OK;
}
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR; return RET_ERROR;
} }
// two inputs have the same shape, support broadcast later
auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData());
auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! ";
}
return ret;
}
int ArithmeticCompareRun(void *cdata, int task_id) { int error_code;
auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata); if (arithmeticParameter_->broadcasting_) { // need broadcast
auto ret = kernel->DoExecute(task_id); stride = UP_DIV(outside_, thread_count_);
if (ret != RET_OK) { int out_count = MSMIN(stride, outside_ - stride * task_id);
MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
} else {
error_code = BroadcastRun(
reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()),
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride);
}
} else { // no broadcast, neither is scalar, two same shape
if (data_type_ == kDataTypeFloat) {
error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
} else {
error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id,
reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id,
reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count);
}
} }
return ret; if (error_code != RET_OK) {
} return RET_ERROR;
int ArithmeticCompareCPUKernel::Run() {
auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
} }
return ret; return RET_OK;
} }
kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,

@ -18,27 +18,57 @@
#include <vector> #include <vector>
#include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h"
#include "nnacl/fp32/arithmetic_compare_fp32.h"
namespace mindspore::kernel { namespace mindspore::kernel {
typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size);
typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size);
class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel {
public: public:
explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticCompareFun(parameter->type_); switch (parameter->type_) {
case PrimitiveType_Equal:
func_fp32_ = ElementEqualFp32;
func_int32_ = ElementEqualInt32;
break;
case PrimitiveType_NotEqual:
func_fp32_ = ElementNotEqualFp32;
func_int32_ = ElementNotEqualInt32;
break;
case PrimitiveType_Less:
func_fp32_ = ElementLessFp32;
func_int32_ = ElementLessInt32;
break;
case PrimitiveType_LessEqual:
func_fp32_ = ElementLessEqualFp32;
func_int32_ = ElementLessEqualInt32;
break;
case PrimitiveType_Greater:
func_fp32_ = ElementGreaterFp32;
func_int32_ = ElementGreaterInt32;
break;
case PrimitiveType_GreaterEqual:
func_fp32_ = ElementGreaterEqualFp32;
func_int32_ = ElementGreaterEqualInt32;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
func_fp32_ = nullptr;
func_int32_ = nullptr;
break;
}
} }
~ArithmeticCompareCPUKernel() override = default; ~ArithmeticCompareCPUKernel() override = default;
int Init() override; int DoArithmetic(int task_id) override;
int ReSize() override; int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override;
int Run() override;
virtual int DoExecute(int task_id);
private: private:
ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); ArithmeticCompareFp32Func func_fp32_ = nullptr;
ArithmeticCompareFp32Func func_; ArithmeticCompareIntFunc func_int32_ = nullptr;
}; };
int ArithmeticCompareRun(void *cdata, int task_id); int ArithmeticCompareRun(void *cdata, int task_id);
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() {
break; break;
} }
break; break;
case PrimitiveType_Equal:
case PrimitiveType_Less:
case PrimitiveType_Greater:
case PrimitiveType_NotEqual:
case PrimitiveType_LessEqual:
case PrimitiveType_GreaterEqual:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
break;
default: default:
break; break;
} }

@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel {
int PreProcess() override; int PreProcess() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id); virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
private: protected:
int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
int break_pos_ = 0; int break_pos_ = 0;
int outside_ = 0; int outside_ = 0;
int thread_count_ = 1; int thread_count_ = 1;
ArithmeticParameter *arithmeticParameter_ = nullptr; ArithmeticParameter *arithmeticParameter_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
private:
ArithmeticRun arithmetic_run_ = nullptr; ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticIntRun arithmetic_run_int_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr;
ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr;
LiteDataType data_type_ = kDataTypeFloat;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_

@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(),
[onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; });
if (nodeIter == onnx_graph.initializer().end()) { if (nodeIter == onnx_graph.initializer().end()) {
MS_LOG(ERROR) << "not find node: " << onnx_conv_weight; MS_LOG(WARNING) << "not find node: " << onnx_conv_weight;
return RET_ERROR; } else {
} std::vector<int> weight_shape;
std::vector<int> weight_shape; auto size = (*nodeIter).dims_size();
auto size = (*nodeIter).dims_size(); weight_shape.reserve(size);
weight_shape.reserve(size); for (int i = 0; i < size; ++i) {
for (int i = 0; i < size; ++i) { weight_shape.emplace_back((*nodeIter).dims(i));
weight_shape.emplace_back((*nodeIter).dims(i)); }
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
} }
attr->channelOut = weight_shape[0];
attr->channelIn = weight_shape[1] * attr->group;
} else { } else {
auto nodeIter = auto nodeIter =
std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(),

@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
output_tensors[m]->AddQuantParam(quant_arg); output_tensors[m]->AddQuantParam(quant_arg);
} }
} }
// here, input_tensor's format need to be transposed nhwc according to fmkType,
// but for the time being, we only transpose the tensor with 0/1/2/3D.
// Others should be added in future.
for (auto &input_tensor : input_tensors) {
input_tensor->set_format(schema::Format::Format_NHWC);
if (input_tensor->shape().size() == 4) {
MS_LOG(INFO) << "init input_tensor format to nhwc";
}
}
lite_primitive->InferShape(input_tensors, output_tensors); lite_primitive->InferShape(input_tensors, output_tensors);
auto primitive = lite_primitive.get(); auto primitive = lite_primitive.get();
MS_ASSERT(primitive != nullptr); MS_ASSERT(primitive != nullptr);

Loading…
Cancel
Save