diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.cc new file mode 100644 index 0000000000..25adbbbb5e --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/fp32/splice_fp32.h" +#include "schema/model_generated.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#include "nnacl/fp32/splice_fp32.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_Splice; +namespace mindspore::kernel { +int SpliceCPUKernel::Init() { return RET_OK; } + +int SpliceCPUKernel::ReSize() { return RET_OK; } + +int SpliceCPUKernel::Run() { + lite::Tensor *input_tensor = in_tensors_.front(); + lite::Tensor *output_tensor = out_tensors_.front(); + if (input_tensor->data_c() == nullptr || output_tensor->data_c() == nullptr) { + MS_LOG(ERROR) << "splice kernel input or output data is nullptr"; + return RET_ERROR; + } + std::vector src_shape = input_tensor->shape(); + std::vector dst_shape = output_tensor->shape(); + if (src_shape.size() != dst_shape.size() || src_shape.size() != kInputSize2 || dst_shape.size() != kInputSize2) { + MS_LOG(ERROR) << "splice kernel src_shape size not equal to dst_shape size"; + return RET_ERROR; + } + int src_row = src_shape.at(kWeightIndex); + int dst_row = dst_shape.at(kWeightIndex); + int src_col = src_shape.at(kBiasIndex); + int dst_col = dst_shape.at(kBiasIndex); + if (src_col * parameter_->context_dim_ != dst_col) { + MS_LOG(ERROR) << "splice kernel src_col not match dst_col"; + return RET_ERROR; + } + auto input_data = reinterpret_cast(input_tensor->data_c()); + auto output_data = reinterpret_cast(output_tensor->data_c()); + SpliceFp32(input_data, src_row, src_col, parameter_, output_data, dst_row, dst_col); + return RET_OK; +} +REG_KERNEL(kCPU, kNumberTypeFloat, PrimitiveType_Splice, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Splice, LiteKernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.h new file mode 100644 index 0000000000..c52106bd41 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/splice_fp32.h @@ -0,0 +1,40 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H +#include +#include "src/lite_kernel.h" +#include "nnacl/splice_parameter.h" +namespace mindspore::kernel { +class SpliceCPUKernel final : public LiteKernel { + public: + SpliceCPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx) + : LiteKernel(parameter, inputs, outputs, ctx) { + parameter_ = reinterpret_cast(parameter); + } + ~SpliceCPUKernel() override = default; + + int Init() override; + int ReSize() override; + int Run() override; + + private: + SpliceParameter *parameter_{nullptr}; +}; +} // namespace mindspore::kernel +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H