diff --git a/mindspore/lite/nnacl/transpose.c b/mindspore/lite/nnacl/fp32/transpose.c similarity index 74% rename from mindspore/lite/nnacl/transpose.c rename to mindspore/lite/nnacl/fp32/transpose.c index aa1f8fdb4d..e70f796aff 100644 --- a/mindspore/lite/nnacl/transpose.c +++ b/mindspore/lite/nnacl/fp32/transpose.c @@ -14,12 +14,10 @@ * limitations under the License. */ -#include "nnacl/transpose.h" -#include -#include "nnacl/errorcode.h" +#include "nnacl/fp32/transpose.h" -void TransposeDim2(const float *in_data, float *out_data, const int *strides, int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end) { +void TransposeDim2Fp32(const float *in_data, float *out_data, const int *strides, int *out_strides, const int *perm, + const int *output_shape, int h_start, int h_end) { const int stride0 = strides[perm[0]]; const int stride1 = strides[perm[1]]; const int output0 = output_shape[0]; @@ -33,8 +31,8 @@ void TransposeDim2(const float *in_data, float *out_data, const int *strides, in } } -void TransposeDim3(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end) { +void TransposeDim3Fp32(const float *in_data, float *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { const int stride0 = strides[perm[0]]; const int stride1 = strides[perm[1]]; const int stride2 = strides[perm[2]]; @@ -56,8 +54,8 @@ void TransposeDim3(const float *in_data, float *out_data, const int *strides, co } } -void TransposeDim4(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end) { +void TransposeDim4Fp32(const float *in_data, float *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { const int stride0 = strides[perm[0]]; const int stride1 = strides[perm[1]]; const int stride2 = strides[perm[2]]; @@ -88,8 +86,8 @@ void TransposeDim4(const float *in_data, float *out_data, const int *strides, co } } -void TransposeDim5(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end) { +void TransposeDim5Fp32(const float *in_data, float *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { const int stride0 = strides[perm[0]]; const int stride1 = strides[perm[1]]; const int stride2 = strides[perm[2]]; @@ -127,8 +125,9 @@ void TransposeDim5(const float *in_data, float *out_data, const int *strides, co } } -void TransposeDims(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end, int dims, int *size, int *position) { +void TransposeDimsFp32(const float *in_data, float *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end, int dims, int *size, + int *position) { *(size + dims - 1) = 1; for (int i = dims - 1; i > 0; --i) { *(size + i - 1) = *(size + i) * output_shape[i]; @@ -149,8 +148,8 @@ void TransposeDims(const float *in_data, float *out_data, const int *strides, co } } -int DoTranspose(const float *in_data, float *out_data, int *input_shape, const int *output_shape, - TransposeParameter *transpose_param, int h_start, int h_end, int *size, int *position) { +int DoTransposeFp32(const float *in_data, float *out_data, int *input_shape, const int *output_shape, + TransposeParameter *transpose_param, int h_start, int h_end, int *size, int *position) { if (in_data == NULL || out_data == NULL) { return NNACL_ERR; } @@ -178,16 +177,16 @@ int DoTranspose(const float *in_data, float *out_data, int *input_shape, const i return NNACL_OK; } if (num_axes == 2) { - TransposeDim2(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + TransposeDim2Fp32(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); } else if (num_axes == 3) { - TransposeDim3(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + TransposeDim3Fp32(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); } else if (num_axes == 4) { - TransposeDim4(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + TransposeDim4Fp32(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); } else if (num_axes == 5) { - TransposeDim5(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + TransposeDim5Fp32(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); } else { - TransposeDims(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end, num_axes, size, - position); + TransposeDimsFp32(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end, num_axes, size, + position); } return NNACL_OK; } diff --git a/mindspore/lite/nnacl/fp32/transpose.h b/mindspore/lite/nnacl/fp32/transpose.h new file mode 100644 index 0000000000..0ce061740a --- /dev/null +++ b/mindspore/lite/nnacl/fp32/transpose.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_FP32_TRANSPOSE_H_ +#define MINDSPORE_LITE_NNACL_FP32_TRANSPOSE_H_ + +#include +#include "nnacl/transpose.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeFp32(const float *in_data, float *out_data, int *input_shape, const int *output_shape, + TransposeParameter *transpose_param, int h_start, int h_end, int *size, int *position); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP32_TRANSPOSE_H_ diff --git a/mindspore/lite/nnacl/int8/l2_norm_int8.c b/mindspore/lite/nnacl/int8/l2_norm_int8.c index f290ff2225..4f86e657bb 100644 --- a/mindspore/lite/nnacl/int8/l2_norm_int8.c +++ b/mindspore/lite/nnacl/int8/l2_norm_int8.c @@ -13,45 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include #include "nnacl/int8/l2_norm_int8.h" +#include #include "nnacl/quantization/fixed_point.h" #include "nnacl/errorcode.h" -void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) { - if (input <= 1) { - *multiplier = INT_MAX; - *shift = 0; - } - *shift = 11; - while (input >= (1 << 29)) { - input /= 4; - ++*shift; - } - int max_left_shift_bits = CountLeadingSignBits(input); - int left_shift_bit_pairs = max_left_shift_bits / 2 - 1; - *shift -= left_shift_bit_pairs; - input <<= 2 * left_shift_bit_pairs; - int32_t fixedpoint_f3_input = input >> 1; // sign: 1 bit, integer: 3 bit, fractional: 28 bit - int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1); - int32_t fp_f3_half_three = (1 << 28) + (1 << 27); - int32_t tmp = (1 << 28); // one - for (int i = 0; i < 5; i++) { - int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3); - tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) - - SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3), - 6, 3); - } - const int32_t fp_f0_half_sqrt_2 = 1518500250; // sqrt(2) / 2 - tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2); - *multiplier = tmp; - if (*shift < 0) { - *multiplier <<= -*shift; - *shift = 0; - } - *shift *= reverse_shift; -} - int L2NormalizationInt8(const int8_t *input_data, int8_t *output_data, const L2NormParameter *param, const L2NormQuantArg *quant_param, const int begin, const int end) { const int inner_size = param->shape_[param->shape_num_ - 1]; diff --git a/mindspore/lite/nnacl/int8/layer_norm_int8.c b/mindspore/lite/nnacl/int8/layer_norm_int8.c new file mode 100644 index 0000000000..1d0e7ce0f4 --- /dev/null +++ b/mindspore/lite/nnacl/int8/layer_norm_int8.c @@ -0,0 +1,64 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/layer_norm_int8.h" + +/* + * origin : (x-mean) / sqrt(variance + epsilon) * gamma + beta + * quant : (x-mean) / sqrt(sum(x * x) - mean * mean) * gamma + beta + * + * */ +int LayerNormInt8(const int8_t *src_data, const int8_t *gamma_data, const int32_t *beta_data, int8_t *dst_data, + bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant_) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + + if (affine && (gamma_data == NULL || beta_data == NULL)) { + return NNACL_NULL_PTR; + } + + for (int out_index = 0; out_index < outer_size; out_index++) { + const int8_t *src = src_data + out_index * inner_size; + int8_t *dst = dst_data + out_index * inner_size; + int32_t mean = 0; + int32_t square_mean = 0; + for (int in_index = 0; in_index < inner_size; in_index++) { + int32_t tmp_src = src[in_index] - quant_->in_quant_arg_.zp_; + mean += tmp_src; + square_mean += tmp_src * tmp_src; + } + mean = round(mean / inner_size); + square_mean = round(square_mean / inner_size); + + int32_t variance_value = square_mean - mean * mean; + + int32_t multiplier; + int32_t shift; + GetSqrtQuantMultiplierExp(variance_value, -1, &multiplier, &shift); + + for (int in_index = 0; in_index < inner_size; in_index++) { + int32_t in = src[in_index] - quant_->in_quant_arg_.zp_ - mean; + int32_t tmp = RoundingDivideByPOT(SaturatingRoundingDoublingHighMul(in * (1 << 7), multiplier), -shift); + if (affine) { + tmp = tmp * (gamma_data[in_index] - quant_->gamma_quant_arg_.zp_) + beta_data[in_index]; + } + int32_t out = MultiplyByQuantizedMultiplier(tmp, quant_->multiplier_, quant_->shift_left_, quant_->shift_right_); + dst[in_index] = (int8_t)MSMIN(quant_->output_activation_max_, MSMAX(quant_->output_activation_max_, out)); + } + } + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/int8/layer_norm_int8.h b/mindspore/lite/nnacl/int8/layer_norm_int8.h new file mode 100644 index 0000000000..15cee4313e --- /dev/null +++ b/mindspore/lite/nnacl/int8/layer_norm_int8.h @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_INT8_LAYER_NORM_H_ +#define MINDSPORE_LITE_NNACL_INT8_LAYER_NORM_H_ + +#include "nnacl/errorcode.h" +#include "nnacl/layer_norm_parameter.h" +#include "nnacl/quantization/fixed_point.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int LayerNormInt8(const int8_t *src_data, const int8_t *gamma_data, const int32_t *beta_data, int8_t *dst_data, + bool affine, int outer_size, int inner_size, LayerNormQuantArg *quant_); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INT8_LAYER_NORM_H_ diff --git a/mindspore/lite/nnacl/int8/transpose_int8.c b/mindspore/lite/nnacl/int8/transpose_int8.c new file mode 100644 index 0000000000..10f2f2b20a --- /dev/null +++ b/mindspore/lite/nnacl/int8/transpose_int8.c @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/int8/transpose_int8.h" +void TransposeDim2Int8(const int8_t *in_data, int8_t *out_data, const int *strides, int *out_strides, const int *perm, + const int *output_shape, int h_start, int h_end) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * output1; + int stride0_i = i * 1 * stride0; + for (int j = 0; j < output1; ++j) { + out_data[out_stride0_i + j] = in_data[stride0_i + j * stride1]; + } + } + return; +} + +void TransposeDim3Int8(const int8_t *in_data, int8_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + out_data[out_stride0_i + out_stride1_j + k] = in_data[stride0_i + stride1_j + k * stride2]; + } + } + } +} + +void TransposeDim4Int8(const int8_t *in_data, int8_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + m] = + in_data[stride0_i + stride1_j + stride2_k + m * stride3]; + } + } + } + } +} + +void TransposeDim5Int8(const int8_t *in_data, int8_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end) { + const int stride0 = strides[perm[0]]; + const int stride1 = strides[perm[1]]; + const int stride2 = strides[perm[2]]; + const int stride3 = strides[perm[3]]; + const int stride4 = strides[perm[4]]; + const int out_stride0 = out_strides[0]; + const int out_stride1 = out_strides[1]; + const int out_stride2 = out_strides[2]; + const int out_stride3 = out_strides[3]; + const int output0 = output_shape[0]; + const int output1 = output_shape[1]; + const int output2 = output_shape[2]; + const int output3 = output_shape[3]; + const int output4 = output_shape[4]; + + for (int i = 0; i < output0; ++i) { + int out_stride0_i = i * out_stride0; + int stride0_i = i * stride0; + for (int j = 0; j < output1; ++j) { + int out_stride1_j = j * out_stride1; + int stride1_j = j * stride1; + for (int k = 0; k < output2; ++k) { + int out_stride2_k = k * out_stride2; + int stride2_k = k * stride2; + for (int m = 0; m < output3; ++m) { + int out_stride3_m = m * out_stride3; + int stride3_m = m * stride3; + for (int n = 0; n < output4; ++n) { + out_data[out_stride0_i + out_stride1_j + out_stride2_k + out_stride3_m + n] = + in_data[stride0_i + stride1_j + stride2_k + stride3_m + n * stride4]; + } + } + } + } + } +} + +void TransposeCommInt8(const int8_t *in_data, int8_t *out_data, const int *strides, const int *out_strides, + const int *perm, const int *output_shape, int h_start, int h_end, int dims, int *size, + int *position) { + *(size + dims - 1) = 1; + for (int i = dims - 1; i > 0; --i) { + *(size + i - 1) = *(size + i) * output_shape[i]; + } + + for (size_t idx = 0; idx < (*size) * output_shape[0]; ++idx) { + int pos = idx; + int output_idx = 0; + int input_idx = 0; + for (int i = 0; i < dims; ++i) { + *(position + i) = pos / *(size + i); + int out_stride = i < dims - 1 ? out_strides[i] : 1; + output_idx += (*(position + i) * out_stride); + input_idx += (*(position + i) * strides[perm[i]]); + pos -= *(position + i) * (*(size + i)); + } + out_data[output_idx] = in_data[input_idx]; + } +} + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, int *input_shape, const int *output_shape, + TransposeParameter *transpose_param, int h_start, int h_end, int *dim_size, int *position) { + if (in_data == NULL || out_data == NULL) { + return NNACL_ERR; + } + + int *perm = transpose_param->perm_; + int *strides = transpose_param->strides_; + int *out_strides = transpose_param->out_strides_; + int num_axes = transpose_param->num_axes_; + + if (num_axes < 2) { + return NNACL_ERR; + } + + // check if transpose is needed + bool needTranspose = false; + for (int i = 1; i < num_axes; i++) { + if (perm[i] - perm[i - 1] != 1) { + needTranspose = true; + break; + } + } + + if (!needTranspose) { + (void)memcpy(out_data, in_data, transpose_param->data_size_); + return NNACL_OK; + } + + switch (num_axes) { + case 2: + TransposeDim2Int8(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + break; + case 3: + TransposeDim3Int8(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + break; + case 4: + TransposeDim4Int8(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + break; + case 5: + TransposeDim5Int8(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end); + break; + default: + TransposeCommInt8(in_data, out_data, strides, out_strides, perm, output_shape, h_start, h_end, num_axes, dim_size, + position); + break; + } + + return NNACL_OK; +} diff --git a/mindspore/lite/nnacl/int8/transpose_int8.h b/mindspore/lite/nnacl/int8/transpose_int8.h new file mode 100644 index 0000000000..1379ed3c28 --- /dev/null +++ b/mindspore/lite/nnacl/int8/transpose_int8.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NNACL_INT8_TRANSPOSE_INT8_H_ +#define MINDSPORE_LITE_NNACL_INT8_TRANSPOSE_INT8_H_ + +#include +#include "nnacl/transpose.h" +#include "nnacl/errorcode.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int DoTransposeInt8(const int8_t *in_data, int8_t *out_data, int *input_shape, const int *output_shape, + TransposeParameter *transpose_param, int h_start, int h_end, int *dim_size, int *position); + +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_INT8_TRANSPOSE_INT8_H_ diff --git a/mindspore/lite/nnacl/layer_norm_parameter.h b/mindspore/lite/nnacl/layer_norm_parameter.h index 8dd513b1a2..fcfc5cbd3a 100644 --- a/mindspore/lite/nnacl/layer_norm_parameter.h +++ b/mindspore/lite/nnacl/layer_norm_parameter.h @@ -17,6 +17,7 @@ #define MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_ #include "nnacl/op_base.h" +#include "nnacl/quantization/quantize.h" typedef struct LayerNormParameter { OpParameter op_parameter_; @@ -24,6 +25,21 @@ typedef struct LayerNormParameter { int normalized_dims_; float epsilon_; bool elementwise_affine_; + int thread_count_; + int thread_outsize_; } LayerNormParameter; +typedef struct LayerNormQuantArg { + QuantArg in_quant_arg_; + QuantArg out_quant_arg_; + QuantArg gamma_quant_arg_; + + int32_t multiplier_; + int32_t shift_left_; + int32_t shift_right_; + + int output_activation_min_; + int output_activation_max_; +} LayerNormQuantArg; + #endif // MINDSPORE_LITE_NNACL_LAYER_NORM_PARAMETER_H_ diff --git a/mindspore/lite/nnacl/quantization/fixed_point.c b/mindspore/lite/nnacl/quantization/fixed_point.c index 164cd2170c..52adfa8dec 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.c +++ b/mindspore/lite/nnacl/quantization/fixed_point.c @@ -202,6 +202,40 @@ int exp_on_negative_values(int a, const int integer_bits) { return result; } +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift) { + if (input <= 1) { + *multiplier = INT_MAX; + *shift = 0; + } + *shift = 11; + while (input >= (1 << 29)) { + input /= 4; + ++*shift; + } + int max_left_shift_bits = CountLeadingSignBits(input); + int left_shift_bit_pairs = max_left_shift_bits / 2 - 1; + *shift -= left_shift_bit_pairs; + input <<= 2 * left_shift_bit_pairs; + int32_t fixedpoint_f3_input = input >> 1; // sign: 1 bit, integer: 3 bit, fractional: 28 bit + int32_t fp_f3_half_input = SaturatingRoundingMultiplyByPOT(fixedpoint_f3_input, -1); + int32_t fp_f3_half_three = (1 << 28) + (1 << 27); + int32_t tmp = (1 << 28); // one + for (int i = 0; i < 5; i++) { + int32_t tmp3 = Rescale(SaturatingRoundingDoublingHighMul(tmp, SaturatingRoundingDoublingHighMul(tmp, tmp)), 9, 3); + tmp = Rescale(SaturatingRoundingDoublingHighMul(fp_f3_half_three, tmp) - + SaturatingRoundingDoublingHighMul(fp_f3_half_input, tmp3), + 6, 3); + } + const int32_t fp_f0_half_sqrt_2 = 1518500250; // sqrt(2) / 2 + tmp = SaturatingRoundingDoublingHighMul(tmp, fp_f0_half_sqrt_2); + *multiplier = tmp; + if (*shift < 0) { + *multiplier <<= -*shift; + *shift = 0; + } + *shift *= reverse_shift; +} + #ifdef ENABLE_NEON int32x4_t RoundingDivideByPOTInt32x4(int32x4_t x, int exponent) { const int32x4_t shift_vec = vdupq_n_s32(-exponent); diff --git a/mindspore/lite/nnacl/quantization/fixed_point.h b/mindspore/lite/nnacl/quantization/fixed_point.h index 9e9121a32f..5a2848312f 100644 --- a/mindspore/lite/nnacl/quantization/fixed_point.h +++ b/mindspore/lite/nnacl/quantization/fixed_point.h @@ -52,6 +52,8 @@ int32_t ComputerReciprocal(int32_t x, int x_digits, int *recip_shift); int exp_on_negative_values(int a, const int tIntegerBits); +void GetSqrtQuantMultiplierExp(int32_t input, int reverse_shift, int32_t *multiplier, int32_t *shift); + #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/transpose.h b/mindspore/lite/nnacl/transpose.h index ed71f0eb37..f17fcb4314 100644 --- a/mindspore/lite/nnacl/transpose.h +++ b/mindspore/lite/nnacl/transpose.h @@ -19,6 +19,8 @@ #include "nnacl/op_base.h" +#define MAX_TRANSPOSE_DIM_SIZE 5 + typedef struct TransposeParameter { OpParameter op_parameter_; int perm_[8]; @@ -29,23 +31,4 @@ typedef struct TransposeParameter { int data_size_; } TransposeParameter; -#ifdef __cplusplus -extern "C" { -#endif -int DoTranspose(const float *in_data, float *out_data, int *input_shape, const int *output_shape, - TransposeParameter *transpose_param, int h_start, int h_end, int *size, int *position); -void TransposeDim2(const float *in_data, float *out_data, const int *strides, int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end); -void TransposeDim3(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end); -void TransposeDim4(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end); -void TransposeDim5(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end); -void TransposeDims(const float *in_data, float *out_data, const int *strides, const int *out_strides, const int *perm, - const int *output_shape, int h_start, int h_end, int dims, int *size, int *position); -#ifdef __cplusplus -} -#endif - #endif // MINDSPORE_LITE_NNACL_TRANSPOSE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc index 3af55274b4..c727970ba3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc @@ -76,7 +76,7 @@ int LayerNormCPUKernel::Run() { dst_data_ = reinterpret_cast(out_tensors_.at(0)->MutableData()); auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormRun, this, op_parameter_->thread_num_); if (ret != RET_OK) { - MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]"; + MS_LOG(ERROR) << "LayerNormRun error error_code[" << ret << "]"; return ret; } return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index ed9f0ce29f..290ab4e626 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -15,12 +15,8 @@ */ #include "src/runtime/kernel/arm/fp32/transpose_fp32.h" - -#include -#include "nnacl/transpose.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" -#include "include/errorcode.h" #include "src/runtime/runtime_api.h" using mindspore::lite::KernelRegistrar; @@ -30,9 +26,6 @@ using mindspore::lite::RET_OP_EXECUTE_FAILURE; using mindspore::schema::PrimitiveType_Transpose; namespace mindspore::kernel { -namespace { -constexpr int maxDimSize = 5; -} // namespace int TransposeCPUKernel::Init() { if (!InferShapeDone()) { @@ -103,9 +96,8 @@ int TransposeCPUKernel::TransposeParallel(int task_id) { size = this->dim_size_ + task_id * param->num_axes_; position = this->position_ + task_id * param->num_axes_; } - - auto ret = DoTranspose(in_data_, out_data_, in_shape_, out_shape_, param, thread_offset, - thread_offset + num_unit_thread, size, position); + auto ret = DoTransposeFp32(in_data_, out_data_, in_shape_, out_shape_, param, thread_offset, + thread_offset + num_unit_thread, size, position); if (ret != RET_OK) { MS_LOG(ERROR) << "Transpose error task_id[" << task_id << "] error_code[" << ret << "]"; return RET_ERROR; @@ -113,7 +105,7 @@ int TransposeCPUKernel::TransposeParallel(int task_id) { return RET_OK; } -int TransposeRun(void *cdata, int task_id) { +int TransposeFp32Run(void *cdata, int task_id) { auto g_kernel = reinterpret_cast(cdata); auto ret = g_kernel->TransposeParallel(task_id); if (ret != RET_OK) { @@ -135,7 +127,7 @@ int TransposeCPUKernel::Run() { in_data_ = reinterpret_cast(in_tensor->MutableData()); out_data_ = reinterpret_cast(out_tensor->MutableData()); int dims = out_tensor->shape().size(); - if (dims > maxDimSize) { + if (dims > MAX_TRANSPOSE_DIM_SIZE) { dim_size_ = reinterpret_cast(context_->allocator->Malloc(dims * thread_h_num_ * sizeof(int))); if (dim_size_ == nullptr) { MS_LOG(ERROR) << "Malloc data failed"; @@ -150,8 +142,8 @@ int TransposeCPUKernel::Run() { } } - auto ret = ParallelLaunch(this->context_->thread_pool_, TransposeRun, this, thread_h_num_); - if (dims > maxDimSize) { + auto ret = ParallelLaunch(this->context_->thread_pool_, TransposeFp32Run, this, thread_h_num_); + if (dims > MAX_TRANSPOSE_DIM_SIZE) { context_->allocator->Free(dim_size_); context_->allocator->Free(position_); dim_size_ = nullptr; @@ -162,7 +154,7 @@ int TransposeCPUKernel::Run() { return ret; } return ret; -} // namespace mindspore::kernel +} kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h index 7c49497c32..e0067a879c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.h @@ -18,12 +18,12 @@ #define MINDSPORE_CCSRC_KERNEL_CPU_ARM_FP32_TRANSPOSE_H_ #include +#include "include/errorcode.h" +#include "nnacl/fp32/transpose.h" #include "src/lite_kernel.h" - #include "src/kernel_registry.h" namespace mindspore::kernel { - class TransposeCPUKernel : public LiteKernel { public: explicit TransposeCPUKernel(OpParameter *param, const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc new file mode 100644 index 0000000000..e8a4c947af --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc @@ -0,0 +1,140 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/runtime/kernel/arm/int8/layer_norm_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::schema::PrimitiveType_LayerNorm; + +namespace mindspore::kernel { +void LayerNormInt8CPUKernel::SetQuantArgs() { + lite::Tensor *input = in_tensors_.at(0); + lite::Tensor *output = out_tensors_.at(0); + + quant_param_.in_quant_arg_.zp_ = input->GetQuantParams().front().zeroPoint; + quant_param_.in_quant_arg_.scale_ = input->GetQuantParams().front().scale; + quant_param_.out_quant_arg_.zp_ = output->GetQuantParams().front().zeroPoint; + quant_param_.out_quant_arg_.scale_ = output->GetQuantParams().front().scale; + + quant_param_.output_activation_min_ = std::numeric_limits::min(); + quant_param_.output_activation_max_ = std::numeric_limits::max(); + + if (param_->elementwise_affine_) { + lite::Tensor *gamma_tensor = out_tensors_.at(1); + quant_param_.gamma_quant_arg_.zp_ = gamma_tensor->GetQuantParams().front().zeroPoint; + quant_param_.gamma_quant_arg_.scale_ = gamma_tensor->GetQuantParams().front().scale; + } + + double in_scale; + if (param_->elementwise_affine_) { + in_scale = static_cast(quant_param_.in_quant_arg_.scale_ * quant_param_.gamma_quant_arg_.scale_); + } else { + in_scale = static_cast(quant_param_.in_quant_arg_.scale_); + } + double real_multiplier = in_scale / static_cast(quant_param_.out_quant_arg_.scale_); + + QuantizeRoundParameter(real_multiplier, &quant_param_.multiplier_, &quant_param_.shift_left_, + &quant_param_.shift_right_); + return; +} + +int LayerNormInt8CPUKernel::Init() { + SetQuantArgs(); + + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int LayerNormInt8CPUKernel::ReSize() { + auto shape = in_tensors_.front()->shape(); + outer_size_ = 1; + inner_size_ = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (i + param_->normalized_dims_ < shape.size()) { + outer_size_ *= shape[i]; + } else { + inner_size_ *= shape[i]; + } + } + + param_->thread_count_ = MSMIN(outer_size_, op_parameter_->thread_num_); + param_->thread_outsize_ = UP_DIV(outer_size_, param_->thread_count_); + return RET_OK; +} + +int LayerNormInt8Run(void *cdata, int task_id) { + auto kernel = reinterpret_cast(cdata); + kernel->DoExecute(task_id); + return RET_OK; +} + +int LayerNormInt8CPUKernel::DoExecute(int task_id) { + int current_out_size = outer_size_ - task_id * param_->thread_outsize_; + current_out_size = MSMIN(current_out_size, param_->thread_outsize_); + if (current_out_size <= 0) { + return RET_OK; + } + + const int8_t *thread_src = src_ptr_ + task_id * param_->thread_outsize_ * inner_size_; + int8_t *thread_dst = dst_ptr_ + task_id * param_->thread_outsize_ * inner_size_; + + LayerNormInt8(thread_src, gamma_ptr_, beta_ptr_, thread_dst, param_->elementwise_affine_, current_out_size, + inner_size_, &quant_param_); + return RET_OK; +} + +int LayerNormInt8CPUKernel::Run() { + src_ptr_ = reinterpret_cast(in_tensors_.at(0)->MutableData()); + dst_ptr_ = reinterpret_cast(out_tensors_.at(0)->MutableData()); + if (param_->elementwise_affine_) { + gamma_ptr_ = reinterpret_cast(in_tensors_.at(1)->MutableData()); + beta_ptr_ = reinterpret_cast(in_tensors_.at(2)->MutableData()); + } + + auto ret = ParallelLaunch(this->context_->thread_pool_, LayerNormInt8Run, this, op_parameter_->thread_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "LayerNormInt8Run error error_code[" << ret << "]"; + return ret; + } + return RET_OK; +} + +kernel::LiteKernel *CpuLayerNormInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *parameter, + const lite::InnerContext *ctx, const KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + auto *kernel = new (std::nothrow) LayerNormInt8CPUKernel(parameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "kernel is nullptr."; + free(parameter); + return nullptr; + } + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ + << ", type: " << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LayerNorm, CpuLayerNormInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.h new file mode 100644 index 0000000000..e50b07d5b9 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_LAYERNORM_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_LAYERNORM_INT8_H_ + +#include +#include +#include "nnacl/int8/layer_norm_int8.h" +#include "src/kernel_registry.h" +#include "include/errorcode.h" + +namespace mindspore::kernel { +class LayerNormInt8CPUKernel : public LiteKernel { + public: + LayerNormInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + param_ = reinterpret_cast(parameter); + } + ~LayerNormInt8CPUKernel() override{}; + + int Init() override; + int ReSize() override; + int Run() override; + + public: + int DoExecute(int task_id); + + private: + void SetQuantArgs(); + + private: + LayerNormParameter *param_ = nullptr; + LayerNormQuantArg quant_param_; + int outer_size_; + int inner_size_; + int8_t *src_ptr_ = nullptr; + int8_t *dst_ptr_ = nullptr; + int8_t *gamma_ptr_ = nullptr; + int32_t *beta_ptr_ = nullptr; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_LAYERNORM_INT8_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc new file mode 100644 index 0000000000..089055b30a --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.cc @@ -0,0 +1,184 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/runtime/kernel/arm/int8/transpose_int8.h" +#include "src/runtime/runtime_api.h" + +using mindspore::lite::KernelRegistrar; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::lite::RET_OP_EXECUTE_FAILURE; +using mindspore::schema::PrimitiveType_Transpose; + +namespace mindspore::kernel { + +TransposeInt8CPUKernel::~TransposeInt8CPUKernel() { return; } + +int TransposeInt8CPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} + +int TransposeInt8Run(void *cdata, int task_id) { + auto transpose_int8 = reinterpret_cast(cdata); + auto ret = transpose_int8->DoTranspose(task_id); + if (ret != RET_OK) { + MS_LOG(ERROR) << "DoTranspose error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_OP_EXECUTE_FAILURE; + } + return RET_OK; +} + +void TransposeInt8CPUKernel::FreeTmpBuf() { + if (!extra_dims_) { + return; + } + if (dim_size_ != nullptr) { + context_->allocator->Free(dim_size_); + dim_size_ = nullptr; + } + if (position_ != nullptr) { + context_->allocator->Free(position_); + position_ = nullptr; + } + return; +} + +int TransposeInt8CPUKernel::MallocTmpBuf() { + if (!extra_dims_) { + return RET_OK; + } + + int dims = out_tensors_[0]->shape().size(); + + dim_size_ = reinterpret_cast(context_->allocator->Malloc(dims * thread_h_num_ * sizeof(int))); + if (dim_size_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + return RET_ERROR; + } + position_ = reinterpret_cast(context_->allocator->Malloc(dims * thread_h_num_ * sizeof(int))); + if (position_ == nullptr) { + MS_LOG(ERROR) << "Malloc data failed"; + context_->allocator->Free(dim_size_); + dim_size_ = nullptr; + return RET_ERROR; + } + return RET_OK; +} + +int TransposeInt8CPUKernel::ReSize() { + auto in_tensor = in_tensors_.front(); + auto out_tensor = out_tensors_.front(); + auto in_shape = in_tensor->shape(); + auto out_shape = out_tensor->shape(); + + transpose_param_->data_size_ = in_tensor->Size(); + + transpose_param_->strides_[transpose_param_->num_axes_ - 1] = 1; + transpose_param_->out_strides_[transpose_param_->num_axes_ - 1] = 1; + for (int i = transpose_param_->num_axes_ - 2; i >= 0; i--) { + transpose_param_->strides_[i] = in_shape[i + 1] * transpose_param_->strides_[i + 1]; + transpose_param_->out_strides_[i] = out_shape[i + 1] * transpose_param_->out_strides_[i + 1]; + } + + extra_dims_ = out_shape.size() > MAX_TRANSPOSE_DIM_SIZE; + + num_unit_ = static_cast(in_shape.at(transpose_param_->perm_[kNHWC_H])); + thread_h_num_ = MSMIN(thread_num_, num_unit_); + thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); + return RET_OK; +} + +int TransposeInt8CPUKernel::DoTranspose(int task_id) { + int num_unit_thread = MSMIN(thread_h_stride_, num_unit_ - task_id * thread_h_stride_); + if (num_unit_thread <= 0) { + return RET_OK; + } + int thread_offset = task_id * thread_h_stride_; + + int *dim_size = nullptr; + int *position = nullptr; + if (extra_dims_) { + dim_size = dim_size_ + task_id * transpose_param_->num_axes_; + position = position_ + task_id * transpose_param_->num_axes_; + } + + auto ret = DoTransposeInt8(in_ptr_, out_ptr_, in_shape_, out_shape_, transpose_param_, thread_offset, + thread_offset + num_unit_thread, dim_size, position); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Transpose error task_id[" << task_id << "] error_code[" << ret << "]"; + return RET_ERROR; + } + + return RET_OK; +} + +int TransposeInt8CPUKernel::Run() { + auto in_tensor = in_tensors_.front(); + auto out_tensor = out_tensors_.front(); + + in_ptr_ = reinterpret_cast(in_tensor->data_c()); + out_ptr_ = reinterpret_cast(out_tensor->data_c()); + + in_shape_ = in_tensor->shape().data(); + out_shape_ = out_tensor->shape().data(); + + int ret = MallocTmpBuf(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "MallocTmpBuf error_code[" << ret << "]"; + } + + ret = ParallelLaunch(this->context_->thread_pool_, TransposeInt8Run, this, thread_h_num_); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Tranpose error error_code[" << ret << "]"; + } + + FreeTmpBuf(); + in_shape_ = nullptr; + out_shape_ = nullptr; + return ret; +} + +kernel::LiteKernel *CpuTransposeInt8KernelCreator(const std::vector &inputs, + const std::vector &outputs, OpParameter *opParameter, + const lite::InnerContext *ctx, const kernel::KernelKey &desc, + const mindspore::lite::PrimitiveC *primitive) { + MS_ASSERT(desc.type == schema::PrimitiveType_Transpose); + if (opParameter == nullptr) { + MS_LOG(ERROR) << "desc type is not Transpose"; + return nullptr; + } + auto *kernel = new (std::nothrow) TransposeInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (kernel == nullptr) { + MS_LOG(ERROR) << "New kernel fails."; + free(opParameter); + return nullptr; + } + + auto ret = kernel->Init(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + delete kernel; + return nullptr; + } + return kernel; +} + +REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Transpose, CpuTransposeInt8KernelCreator) +} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h new file mode 100644 index 0000000000..18d3fb9899 --- /dev/null +++ b/mindspore/lite/src/runtime/kernel/arm/int8/transpose_int8.h @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TRANSPOSE_INT8_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TRANSPOSE_INT8_H_ + +#include +#include "nnacl/int8/transpose_int8.h" +#include "src/kernel_registry.h" +#include "src/lite_kernel.h" +#include "include/errorcode.h" + +namespace mindspore::kernel { +class TransposeInt8CPUKernel : public LiteKernel { + public: + TransposeInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, + const std::vector &outputs, const lite::InnerContext *ctx, + const mindspore::lite::PrimitiveC *primitive) + : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + transpose_param_ = reinterpret_cast(op_parameter_); + } + ~TransposeInt8CPUKernel() override; + + int Init() override; + int ReSize() override; + int Run() override; + + public: + int DoTranspose(int task_id); + + private: + int MallocTmpBuf(); + void FreeTmpBuf(); + + private: + TransposeParameter *transpose_param_; + int8_t *in_ptr_ = nullptr; + int8_t *out_ptr_ = nullptr; + int *in_shape_ = nullptr; + int *out_shape_ = nullptr; + int *dim_size_ = nullptr; + int *position_ = nullptr; + bool extra_dims_ = false; + int thread_num_ = 1; + int thread_h_stride_ = 0; + int thread_h_num_ = 0; + int num_unit_ = 0; +}; +} // namespace mindspore::kernel + +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_TRANSPOSE_INT8_H_ diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc index 30551a2ecb..b04c057f1a 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/transpose_fp32_tests.cc @@ -64,7 +64,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes4) { param->out_strides_[i] = out_strides[i]; } - auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3, nullptr, nullptr); + auto ret = DoTransposeFp32(in, out, input_shape, output_shape, param, 0, 3, nullptr, nullptr); ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001); @@ -104,7 +104,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes3) { param->out_strides_[i] = out_strides[i]; } - auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 3, nullptr, nullptr); + auto ret = DoTransposeFp32(in, out, input_shape, output_shape, param, 0, 3, nullptr, nullptr); ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001); @@ -145,7 +145,7 @@ TEST_F(TestTransposeFp32, TransposeFp32_axes2) { param->out_strides_[i] = out_strides[i]; } - auto ret = DoTranspose(in, out, input_shape, output_shape, param, 0, 6, nullptr, nullptr); + auto ret = DoTransposeFp32(in, out, input_shape, output_shape, param, 0, 6, nullptr, nullptr); ASSERT_EQ(ret, 0); delete param; CompareOutputData(out, correct, 24, 0.000001);