Refactor slice and add fp16 kernel

pull/6932/head
sunsuodong 5 years ago
parent b6a7f8bd71
commit ef330cdffe

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/slice_fp16.h"
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/errorcode.h"
void DoSliceFp16(const float16_t *input, float16_t *output, SliceParameter *param, int thread_id) {
int32_t out_dim1 = param->size_[1];
int32_t out_dim2 = param->size_[2];
int32_t out_dim3 = param->size_[3];
size_t out_stride2 = out_dim3;
size_t out_stride1 = out_stride2 * out_dim2;
size_t out_stride0 = out_stride1 * out_dim1;
size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_);
size_t thread_stride = thread_id * count_per_thread;
size_t copy_size = param->size_[3] * sizeof(float16_t);
size_t in_stride2 = param->shape_[3];
size_t in_stride1 = param->shape_[2] * in_stride2;
size_t in_stride0 = param->shape_[1] * in_stride1;
for (int i = 0; i < param->size_[0]; ++i) {
size_t out_offset0 = i * out_stride0;
size_t in_offset0 = (i + param->begin_[0]) * in_stride0 + param->begin_[3];
for (size_t j = 0; j < count_per_thread; ++j) {
size_t k = j + thread_stride;
if (k >= out_dim1) {
break;
}
size_t out_offset1 = k * out_stride1 + out_offset0;
size_t in_offset1 = (k + param->begin_[1]) * in_stride1 + in_offset0;
for (int l = 0; l < out_dim2; ++l) {
size_t out_offset = out_offset1 + l * out_stride2;
size_t in_offset = in_offset1 + (l + param->begin_[2]) * in_stride2;
memcpy(output + out_offset, input + in_offset, copy_size);
}
}
}
}
void DoSliceFp16NoParallel(const float16_t *input, float16_t *output, SliceParameter *param) {
size_t copy_size = param->size_[3] * sizeof(float16_t);
size_t in_stride2 = param->shape_[3];
size_t in_stride1 = param->shape_[2] * in_stride2;
size_t in_stride0 = param->shape_[1] * in_stride1;
size_t out_offset = 0;
for (int32_t dim0 = param->begin_[0]; dim0 < param->end_[0]; ++dim0) {
size_t in_offset0 = dim0 * in_stride0 + param->begin_[3];
for (size_t dim1 = param->begin_[1]; dim1 < param->end_[1]; ++dim1) {
size_t in_offset1 = dim1 * in_stride1 + in_offset0;
for (int32_t dim2 = param->begin_[2]; dim2 < param->end_[2]; ++dim2) {
size_t in_offset = in_offset1 + dim2 * in_stride2;
memcpy(output + out_offset, input + in_offset, copy_size);
out_offset += param->size_[3];
}
}
}
}

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_
#define MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_
#include "nnacl/op_base.h"
#include "nnacl/slice_parameter.h"
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
#ifdef __cplusplus
extern "C" {
#endif
void DoSliceFp16(const float16_t *input, float16_t *output, SliceParameter *param, int thread_id);
void DoSliceFp16NoParallel(const float16_t *input, float16_t *output, SliceParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_SLICE_FP16_H_

@ -1,114 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/base/slice_base.h"
#include <vector>
#include "src/runtime/kernel/arm/int8/slice_int8.h"
#include "src/runtime/kernel/arm/fp32/slice.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
int SliceBaseCPUKernel::Init() { return RET_OK; }
int SliceBaseCPUKernel::ReSize() {
auto input_shape = in_tensors_[0]->shape();
if (input_shape.size() > DIMENSION_4D) {
MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_4D;
return RET_ERROR;
}
for (size_t i = 0; i < input_shape.size(); ++i) {
param_->shape_[i] = input_shape[i];
}
if (param_->param_length_ < DIMENSION_4D) {
for (int i = param_->param_length_ - 1, j = 1; i >= 0; --i, ++j) {
param_->begin_[DIMENSION_4D - j] = param_->begin_[i];
param_->size_[DIMENSION_4D - j] = param_->size_[i];
}
for (int i = 0; i < DIMENSION_4D - param_->param_length_; i++) {
param_->begin_[i] = 0;
param_->size_[i] = 1;
}
}
param_->param_length_ = DIMENSION_4D;
for (int i = 0; i < DIMENSION_4D; ++i) {
if (param_->size_[i] < 0) {
param_->size_[i] = param_->shape_[i] - param_->begin_[i];
}
param_->end_[i] = param_->begin_[i] + param_->size_[i];
}
return RET_OK;
}
kernel::LiteKernel *CpuSliceInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Slice);
auto *kernel = new (std::nothrow) SliceInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (opParameter == nullptr) {
MS_LOG(ERROR) << "Input opParameter is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_Slice);
auto *kernel = new (std::nothrow) SliceCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Slice, CpuSliceInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,91 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp16/slice_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"
#include "src/kernel_registry.h"
#include "nnacl/fp16/cast_fp16.h"
#include "nnacl/fp16/slice_fp16.h"
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
int SliceFp16CPUKernel::SliceParallelRun(int thread_id) {
DoSliceFp16(input_fp16_, output_fp16_, param_, thread_id);
return RET_OK;
}
int SliceFp16CPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
input_fp16_ = ConvertInputFp32toFp16(in_tensors_.at(0), context_);
output_fp16_ = MallocOutputFp16(out_tensors_.at(0), context_);
if (input_fp16_ == nullptr || output_fp16_ == nullptr) {
FreeInputAndOutput();
MS_LOG(ERROR) << "input or output is nullptr";
return RET_ERROR;
}
if (param_->size_[1] < op_parameter_->thread_num_) {
DoSliceFp16NoParallel(input_fp16_, output_fp16_, param_);
return RET_OK;
}
ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "slice launch fail!ret: " << ret;
}
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
Float16ToFloat32(output_fp16_, reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()),
out_tensors_.at(0)->ElementsNum());
}
FreeInputAndOutput();
return ret;
}
void SliceFp16CPUKernel::FreeInputAndOutput() {
if (in_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(input_fp16_);
input_fp16_ = nullptr;
}
if (out_tensors_.at(0)->data_type() == kNumberTypeFloat32) {
context_->allocator->Free(output_fp16_);
output_fp16_ = nullptr;
}
}
kernel::LiteKernel *CpuSliceFp16KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) SliceFp16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceFp16CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Slice, CpuSliceFp16KernelCreator)
} // namespace mindspore::kernel

