add splice_fp32 kernel to lite runtime/arm/fp32

pull/14714/head
z00512249 4 years ago
parent 68676a0d99
commit 2616aede84

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/splice_fp32.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#include "nnacl/fp32/splice_fp32.h"
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Splice;
namespace mindspore::kernel {
int SpliceCPUKernel::Init() { return RET_OK; }
int SpliceCPUKernel::ReSize() { return RET_OK; }
int SpliceCPUKernel::Run() {
lite::Tensor *input_tensor = in_tensors_.front();
lite::Tensor *output_tensor = out_tensors_.front();
if (input_tensor->data_c() == nullptr || output_tensor->data_c() == nullptr) {
MS_LOG(ERROR) << "splice kernel input or output data is nullptr";
return RET_ERROR;
}
std::vector<int> src_shape = input_tensor->shape();
std::vector<int> dst_shape = output_tensor->shape();
if (src_shape.size() != dst_shape.size() || src_shape.size() != kInputSize2 || dst_shape.size() != kInputSize2) {
MS_LOG(ERROR) << "splice kernel src_shape size not equal to dst_shape size";
return RET_ERROR;
}
int src_row = src_shape.at(kWeightIndex);
int dst_row = dst_shape.at(kWeightIndex);
int src_col = src_shape.at(kBiasIndex);
int dst_col = dst_shape.at(kBiasIndex);
if (src_col * parameter_->context_dim_ != dst_col) {
MS_LOG(ERROR) << "splice kernel src_col not match dst_col";
return RET_ERROR;
}
auto input_data = reinterpret_cast<float *>(input_tensor->data_c());
auto output_data = reinterpret_cast<float *>(output_tensor->data_c());
SpliceFp32(input_data, src_row, src_col, parameter_, output_data, dst_row, dst_col);
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeFloat, PrimitiveType_Splice, LiteKernelCreator<SpliceCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Splice, LiteKernelCreator<SpliceCPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/splice_parameter.h"
namespace mindspore::kernel {
class SpliceCPUKernel final : public LiteKernel {
public:
SpliceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {
parameter_ = reinterpret_cast<SpliceParameter *>(parameter);
}
~SpliceCPUKernel() override = default;
int Init() override;
int ReSize() override;
int Run() override;
private:
SpliceParameter *parameter_{nullptr};
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_SPLICE_SPLICE_FP32_H
Loading…
Cancel
Save