!9969 fix npu subgraph executor bug

From: @yeyunpeng2020
Reviewed-by: @HilbertDavid,@zhanghaibo5,@zhang_xue_tong
Signed-off-by: @zhang_xue_tong
pull/9969/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c3e7bbb4c2

@ -150,6 +150,7 @@ int Split::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outpu
for (size_t j = 0; j < size_splits_.size() - 1; ++j) {
split_dim_i -= size_splits_[j];
}
size_splits_[i] = split_dim_i;
} else {
split_dim_i = size_splits_[i];
}

@ -24,6 +24,7 @@ ge::Shape ConverterToNPUShape(const std::vector<int> &src_shape) {
}
return ge::Shape({shapes});
}
ge::Format ConverterToNPUFormat(schema::Format format) {
ge::Format ge_format;
switch (format) {
@ -74,13 +75,14 @@ ge::DataType ConverterToNPUDataType(TypeId type_id) {
}
return data_type;
}
hiai::op::Data *ConverterToNPUData(Tensor *src, const std::string &name) {
auto data = new (std::nothrow) hiai::op::Data(name);
if (data == nullptr) {
MS_LOG(ERROR) << "new data failed.";
return data;
}
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW,
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()),
ConverterToNPUDataType(src->data_type()));
data->update_input_desc_x(tensor_desc);
return data;
@ -92,7 +94,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) {
MS_LOG(ERROR) << "new ge_tensor failed.";
return ge_tensor;
}
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ge::FORMAT_NCHW,
ge::TensorDesc tensor_desc(ConverterToNPUShape(src->shape()), ConverterToNPUFormat(src->format()),
ConverterToNPUDataType(src->data_type()));
ge_tensor->SetTensorDesc(tensor_desc);
@ -102,62 +104,7 @@ std::shared_ptr<ge::Tensor> ConverterToNPUTensor(Tensor *src) {
}
return ge_tensor;
}
/*
* mode : Activation mode, with options as follows:
* 0 : Sigmoid
* 1 : ReLU
* 2 : Tanh
* 3 : Clipped ReLU
* 4 : ELU
* 5 : PReLU
* 6 : Abs
* 7 : Relu1
* 8 : Softsign
* 9 : Softplus
* 10 : Hardsigmoid
* 11 : Threshold ReLU
* 12 : Selu
* 13 : Linear
* 14 : Relu6
* 15 : GeLU.
*/
int ConverterToNPUActMode(schema::ActivationType type) {
switch (type) {
case schema::ActivationType_NO_ACTIVATION:
return -1;
case schema::ActivationType_SIGMOID:
return 0;
case schema::ActivationType_RELU:
return 1;
case schema::ActivationType_TANH:
return 2;
case schema::ActivationType_ELU:
return 4;
case schema::ActivationType_LEAKY_RELU:
return 5;
case schema::ActivationType_ABS:
return 6;
case schema::ActivationType_RELU1:
return 7;
case schema::ActivationType_SOFTSIGN:
return 8;
case schema::ActivationType_SOFTPLUS:
return 9;
case schema::ActivationType_HSIGMOID:
return 10;
case schema::ActivationType_THRESHOLDRELU:
return 11;
case schema::ActivationType_SELU:
return 12;
case schema::ActivationType_LINEAR:
return 13;
case schema::ActivationType_RELU6:
return 14;
default:
MS_LOG(ERROR) << "Unsupport activation type to NPU." << type;
return -1;
}
}
// mode : Either 0 (product), 1 (sum), 2 (max), 3 (mean). Defaults to 1 (sum).
int ConverterToNPUEltwiseMode(schema::EltwiseMode mode) {
int mode_num = 1;

@ -53,6 +53,7 @@ int NPUExecutor::Run(std::vector<Tensor *> &in_tensors, std::vector<Tensor *> &o
for (int i = 0; i < npu_output_tensors_.size(); ++i) {
memcpy(out_tensors[i]->MutableData(), npu_output_tensors_[i]->GetBuffer(), npu_output_tensors_[i]->GetSize());
out_tensors[i]->ResetRefCount();
}
return RET_OK;

@ -83,16 +83,22 @@ int SubGraphNpuKernel::BuildNPUInputOp() {
for (auto in_tensor : node->in_tensors()) {
if (IsSubGraphInputTensor(in_tensor)) {
auto tensor_name = node->name() + "_" + std::to_string(count++);
auto shape = in_tensor->shape();
hiai::op::Data *data;
if (trans_nodes.find(node->Type()) != trans_nodes.end()) {
in_tensor->set_shape({shape[0], shape[3], shape[1], shape[2]});
auto shape = in_tensor->shape();
data = new (std::nothrow) hiai::op::Data(tensor_name);
if (data == nullptr) {
MS_LOG(ERROR) << "New data failed.";
return RET_ERROR;
}
ge::TensorDesc tensor_desc(lite::ConverterToNPUShape({shape[0], shape[3], shape[1], shape[2]}),
ge::FORMAT_NCHW, lite::ConverterToNPUDataType(in_tensor->data_type()));
data->update_input_desc_x(tensor_desc);
} else {
data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name);
}
auto data = mindspore::lite::ConverterToNPUData(in_tensor, tensor_name);
subgraph_input_op_.push_back(*data);
node_input_op.push_back(data);
if (trans_nodes.find(node->Type()) != trans_nodes.end()) {
in_tensor->set_shape(shape);
}
continue;
}
@ -120,13 +126,11 @@ int SubGraphNpuKernel::BuildNPUInputOp() {
// weight tensor
if (is_weight_tensor) {
if (!(node->Type() == schema::PrimitiveType_Conv2D || node->Type() == schema::PrimitiveType_DeConv2D ||
node->Type() == schema::PrimitiveType_DepthwiseConv2D ||
node->Type() == schema::PrimitiveType_DeDepthwiseConv2D)) {
if (trans_nodes.find(node->Type()) == trans_nodes.end()) {
auto name = node->name() + "_" + std::to_string(count++);
auto weight_const = new (std::nothrow) hiai::op::Const(node->name() + "_" + std::to_string(count++));
if (weight_const == nullptr) {
MS_LOG(ERROR) << "new weight const failed.";
MS_LOG(ERROR) << "New weight const failed.";
return RET_ERROR;
}
auto weight_tensor = mindspore::lite::ConverterToNPUTensor(in_tensor);

@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Concat;
namespace mindspore::kernel {
int ConcatNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
return RET_ERROR;
}
int ConcatNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,

@ -25,7 +25,7 @@ using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel {
int ConvolutionDepthwiseNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
return RET_OK;
return RET_ERROR;
}
int ConvolutionDepthwiseNPUKernel::SetConvDwParam() {

@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
int ConvolutionNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter) {
return RET_OK;
return RET_ERROR;
}
int ConvolutionNPUKernel::SetConvParam() {

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/matmul_npu.h"
#include "src/kernel_registry.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_MatMul;
namespace mindspore::kernel {
int MatMulNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
}
int MatMulNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::MatMul(name_);
op_->set_input_x1(*npu_inputs[0]);
op_->set_input_x2(*npu_inputs[1]);
op_->set_attr_transpose_x1(a_transpose_);
op_->set_attr_transpose_x2(b_transpose_);
return RET_OK;
}
ge::Operator *mindspore::kernel::MatMulNPUKernel::GetNPUOp() { return this->op_; }
MatMulNPUKernel::~MatMulNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_MatMul, NPUKernelCreator<MatMulNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_
#include <vector>
#include "nnacl/matmul_parameter.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "nnacl/softmax_parameter.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class MatMulNPUKernel : public NPUKernel {
public:
MatMulNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
auto matmul_parameter = reinterpret_cast<MatMulParameter *>(parameter);
a_transpose_ = matmul_parameter->a_transpose_;
b_transpose_ = matmul_parameter->b_transpose_;
}
~MatMulNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::MatMul *op_ = nullptr;
bool a_transpose_ = false;
bool b_transpose_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_MATMUL_NPU_H_

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/pad_npu.h"
#include <memory>
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Pad;
namespace mindspore::kernel {
int PadNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (padding_mode_ != schema::PaddingMode_CONSTANT) {
MS_LOG(WARNING) << "NPU only support CONSTANT padding mode";
return RET_ERROR;
}
return RET_OK;
}
int PadNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::PadV2(name_);
if (op_ == nullptr) {
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
int size = static_cast<int>(paddings_.size() / 2);
ge::TensorDesc padding_tensor_desc(ge::Shape({size, 2}), ge::FORMAT_NCHW, ge::DT_INT32);
ge::TensorPtr padding_tensor = std::make_shared<hiai::Tensor>(padding_tensor_desc);
padding_tensor->SetData(reinterpret_cast<uint8_t *>(paddings_.data()), size * sizeof(int));
auto paddings = new hiai::op::Const(name_ + "paddings");
paddings->set_attr_value(padding_tensor);
ge::TensorDesc constant_values_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::TensorPtr constant_values_tensor = std::make_shared<hiai::Tensor>(constant_values_tensor_desc);
vector<float> constant_values_data_value = {constant_value_};
constant_values_tensor->SetData(reinterpret_cast<uint8_t *>(constant_values_data_value.data()), 1 * sizeof(float));
auto constant = new hiai::op::Const(name_ + "constant");
constant->set_attr_value(constant_values_tensor);
op_->set_input_x(*npu_inputs[0]);
op_->set_input_constant_values(*constant);
op_->set_input_paddings(*paddings);
return RET_OK;
}
ge::Operator *mindspore::kernel::PadNPUKernel::GetNPUOp() { return this->op_; }
PadNPUKernel::~PadNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Pad, NPUKernelCreator<PadNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_
#include <vector>
#include "nnacl/pad_parameter.h"
#include "src/ops/pad.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class PadNPUKernel : public NPUKernel {
public:
PadNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
auto pad = reinterpret_cast<const mindspore::lite::Pad *>(primitive);
constant_value_ = pad->GetConstantValue();
paddings_ = pad->GetPaddings();
padding_mode_ = pad->GetPaddingMode();
}
~PadNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::PadV2 *op_ = nullptr;
std::vector<int> paddings_;
int padding_mode_;
float constant_value_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_PAD_NPU_H_

@ -24,7 +24,7 @@ using mindspore::schema::PrimitiveType_Pooling;
namespace mindspore::kernel {
int PoolingNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
return RET_ERROR;
}
int PoolingNPUKernel::SetPoolingParam() {

@ -0,0 +1,54 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/slice_npu.h"
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Slice;
namespace mindspore::kernel {
int SliceNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
}
int SliceNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::Slice(name_);
if (op_ == nullptr) {
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
op_->set_input_x(*npu_inputs[0]);
op_->set_input_offsets(*npu_inputs[1]);
op_->set_input_size(*npu_inputs[2]);
return RET_OK;
}
ge::Operator *mindspore::kernel::SliceNPUKernel::GetNPUOp() { return this->op_; }
SliceNPUKernel::~SliceNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Slice, NPUKernelCreator<SliceNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_
#include <vector>
#include "src/ops/slice.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class SliceNPUKernel : public NPUKernel {
public:
SliceNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~SliceNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::Slice *op_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SLICE_NPU_H_

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/split_npu.h"
#include <memory>
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Split;
namespace mindspore::kernel {
int SplitNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
return RET_OK;
}
int SplitNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::SplitV(name_);
if (op_ == nullptr) {
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
int size = size_splits_.size();
ge::TensorDesc size_splits_tensor_desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32);
ge::TensorPtr size_splits_tensor = std::make_shared<hiai::Tensor>(size_splits_tensor_desc);
size_splits_tensor->SetData(reinterpret_cast<uint8_t *>(size_splits_.data()), size * sizeof(int));
auto size_splits = new hiai::op::Const(name_ + "_size");
size_splits->set_attr_value(size_splits_tensor);
ge::TensorDesc split_dim_tensor_desc(ge::Shape({1}), ge::FORMAT_NCHW, ge::DT_INT32);
ge::TensorPtr split_dim_tensor = std::make_shared<hiai::Tensor>(split_dim_tensor_desc);
vector<int32_t> split_dim_data_value = {split_dim_};
split_dim_tensor->SetData(reinterpret_cast<uint8_t *>(split_dim_data_value.data()), 1 * sizeof(int));
auto split_dim = new hiai::op::Const(name_ + "_dim");
split_dim->set_attr_value(split_dim_tensor);
op_->set_input_x(*npu_inputs[0]);
op_->set_attr_num_split(num_split_);
op_->set_input_split_dim(*split_dim);
op_->set_input_size_splits(*size_splits);
op_->create_dynamic_output_y(num_split_);
return RET_OK;
}
ge::Operator *mindspore::kernel::SplitNPUKernel::GetNPUOp() { return this->op_; }
SplitNPUKernel::~SplitNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Split, NPUKernelCreator<SplitNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_
#include <vector>
#include "src/ops/split.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class SplitNPUKernel : public NPUKernel {
public:
SplitNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
auto split = reinterpret_cast<const mindspore::lite::Split *>(primitive);
num_split_ = split->GetNumberSplit();
size_splits_ = split->GetSizeSplit();
split_dim_ = split->GetSplitDim();
}
~SplitNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::SplitV *op_ = nullptr;
int num_split_;
std::vector<int> size_splits_;
int split_dim_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_SPLIT_NPU_H_

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/transpose_npu.h"
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Nchw2Nhwc;
using mindspore::schema::PrimitiveType_Nhwc2Nchw;
using mindspore::schema::PrimitiveType_Transpose;
namespace mindspore::kernel {
int TransposeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (conjugate_) {
MS_LOG(ERROR) << "Unsupported conjugate transpose.";
return RET_ERROR;
}
return RET_OK;
}
int TransposeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::Permute(name_);
if (op_ == nullptr) {
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
op_->set_input_x(*npu_inputs[0]);
op_->set_attr_order(perm_);
return RET_OK;
}
ge::Operator *mindspore::kernel::TransposeNPUKernel::GetNPUOp() { return this->op_; }
TransposeNPUKernel::~TransposeNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Transpose, NPUKernelCreator<TransposeNPUKernel>)
// REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, NPUKernelCreator<TransposeNPUKernel>)
// REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, NPUKernelCreator<TransposeNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,56 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_
#include <vector>
#include "nnacl/transpose.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class TransposeNPUKernel : public NPUKernel {
public:
TransposeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
if (primitive->Type() == schema::PrimitiveType_Transpose) {
auto transpose_parameter = reinterpret_cast<TransposeParameter *>(parameter);
conjugate_ = transpose_parameter->conjugate_;
for (int i = 0; i < transpose_parameter->num_axes_; i++) {
perm_.push_back(transpose_parameter->perm_[i]);
}
} else if (primitive->Type() == schema::PrimitiveType_Nchw2Nhwc) {
perm_ = {0, 2, 3, 1};
} else if (primitive->Type() == schema::PrimitiveType_Nhwc2Nchw) {
perm_ = {0, 3, 1, 2};
}
}
~TransposeNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::Permute *op_ = nullptr;
std::vector<int64_t> perm_;
bool conjugate_ = false;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_TRANSPOSE_NPU_H_

@ -0,0 +1,66 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/npu/unsqueeze_npu.h"
#include <memory>
#include "src/kernel_registry.h"
#include "src/runtime/agent/npu/npu_converter_utils.h"
using mindspore::kernel::KERNEL_ARCH::kNPU;
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_Unsqueeze;
namespace mindspore::kernel {
int UnsqueezeNPUKernel::IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) {
if (inputs[0]->shape().size() > 3) {
MS_LOG(WARNING) << "The dimension of output not support bigger than 4.";
return RET_ERROR;
}
return RET_OK;
}
int UnsqueezeNPUKernel::SetNPUInputs(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) {
op_ = new (std::nothrow) hiai::op::ExpandDims(name_);
if (op_ == nullptr) {
MS_LOG(ERROR) << name_ << " op is nullptr";
return RET_ERROR;
}
int size = axis_.size();
ge::TensorDesc desc(ge::Shape({size}), ge::FORMAT_NCHW, ge::DT_INT32);
ge::TensorPtr tensor = std::make_shared<hiai::Tensor>(desc);
tensor->SetData(reinterpret_cast<uint8_t *>(axis_.data()), size * sizeof(int));
auto axis = new hiai::op::Const(name_ + "_axis");
axis->set_attr_value(tensor);
op_->set_input_x(*npu_inputs[0]);
op_->set_input_axis(*axis);
return RET_OK;
}
ge::Operator *mindspore::kernel::UnsqueezeNPUKernel::GetNPUOp() { return this->op_; }
UnsqueezeNPUKernel::~UnsqueezeNPUKernel() {
if (op_ != nullptr) {
delete op_;
op_ = nullptr;
}
}
REG_KERNEL(kNPU, kNumberTypeFloat32, PrimitiveType_Unsqueeze, NPUKernelCreator<UnsqueezeNPUKernel>)
} // namespace mindspore::kernel

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_
#include <vector>
#include "src/ops/unsqueeze.h"
#include "src/runtime/kernel/npu/npu_kernel.h"
#include "include/graph/op/all_ops.h"
namespace mindspore::kernel {
class UnsqueezeNPUKernel : public NPUKernel {
public:
UnsqueezeNPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: NPUKernel(parameter, inputs, outputs, ctx, primitive) {
auto unsqueeze = reinterpret_cast<const mindspore::lite::Unsqueeze *>(primitive);
axis_ = unsqueeze->GetAxis();
}
~UnsqueezeNPUKernel() override;
int IsSupport(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
OpParameter *opParameter) override;
int SetNPUInputs(const std::vector<lite::Tensor *> &inputs, const std::vector<lite::Tensor *> &outputs,
const std::vector<ge::Operator *> &npu_inputs) override;
ge::Operator *GetNPUOp() override;
private:
hiai::op::ExpandDims *op_ = nullptr;
vector<int> axis_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_NPU_UNSQUEEZE_NPU_H_
Loading…
Cancel
Save