!6080 [MS][LITE][Develop]support nnacl internal kernels

Merge pull request !6080 from chenjianping/lite_dev2
pull/6080/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9e104137ac

@ -8,8 +8,10 @@ file(GLOB_RECURSE C_SRC ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/*.c
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32/*.c
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/fp32_grad/*.c
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/int8/*.c
${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/quantization/*.c
${CMAKE_CURRENT_SOURCE_DIR}/src/kernel/fp32/*.cc
)
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../nnacl/opt_op_handler.c)

@ -84,7 +84,7 @@ typedef struct LiteSession {
/// \param[in] inputs Define the new inputs shape.
///
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
int Resize(const TensorPtrVector &inputs);
int Resize(const TensorPtrVector &inputs, Int32VectorVector dims);
} LiteSession;
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H

@ -27,5 +27,6 @@ using String = std::string;
using StringVector = std::vector<std::string>;
using ShapeVector = std::vector<int>;
using NodePtrVector = std::vector<struct Node *>;
using Int32Vector = std::vector<int32_t>;
using Int32VectorVector = std::vector<Int32Vector>;
#endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_

@ -27,6 +27,183 @@ enum NodeType {
NodeType_MAX = NodeType_CNode
};
enum KernelType {
Concat,
SoftMax,
Activation,
Conv2D,
FusedBatchNorm,
BatchNorm,
BiasAdd,
Pooling,
ROIPooling,
DepthwiseConv2D,
DeDepthwiseConv2D,
Resize,
DetectionPostProcess,
FullConnection,
Mean,
DeConv2D,
Scale,
Reshape,
Eltwise,
NetOutput,
Add,
Sub,
MatMul,
StridedSlice,
Power,
Slice,
Stack,
Mul,
RealDiv,
Pad,
Maximum,
Minimum,
PReLU,
LeakyReLU,
ArgMax,
ArgMin,
Exp,
Crop,
Range,
Rsqrt,
ExpandDims,
Tile,
Cast,
Shape,
Nchw2Nhwc,
Nhwc2Nchw,
QuantDTypeCast,
Split,
Permute,
FakeQuantWithMinMaxVars,
Equal,
Less,
Greater,
NotEqual,
LessEqual,
GreaterEqual,
Min,
Floor,
Abs,
Neg,
Cos,
Sin,
Sqrt,
Square,
Constant,
Log,
Tan,
Atan,
Asin,
Clip,
Transpose,
Squeeze,
Unsqueeze,
Upsample,
Dropout,
Broadcast,
BroadcastTo,
Lrn,
ZerosLike,
TopK,
SpaceToDepth,
SpaceToBatch,
SparseToDense,
ReverseSequence,
Rank,
Gather,
GatherNd,
Fill,
Elu,
DepthToSpace,
BatchToSpace,
AddN,
Ceil,
EmbeddingLookup,
EmbeddingLookupSparse,
FloorDiv,
FloorMod,
L2Norm,
LocalResponseNormalization,
MatrixDiag,
Reduce,
Reverse,
Round,
Select,
Scatter,
ScatterND,
ConstantOfShape,
Unique,
Unstack,
LogicalAnd,
LogicalOr,
LogicalXor,
LogicalNot,
OnnxInt8Quantize,
OnnxInt8Dequantize,
FakeQuantWithMinMax,
FakeQuantWithMinMaxPerChannel,
BatchNormFold,
MulFold,
AddFold,
SquaredDifference,
Flatten,
FlattenGrad,
TupleGetItem,
Div,
Where,
OneHot,
Lstm,
Conv2DGradFilter,
Conv2DGradInput,
PoolingGrad,
BNGrad,
BNGradInput,
ApplyMomentum,
BiasGrad,
SoftmaxCrossEntropy,
AddGrad,
SubGrad,
MulGrad,
DivGrad,
PowerGrad,
ActivationGrad,
PriorBox,
SpaceToBatchND,
Depend,
Return,
MakeTuple,
ToFormat,
Proposal,
Custom,
BlackBox,
NegGrad,
LogGrad,
BatchToSpaceND,
};
enum ActivationType {
NO_ACTIVATION = 0,
RELU = 1,
SIGMOID = 2,
RELU6 = 3,
ELU = 4,
LEAKY_RELU = 5,
ABS = 6,
RELU1 = 7,
SOFTSIGN = 8,
SOFTPLUS = 9,
TANH = 10,
SELU = 11,
HSWISH = 12,
HSIGMOID = 13,
THRESHOLDRELU = 14,
LINEAR = 15,
UNKNOW = 16
};
typedef struct Node {
String name_;
NodeType node_type_;

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "internal/src/kernel/fp32/activation.h"
#include "internal/include/errorcode.h"
#include "internal/include/ms_tensor.h"
#include "nnacl/fp32/activation.h"
#include "utils/log_adapter.h"
#include "nnacl/errorcode.h"
int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator) {
ActivationParameter *param = (ActivationParameter *)node->primitive_;
int ret = RET_OK;
size_t length = in_tensors[0]->ElementsNum();
float *input_addr = (float *)in_tensors[0]->data_;
float *output_addr = (float *)out_tensors[0]->data_;
if (param->type_ == ActivationType::RELU) {
ret = Fp32Relu(input_addr, length, output_addr);
} else if (param->type_ == ActivationType::SIGMOID) {
ret = Sigmoid(input_addr, length, output_addr);
} else if (param->type_ == ActivationType::RELU6) {
ret = Fp32Relu6(input_addr, length, output_addr);
} else if (param->type_ == ActivationType::LEAKY_RELU) {
float alpha = param->alpha_;
ret = LRelu(input_addr, length, output_addr, alpha);
} else {
MS_LOG(ERROR) << "Unsupport activation type " << param->type_;
return RET_PARAM_INVALID;
}
if (ret != NNACL_OK) {
MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_
#include "internal/include/model.h"
#include "src/runtime/allocator.h"
int DoActivation(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator);
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ACTIVATION_H_

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "internal/src/kernel/fp32/arithmetic_self.h"
#include "internal/include/errorcode.h"
#include "internal/include/ms_tensor.h"
#include "utils/log_adapter.h"
#include "nnacl/fp32/arithmetic_self.h"
int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator) {
size_t data_size = in_tensors[0]->ElementsNum();
OpParameter *param = node->primitive_;
int ret;
if (param->type_ == KernelType::Log) {
ret = ElementLog((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size);
} else if (param->type_ == KernelType::Neg) {
ret = ElementNegative((float *)in_tensors[0]->data_, (float *)out_tensors[0]->data_, data_size);
} else {
MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_;
return RET_PARAM_INVALID;
}
if (ret != NNACL_OK) {
MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_
#include "internal/include/model.h"
#include "src/runtime/allocator.h"
int DoArithmeticSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator);
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_ARITHMETIC_SELF_H_

@ -0,0 +1,145 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "internal/src/kernel/fp32/matmul.h"
#include "nnacl/fp32/matmul.h"
#include "internal/include/errorcode.h"
#include "internal/include/ms_tensor.h"
#include "utils/log_adapter.h"
typedef struct MatMulCPUKernelData {
float *a_c12_ptr_;
float *b_r8_ptr_;
float *bias_ptr_;
} MatMulCPUKernelData;
void MatMulInitMatrixA(float *src_ptr, float *dst_ptr, MatMulParameter *params) {
for (int i = 0; i < params->batch; i++) {
float *src = src_ptr + i * params->deep_ * params->row_;
float *dst = dst_ptr + i * params->deep_ * params->row_12_;
if (params->a_transpose_) {
RowMajor2Row12Major(src, dst, params->deep_, params->row_);
} else {
RowMajor2Col12Major(src, dst, params->row_, params->deep_);
}
}
}
void MatMulInitMatrixB(float *src_ptr, float *dst_ptr, MatMulParameter *params) {
for (int i = 0; i < params->batch; i++) {
float *src = src_ptr + i * params->deep_ * params->col_;
float *dst = dst_ptr + i * params->deep_ * params->col_8_;
if (params->b_transpose_) {
RowMajor2Col8Major(src, dst, params->col_, params->deep_);
} else {
RowMajor2Row8Major(src, dst, params->deep_, params->col_);
}
}
}
void FreeMatMulKernelData(MatMulCPUKernelData *kernel_data, mindspore::lite::Allocator *allocator) {
if (kernel_data == NULL) {
return;
}
if (kernel_data->a_c12_ptr_ != NULL) {
allocator->Free(kernel_data->a_c12_ptr_);
kernel_data->a_c12_ptr_ = NULL;
}
if (kernel_data->b_r8_ptr_ != NULL) {
allocator->Free(kernel_data->b_r8_ptr_);
kernel_data->b_r8_ptr_ = NULL;
}
if (kernel_data->bias_ptr_ != NULL) {
allocator->Free(kernel_data->bias_ptr_);
kernel_data->bias_ptr_ = NULL;
}
free(kernel_data);
}
int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator) {
if (in_tensors[0]->data_ == NULL || in_tensors[1]->data_ ==NULL) {
MS_LOG(ERROR) << "input data is NULL!";
return RET_PARAM_INVALID;
}
if (allocator == NULL) {
MS_LOG(ERROR) << "input allocator is NULL!";
return RET_PARAM_INVALID;
}
int batch = 1;
std::vector<int> a_shape = in_tensors[0]->shape_;
std::vector<int> c_shape = out_tensors[0]->shape_;
if (in_tensors.size() == 3) {
std::vector<int> bias_shape = in_tensors[2]->shape_;
if (bias_shape[bias_shape.size() - 1] != c_shape[c_shape.size() - 1]) {
MS_LOG(ERROR) << "The bias' dimension is not equal with column";
return RET_INPUT_TENSOR_ERROR;
}
}
for (size_t i = 0; i < a_shape.size() - 2; ++i) {
batch *= a_shape[i];
}
MatMulParameter *params = (MatMulParameter *)node->primitive_;
params->batch = batch;
params->row_ = c_shape[c_shape.size() - 2];
params->col_ = c_shape[c_shape.size() - 1];
params->deep_ = params->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params->row_12_ = UP_ROUND(params->row_, C12NUM);
params->col_8_ = UP_ROUND(params->col_, 8);
MatMulCPUKernelData *kernel_data = (MatMulCPUKernelData *)malloc(sizeof(MatMulCPUKernelData));
kernel_data->a_c12_ptr_
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->row_12_ * params->deep_ * sizeof(float)));
if (kernel_data->a_c12_ptr_ == NULL) {
return RET_MEMORY_FAILED;
}
memset(kernel_data->a_c12_ptr_, 0, params->row_12_ * params->deep_ * sizeof(float));
kernel_data->b_r8_ptr_
= reinterpret_cast<float *>(allocator->Malloc(params->batch * params->col_8_ * params->deep_ * sizeof(float)));
if (kernel_data->b_r8_ptr_ == NULL) {
FreeMatMulKernelData(kernel_data, allocator);
return RET_MEMORY_FAILED;
}
memset(kernel_data->b_r8_ptr_, 0, params->col_8_ * params->deep_ * sizeof(float));
MatMulInitMatrixA((float *)in_tensors[0]->data_, kernel_data->a_c12_ptr_, params);
MatMulInitMatrixB((float *)in_tensors[1]->data_, kernel_data->b_r8_ptr_, params);
kernel_data->bias_ptr_ = (float *)(allocator->Malloc(params->col_8_ * sizeof(float)));
if (kernel_data->bias_ptr_ == NULL) {
FreeMatMulKernelData(kernel_data, allocator);
return RET_MEMORY_FAILED;
}
memset(kernel_data->bias_ptr_, 0, params->col_8_ * sizeof(float));
if (in_tensors.size() == 3) {
memcpy(kernel_data->bias_ptr_, in_tensors[2]->data_, params->col_ * sizeof(float));
}
auto c_src = (float *)out_tensors[0]->data_;
for (int i = 0; i < params->batch; ++i) {
float *a_ptr = kernel_data->a_c12_ptr_ + i * params->row_12_ * params->deep_;
float *b_ptr = kernel_data->b_r8_ptr_ + i * params->deep_ * params->col_8_;
float *c_ptr = c_src + i * params->row_ * params->col_;
MatMulOpt(a_ptr, b_ptr, c_ptr, kernel_data->bias_ptr_, ActType_No, params->deep_, params->row_, params->col_,
params->col_, OutType_Nhwc);
}
return RET_OK;
}

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_
#include "internal/include/model.h"
#include "src/runtime/allocator.h"
int DoMatMul(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator);
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_MATMUL_H_

@ -0,0 +1,50 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "internal/src/kernel/fp32_grad/activation_grad.h"
#include "internal/include/errorcode.h"
#include "internal/include/ms_tensor.h"
#include "nnacl/fp32_grad/activation_grad.h"
#include "utils/log_adapter.h"
#include "nnacl/errorcode.h"
int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator) {
ActivationGradParameter *param = (ActivationGradParameter *)node->primitive_;
int ret = RET_OK;
size_t length = in_tensors[0]->ElementsNum();
float *dy_data = (float *)in_tensors[0]->data_;
float *x_data = (float *)in_tensors[1]->data_;
float *dx_data = (float *)(float *)out_tensors[0]->data_;
if (param->type_ == ActivationType::RELU) {
ret = ReluGrad(dy_data, x_data, length, dx_data);
} else if (param->type_ == ActivationType::SIGMOID) {
ret = SigmoidGrad(dy_data, x_data, length, dx_data);
} else if (param->type_ == ActivationType::RELU6) {
ret = Relu6Grad(dy_data, x_data, length, dx_data);
} else if (param->type_ == ActivationType::LEAKY_RELU) {
float alpha = param->alpha_;
ret = LReluGrad(dy_data, x_data, length, dx_data, alpha);
} else {
MS_LOG(ERROR) << "Unsupport activation type " << param->type_;
return RET_PARAM_INVALID;
}
if (ret != NNACL_OK) {
MS_LOG(ERROR) << "do activation(" << param->type_ << ") fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_
#include "internal/include/model.h"
#include "src/runtime/allocator.h"
int DoActivationGrad(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator);
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ACTIVATION_GRAD_H_

@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h"
#include "internal/include/errorcode.h"
#include "internal/include/ms_tensor.h"
#include "utils/log_adapter.h"
#include "nnacl/fp32/arithmetic_self.h"
#include "nnacl/fp32/arithmetic.h"
int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator) {
size_t data_size = in_tensors[0]->ElementsNum();
OpParameter *param = node->primitive_;
float *dy_data = (float *)in_tensors[0]->data_;
float *x_data = (float *)in_tensors[1]->data_;
float *dx_data = (float *)(float *)out_tensors[0]->data_;
int ret;
if (param->type_ == KernelType::LogGrad) {
ret = ElementDiv(dy_data, x_data, dx_data, data_size);
} else if (param->type_ == KernelType::NegGrad) {
ret = ElementNegative(dy_data, dx_data, data_size);
} else {
MS_LOG(ERROR) << "Unsupport kernel type: " << param->type_;
return RET_PARAM_INVALID;
}
if (ret != NNACL_OK) {
MS_LOG(ERROR) << "do arithmetic " << param->type_ << " fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_
#define MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_
#include "internal/include/model.h"
#include "src/runtime/allocator.h"
int DoArithmeticGradSelf(TensorPtrVector in_tensors, TensorPtrVector out_tensors, Node *node,
mindspore::lite::Allocator *allocator);
#endif // MINDSPORE_LITE_INTERNAL_SRC_KERNEL_FP32_GRAD_ARITHMETIC_SELF_GRAD_H_

@ -17,6 +17,13 @@
#include "internal/include/model.h"
#include "internal/include/ms_tensor.h"
#include "src/runtime/allocator.h"
#include "internal/include/errorcode.h"
#include "utils/log_adapter.h"
#include "internal/src/kernel/fp32/activation.h"
#include "internal/src/kernel/fp32/arithmetic_self.h"
#include "internal/src/kernel/fp32/matmul.h"
#include "internal/src/kernel/fp32_grad/arithmetic_self_grad.h"
#include "internal/src/kernel/fp32_grad/activation_grad.h"
static Context *g_Ctx;
static Model *g_Model;
@ -58,11 +65,56 @@ TensorPtrVector LiteSession::GetOutputs() const {
int LiteSession::RunGraph() {
// invoke nnacl kernel
return 0;
NodePtrVector nodes = g_Model->nodes_;
size_t nodes_size = nodes.size();
for (size_t i = 0; i < nodes_size; ++i) {
auto node = nodes[i];
if (node->primitive_ == nullptr) {
MS_LOG(ERROR) << "node's primitive is NULL!";
return RET_ERROR;
}
TensorPtrVector in_tensors;
for (size_t j = 0; j < node->input_indices_.size(); ++j) {
in_tensors.push_back(g_Model->all_tensors_[node->input_indices_[j]]);
}
TensorPtrVector out_tensors;
for (size_t j = 0; j < node->output_indices_.size(); ++j) {
out_tensors.push_back(g_Model->all_tensors_[node->output_indices_[j]]);
}
int type = node->primitive_->type_;
int ret = RET_ERROR;
switch (type) {
case KernelType::MatMul:
ret = DoMatMul(in_tensors, out_tensors, node, &allocator);
break;
case KernelType::Activation:
ret = DoActivation(in_tensors, out_tensors, node, &allocator);
break;
case KernelType::Log:
case KernelType::Neg:
ret = DoArithmeticSelf(in_tensors, out_tensors, node, &allocator);
break;
case KernelType::LogGrad:
case KernelType::NegGrad:
ret = DoArithmeticGradSelf(in_tensors, out_tensors, node, &allocator);
break;
case KernelType::ActivationGrad:
ret = DoActivationGrad(in_tensors, out_tensors, node, &allocator);
break;
default:
MS_LOG(ERROR) << "Unsupport kernel type: " << type;
return RET_PARAM_INVALID;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "run kernel fail!ret: " << ret;
return ret;
}
}
return RET_OK;
}
StringVector LiteSession::GetOutputTensorNames() const { return StringVector(); }
MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { return NULL; }
int LiteSession::Resize(const TensorPtrVector &inputs) { return 0; }
int LiteSession::Resize(const TensorPtrVector &inputs, Int32VectorVector dims) { return 0; }

Loading…
Cancel
Save