!12123 [MSLITE] Add fp16 gather op

From: @zhanyuan1
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/12123/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 509bb8a948

@ -0,0 +1,177 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/gather_fp16.h"
#include <limits>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "nnacl/fp16/cast_fp16.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;
namespace mindspore::kernel {
int GatherFp16CPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int GatherFp16CPUKernel::ReSize() { return RET_OK; }
int GatherFp16CPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
out_tensors_[0]->set_data_type(kNumberTypeFloat16);
}
for (auto *output : this->out_tensors()) {
MS_ASSERT(output != nullptr);
if (output->ElementsNum() >= lite::MAX_MALLOC_SIZE / static_cast<int>(sizeof(int64_t))) {
MS_LOG(ERROR) << "The size of output tensor is too big";
return RET_ERROR;
}
auto ret = output->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "MallocData failed";
return ret;
}
}
return RET_OK;
}
int GatherFp16CPUKernel::DoGather(int task_id) {
auto input_tensor = in_tensors_.at(0);
auto indices_tensor = in_tensors_.at(1);
auto out_tensor = out_tensors_.at(0);
auto in_shape = input_tensor->shape();
int in_rank = in_shape.size();
int indices_element_size = indices_tensor->ElementsNum();
auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
const int limit = in_shape.at(axis);
int outer_size = 1, inner_size = 1;
for (int i = 0; i < axis; ++i) {
outer_size *= in_shape.at(i);
}
for (int i = axis + 1; i < in_rank; ++i) {
inner_size *= in_shape.at(i);
}
int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
int count = MSMIN(stride, outer_size - stride * task_id);
auto thread_stride = stride * task_id;
int8_t *int8_in = nullptr;
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
int8_in = reinterpret_cast<int8_t *>(input_data_);
} else if (input_tensor->data_type() == kNumberTypeFloat16) {
int8_in = reinterpret_cast<int8_t *>(input_tensor->data_c());
} else {
MS_LOG(ERROR) << "input data type error";
return RET_ERROR;
}
int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data_c());
int data_size = lite::DataTypeSize(kNumberTypeFloat16);
int8_in += thread_stride * limit * data_size;
int8_out += thread_stride * indices_element_size * data_size;
int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
return error_code;
}
int GatherRunFp16(void *cdata, int task_id) {
auto gather_kernel = reinterpret_cast<GatherFp16CPUKernel *>(cdata);
auto error_code = gather_kernel->DoGather(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
}
return error_code;
}
int GatherFp16CPUKernel::Run() {
auto indices_tensor = in_tensors_.at(1);
int indices_num = indices_tensor->ElementsNum();
bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
if (ret != RET_OK) {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
}
if (!isIndicesInt32) {
context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
if (input_data_) {
context_->allocator->Free(input_data_);
input_data_ = nullptr;
}
return ret;
}
int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor) {
if (!isIndicesInt32) {
if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
return RET_ERROR;
}
indices_data_ = reinterpret_cast<int32_t *>(context_->allocator->Malloc(sizeof(int32_t) * indices_num));
if (indices_data_ == nullptr) {
MS_LOG(ERROR) << "Memory allocation failed";
return RET_ERROR;
}
if (indices_tensor->data_type() == kNumberTypeInt64) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->MutableData())[i];
}
} else if (indices_tensor->data_type() == kNumberTypeFloat16) {
for (int i = 0; i < indices_num; i++) {
indices_data_[i] = reinterpret_cast<float16_t *>(indices_tensor->MutableData())[i];
}
} else {
MS_LOG(ERROR) << "The data type of indices tensor is wrong";
return RET_ERROR;
}
} else {
indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->MutableData());
}
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Gather, LiteKernelCreator<GatherFp16CPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GATHER_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GATHER_H_
#include <arm_neon.h>
#include <vector>
#include "include/errorcode.h"
#include "src/lite_kernel.h"
#include "nnacl/gather_parameter.h"
#include "nnacl/base/gather_base.h"
namespace mindspore::kernel {
class GatherFp16CPUKernel : public LiteKernel {
public:
GatherFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~GatherFp16CPUKernel() = default;
int Init() override;
int ReSize() override;
int PreProcess() override;
int Run() override;
int DoGather(int task_id);
private:
int *indices_data_ = nullptr;
int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor);
float16_t *input_data_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_GATHER_H_
Loading…
Cancel
Save