@ -13,32 +13,29 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/slice_parameter.h"
#include "src/runtime/kernel/arm/fp32/slice.h"
namespace mindspore::kernel {
class SliceBaseCPUKernel : public LiteKernel {
class SliceFp16CPUKernel : public SliceCPUKernel {
public:
SliceBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
SliceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<SliceParameter *>(op_parameter_);
}
~SliceBaseCPUKernel() = default;
: SliceCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~SliceFp16CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override { return 0; }
int Run() override;
int SliceParallelRun(int thread_id) override;
protected:
SliceParameter *param_;
void FreeInputAndOutput();
float16_t *input_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_SLICE_BASE_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_

@ -14,55 +14,41 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/slice.h"
#include <vector>
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "nnacl/fp32/slice.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
#include "src/ops/slice.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_NULL_PTR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
namespace {
int SliceLaunch(void *cdata, int task_id) {
if (cdata == nullptr) {
MS_LOG(ERROR) << "Input cdata is nullptr!";
return RET_NULL_PTR;
return RET_ERROR;
}
auto kernel = reinterpret_cast<SliceCPUKernel *>(cdata);
return kernel->SliceParallelRun(task_id);
}
} // namespace
int SliceCPUKernel::ReSize() {
auto primitive_slice = reinterpret_cast<const mindspore::lite::Slice *>(primitive_);
auto begin = primitive_slice->GetPostProcessBegin();
auto size = primitive_slice->GetPostProcessSize();
auto param = reinterpret_cast<SliceParameter *>(op_parameter_);
param->param_length_ = in_tensors_[0]->shape().size();
for (int i = 0; i < param->param_length_; ++i) {
param->begin_[i] = begin[i];
param->size_[i] = size[i];
}
auto input_shape = in_tensors_[0]->shape();
if (static_cast<int>(input_shape.size()) != param->param_length_) {
MS_LOG(ERROR) << "Input begin's lenth " << param->param_length_ << "is not equal to input shape size "
<< input_shape.size();
return RET_ERROR;
}
if (input_shape.size() > DIMENSION_4D) {
param_->param_length_ = in_tensors_.at(0)->shape().size();
if (param_->param_length_ > DIMENSION_4D) {
MS_LOG(ERROR) << "input dimension num should <= " << DIMENSION_4D;
return RET_ERROR;
}
for (size_t i = 0; i < input_shape.size(); ++i) {
param->shape_[i] = input_shape[i];
for (int i = 0; i < param_->param_length_; ++i) {
param_->shape_[i] = in_tensors_.at(0)->DimensionSize(i);
param_->begin_[i] = begin[i];
param_->size_[i] = size[i] < 0 ? param_->shape_[i] - param_->begin_[i] : size[i];
param_->end_[i] = param_->begin_[i] + param_->size_[i];
}
if (param_->param_length_ < DIMENSION_4D) {
PadSliceParameterTo4D(param_);
}
return RET_OK;
}
@ -77,8 +63,7 @@ int SliceCPUKernel::Init() {
int SliceCPUKernel::SliceParallelRun(int thread_id) {
const float *input_data = reinterpret_cast<const float *>(in_tensors_[0]->MutableData());
float *output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
SliceParameter *param = reinterpret_cast<SliceParameter *>(op_parameter_);
DoSlice(input_data, output_data, param, thread_id);
DoSlice(input_data, output_data, param_, thread_id);
return RET_OK;
}
@ -88,29 +73,38 @@ int SliceCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
SliceParameter *param = reinterpret_cast<SliceParameter *>(op_parameter_);
for (int i = 0; i < param->param_length_; ++i) {
if (param->size_[i] < 0) {
param->size_[i] = param->shape_[i] - param->begin_[i];
}
param->end_[i] = param->begin_[i] + param->size_[i];
}
if (param->param_length_ < DIMENSION_4D) {
PadSliceParameterTo4D(param);
}
const float *input_data = reinterpret_cast<const float *>(in_tensors_[0]->MutableData());
float *output_data = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
if (param->size_[1] < param->op_parameter_.thread_num_) {
DoSliceNoParallel(input_data, output_data, param);
if (param_->size_[1] < op_parameter_->thread_num_) {
DoSliceNoParallel(input_data, output_data, param_);
return RET_OK;
}
ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, param->op_parameter_.thread_num_);
ret = ParallelLaunch(this->context_->thread_pool_, SliceLaunch, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "slice launch fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}
kernel::LiteKernel *CpuSliceFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) SliceCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
delete kernel;
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Slice, CpuSliceFp32KernelCreator)
} // namespace mindspore::kernel

@ -18,22 +18,28 @@
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/slice_base.h"
#include "nnacl/slice_parameter.h"
namespace mindspore::kernel {
class SliceCPUKernel : public SliceBaseCPUKernel {
class SliceCPUKernel : public LiteKernel {
public:
SliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: SliceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
param_ = reinterpret_cast<SliceParameter *>(op_parameter_);
}
~SliceCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int SliceParallelRun(int thread_id);
virtual int SliceParallelRun(int thread_id);
protected:
SliceParameter *param_;
};
int SliceLaunch(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SLICE_H_

@ -16,23 +16,19 @@
#include "src/runtime/kernel/arm/int8/slice_int8.h"
#include <limits>
#include "nnacl/slice_parameter.h"
#include "src/kernel_registry.h"
#include "nnacl/int8/slice_int8.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
int SliceInt8CPUKernel::Init() {
auto ret = SliceBaseCPUKernel::Init();
if (ret != RET_OK) {
return ret;
}
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);
MS_ASSERT(input);
@ -54,8 +50,6 @@ int SliceInt8CPUKernel::Init() {
return ReSize();
}
int SliceInt8CPUKernel::ReSize() { return SliceBaseCPUKernel::ReSize(); }
int SliceInt8CPUKernel::DoSlice(int task_id) {
const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_[0]->MutableData());
int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->MutableData());
@ -97,4 +91,25 @@ int SliceInt8CPUKernel::Run() {
}
return ret;
}
kernel::LiteKernel *CpuSliceInt8KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) SliceInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new SliceInt8CPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Slice, CpuSliceInt8KernelCreator)
} // namespace mindspore::kernel

@ -18,20 +18,19 @@
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_SLICE_INT8_H_
#include <vector>
#include "src/runtime/kernel/arm/base/slice_base.h"
#include "src/runtime/kernel/arm/fp32/slice.h"
#include "nnacl/quantization/quantize.h"
namespace mindspore::kernel {
class SliceInt8CPUKernel : public SliceBaseCPUKernel {
class SliceInt8CPUKernel : public SliceCPUKernel {
public:
SliceInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: SliceBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
: SliceCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~SliceInt8CPUKernel() {}
int Init() override;
int ReSize() override;
int Run() override;
int DoSlice(int task_id);
};

Loading…
Cancel
Save