fp32 optimize

pull/12334/head
lzk 4 years ago
parent c2b5490375
commit 37a30fb0fb

@ -231,6 +231,8 @@ endif()
if(NOT PLATFORM_ARM32 AND NOT PLATFORM_ARM64)
if("${X86_64_SIMD}" STREQUAL "sse")
add_compile_definitions(ENABLE_SSE)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.1")
endif()
if("${X86_64_SIMD}" STREQUAL "avx")
add_compile_definitions(ENABLE_SSE)

@ -20,10 +20,17 @@
int Fp32Relu(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
for (; i < length - 8; i += 8) {
MS_ST256_F32(dst + i, MS_MAX256_F32(MS_LD256_F32(src + i), zero_8));
}
#endif
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
for (; i < length - 4; i += 4) {
vst1q_f32(dst + i, vmaxq_f32(vld1q_f32(src + i), zero_4));
MS_STQ_F32(dst + i, MS_MAXQ_F32(MS_LDQ_F32(src + i), zero));
}
#endif
for (; i < length; ++i) {
@ -34,13 +41,24 @@ int Fp32Relu(const float *src, int length, float *dst) {
int Fp32Relu6(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM
float32x4_t zero_4 = vdupq_n_f32(0.0f);
float32x4_t six_4 = vdupq_n_f32(6.0f);
#if defined(ENABLE_AVX)
MS_FLOAT32X8 zero_8 = MS_MOV256_F32(0.0f);
MS_FLOAT32X8 six_8 = MS_MOV256_F32(6.0f);
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 dst_tmp = MS_MAX256_F32(MS_LD256_F32(src + i), zero_8);
dst_tmp = MS_MIN256_F32(dst_tmp, six_8);
MS_ST256_F32(dst + i, dst_tmp);
}
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 zero = MS_MOVQ_F32(0.0f);
MS_FLOAT32X4 six = MS_MOVQ_F32(6.0f);
for (; i < length - 4; i += 4) {
float32x4_t dst_4 = vmaxq_f32(vld1q_f32(src + i), zero_4);
dst_4 = vminq_f32(dst_4, six_4);
vst1q_f32(dst + i, dst_4);
MS_FLOAT32X4 dst_tmp = MS_MAXQ_F32(MS_LDQ_F32(src + i), zero);
dst_tmp = MS_MINQ_F32(dst_tmp, six);
MS_STQ_F32(dst + i, dst_tmp);
}
#endif
for (; i < length; ++i) {
@ -55,14 +73,21 @@ int Fp32Relu6(const float *src, int length, float *dst) {
int LRelu(const float *src, int length, float *dst, float alpha) {
int i = 0;
#ifdef ENABLE_ARM64
float32x4_t alpha_4 = vdupq_n_f32(alpha);
#if defined(ENABLE_AVX)
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 src_tmp = MS_LD256_F32(src + i);
MS_FLOAT32X8 mul_tmp = MS_MUL256_N_F32(src_tmp, alpha);
MS_FLOAT32X8 mask = MS_CMP256_PS(src_tmp, MS_MOV256_F32(0.0f), 30);
MS_ST256_F32(dst + i, MS_BLEND256_PS(mul_tmp, src_tmp, mask));
}
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
for (; i < length - 4; i += 4) {
float32x4_t src_4 = vld1q_f32(src + i);
float32x4_t mul_4 = vmulq_f32(src_4, alpha_4);
uint32x4_t flag = vclezq_f32(src_4);
float32x4_t dst_4 = vbslq_f32(flag, mul_4, src_4);
vst1q_f32(dst + i, dst_4);
MS_FLOAT32X4 src_tmp = MS_LDQ_F32(src + i);
MS_FLOAT32X4 mul_tmp = MS_MULQ_N_F32(src_tmp, alpha);
MS_FLOAT32X4 mask = MS_CMPGTQ_PS(src_tmp, MS_MOVQ_F32(0.0f));
MS_STQ_F32(dst + i, MS_BLENDQ_PS(mul_tmp, src_tmp, mask));
}
#endif
for (; i < length; ++i) {
@ -73,11 +98,18 @@ int LRelu(const float *src, int length, float *dst, float alpha) {
int Sigmoid(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM64
int count = (length / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
simd_exp(vnegq_f32(vld1q_f32(src + i)), dst + i);
vst1q_f32(dst + i, vdivq_f32(vdupq_n_f32(1.0f), vaddq_f32(vdupq_n_f32(1.0f), vld1q_f32(dst + i))));
#if defined(ENABLE_AVX)
for (; i < length - 8; i += 8) {
simd_exp_avx(-(MS_LD256_F32(src + i)), dst + i);
MS_ST256_F32(dst + i,
MS_DIV256_F32(MS_MOV256_F32(1.0f), MS_ADD256_F32(MS_MOV256_F32(1.0f), MS_LD256_F32(dst + i))));
}
#endif
#if defined(ENABLE_ARM64) || defined(ENABLE_SSE)
for (; i < length - 4; i += 4) {
simd_exp(-(MS_LDQ_F32(src + i)), dst + i);
MS_STQ_F32(dst + i, MS_DIVQ_F32(MS_MOVQ_F32(1.0f), MS_ADDQ_F32(MS_MOVQ_F32(1.0f), MS_LDQ_F32(dst + i))));
}
#endif
for (; i < length; ++i) {
@ -102,26 +134,40 @@ float TanhOpt(float src) {
int Tanh(const float *src, int length, float *dst) {
int i = 0;
#ifdef ENABLE_ARM64
static float32x4_t paramv[] = {{378.0f, 378.0f, 378.0f, 378.0f},
{17325.0f, 17325.0f, 17325.0f, 17325.0f},
{135135.0f, 135135.0f, 135135.0f, 135135.0f},
{28.0f, 28.0f, 28.0f, 28.0f},
{3150.0f, 3150.0f, 3150.0f, 3150.0f},
{62370.0f, 62370.0f, 62370.0f, 62370.0f}};
float32x4_t neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
float32x4_t pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
int count = (length / C4NUM) * C4NUM;
for (; i < count; i += C4NUM) {
float32x4_t input = vld1q_f32(src + i);
float32x4_t square = vmulq_f32(input, input);
float32x4_t a = vmulq_f32(
vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(square, paramv[0]), square), paramv[1]), square), paramv[2]),
input);
float32x4_t b = vaddq_f32(
vmulq_f32(vaddq_f32(vmulq_f32(vaddq_f32(vmulq_f32(paramv[3], square), paramv[4]), square), paramv[5]), square),
paramv[2]);
vst1q_f32(dst + i, vminq_f32(vmaxq_f32(vdivq_f32(a, b), neg_one), pos_one));
#if defined(ENABLE_ARM) || defined(ENABLE_SSE) || defined(ENABLE_AVX)
const int cnt = 6;
float data[] = {378.0f, 17325.0f, 135135.0f, 28.0f, 3150.0f, 62370.0f};
#endif
#if defined(ENABLE_AVX)
MS_FLOAT32X8 neg_one_8 = {-1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X8 pos_one_8 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
MS_FLOAT32X8 param256[6];
for (int j = 0; j < cnt; ++j) {
param256[j] = MS_MOV256_F32(data[j]);
}
for (; i < length - 8; i += 8) {
MS_FLOAT32X8 input = MS_LD256_F32(src + i);
MS_FLOAT32X8 square = input * input;
MS_FLOAT32X8 a = (((square + param256[0]) * square + param256[1]) * square + param256[2]) * input;
MS_FLOAT32X8 b = ((param256[3] * square + param256[4]) * square + param256[5]) * square + param256[2];
MS_ST256_F32(dst + i, MS_MIN256_F32(MS_MAX256_F32(a / b, neg_one_8), pos_one_8));
}
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
MS_FLOAT32X4 param[6];
MS_FLOAT32X4 neg_one = {-1.0f, -1.0f, -1.0f, -1.0f};
MS_FLOAT32X4 pos_one = {1.0f, 1.0f, 1.0f, 1.0f};
for (int j = 0; j < cnt; ++j) {
param[j] = MS_MOVQ_F32(data[j]);
}
for (; i < length - 4; i += 4) {
MS_FLOAT32X4 input = MS_LDQ_F32(src + i);
MS_FLOAT32X4 square = input * input;
MS_FLOAT32X4 a = (((square + param[0]) * square + param[1]) * square + param[2]) * input;
MS_FLOAT32X4 b = ((param[3] * square + param[4]) * square + param[5]) * square + param[2];
MS_STQ_F32(dst + i, MS_MINQ_F32(MS_MAXQ_F32(a / b, neg_one), pos_one));
}
#endif
for (; i < length; ++i) {
@ -142,12 +188,21 @@ int Swish(const float *src, int length, float *dst) {
return NNACL_ERR;
}
int index = 0;
#ifdef ENABLE_NEON
for (; index <= length - C4NUM; index += C4NUM) {
float32x4_t src_value = vld1q_f32(src + index);
float32x4_t sigmoid_value = vld1q_f32(dst + index);
float32x4_t result = vmulq_f32(src_value, sigmoid_value);
vst1q_f32(dst + index, result);
#if defined(ENABLE_AVX)
for (; index <= length - 8; index += 8) {
MS_FLOAT32X8 src_value = MS_LD256_F32(src + index);
MS_FLOAT32X8 sigmoid_value = MS_LD256_F32(dst + index);
MS_FLOAT32X8 result = MS_MUL256_F32(src_value, sigmoid_value);
MS_ST256_F32(dst + index, result);
}
#endif
#if defined(ENABLE_ARM) || defined(ENABLE_SSE)
for (; index <= length - 4; index += 4) {
MS_FLOAT32X4 src_value = MS_LDQ_F32(src + index);
MS_FLOAT32X4 sigmoid_value = MS_LDQ_F32(dst + index);
MS_FLOAT32X4 result = MS_MULQ_F32(src_value, sigmoid_value);
MS_STQ_F32(dst + index, result);
}
#endif
for (; index < length; index++) {

File diff suppressed because it is too large Load Diff

@ -38,26 +38,49 @@ extern "C" {
int Exp(const float *input_data, float *output_data, const ExpParameter *parameter, int task_id);
void ExpFp32(const float *src, float *dst, int num);
#ifdef ENABLE_ARM64
static inline void simd_exp(float32x4_t input4, float *dst) {
static float32x4_t maxv = {88.0f, 88.0f, 88.0f, 88.0f};
static float32x4_t minv = {-88.0f, -88.0f, -88.0f, -88.0f};
static float32x4_t paramv[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
#if defined(ENABLE_ARM64) || defined(ENABLE_SSE)
static inline void simd_exp(MS_FLOAT32X4 input, float *dst) {
static MS_FLOAT32X4 maxv = {88.0f, 88.0f, 88.0f, 88.0f};
static MS_FLOAT32X4 minv = {-88.0f, -88.0f, -88.0f, -88.0f};
static MS_FLOAT32X4 param[] = {{0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
{1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24},
{1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6},
{0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f}};
input4 = vmaxq_f32(minv, vminq_f32(maxv, input4));
int32x4_t integer4 = vcvtq_s32_f32(vdivq_f32(input4, paramv[0]));
float32x4_t decimal4 = vsubq_f32(input4, vmulq_f32(vcvtq_f32_s32(integer4), paramv[0]));
int32x4_t int_exp4 = vshlq_s32(vaddq_s32(integer4, vdupq_n_s32(127)), vdupq_n_s32(23));
vst1q_f32(dst, vld1q_f32((float32_t *)(&int_exp4)));
float32x4_t decimal_exp4 = vaddq_f32(paramv[2], vmulq_f32(decimal4, paramv[1]));
decimal_exp4 = vmulq_f32(decimal4, vaddq_f32(paramv[3], vmulq_f32(decimal4, decimal_exp4)));
decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, vaddq_f32(paramv[4], decimal_exp4)));
decimal_exp4 = vaddq_f32(paramv[5], vmulq_f32(decimal4, decimal_exp4));
vst1q_f32(dst, vmulq_f32(vld1q_f32(dst), decimal_exp4));
input = MS_MAXQ_F32(minv, MS_MINQ_F32(input, maxv));
MS_INT32X4 integer = MS_CVTQPS_EPI32(input / param[0]);
MS_FLOAT32X4 decimal = input - MS_CVTQEPI32_PS(integer) * param[0];
MS_INT32X4 int_exp = MS_SLLIQ_EPI32(MS_ADDQ_EPI32(integer, MS_MOVQ_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 4);
MS_FLOAT32X4 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_STQ_F32(dst, decimal_exp * MS_LDQ_F32(dst));
}
#endif
#if defined(ENABLE_AVX)
static inline void simd_exp_avx(MS_FLOAT32X8 input, float *dst) {
static MS_FLOAT32X8 maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f};
static MS_FLOAT32X8 minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f};
static MS_FLOAT32X8 param[] = {
{0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f, 0.693147f},
{1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120, 1.0f / 120},
{1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24, 1.0f / 24},
{1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6, 1.0f / 6},
{0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}};
input = MS_MAX256_F32(minv, MS_MIN256_F32(input, maxv));
MS_INT32X8 integer = MS_CVT256PS_EPI32(input / param[0]);
MS_FLOAT32X8 decimal = input - MS_CVT256EPI32_PS(integer) * param[0];
MS_INT32X8 int_exp = MS_SLLI256_EPI32(MS_ADD256_EPI32(integer, MS_MOV256_EPI32(127)), 23);
memcpy(dst, &int_exp, sizeof(int32_t) * 8);
MS_FLOAT32X8 decimal_exp =
param[5] +
decimal * (param[5] + decimal * (param[4] + decimal * (param[3] + decimal * (param[2] + decimal * param[1]))));
MS_ST256_F32(dst, decimal_exp * MS_LD256_F32(dst));
}
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -107,28 +107,105 @@ typedef enum CalFixedMultiplierMode {
#ifdef ENABLE_ARM
#define MS_FLOAT32X4 float32x4_t
#define MS_INT32X4 int32x4_t
#define MS_LDQ_F32 vld1q_f32
#define MS_LDQ_EPI32 vld1q_s32
#define MS_ADDQ_F32 vaddq_f32
#define MS_ADDQ_EPI32 vaddq_s32
#define MS_MOVQ_F32 vmovq_n_f32
#define MS_MOVQ_EPI32 vmovq_n_s32
#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32.
#define MS_SUBQ_F32 vsubq_f32
#define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3)
#define MS_STQ_F32 vst1q_f32
#define MS_STQ_EPI32 vst1q_s32
#define MS_MAXQ_F32 vmaxq_f32
#define MS_MAXQ_EPI32 vmaxq_s32
#define MS_MINQ_F32 vminq_f32
#define MS_MULQ_F32(src1, src2) vmulq_n_f32(src1, src2)
#elif defined(ENABLE_SSE)
#define MS_MINQ_EPI32 vminq_s32
#define MS_MULQ_F32(src1, src2) vmulq_f32(src1, src2)
#define MS_MULQ_EPI32(src1, src2) vmulq_s32(src1, src2)
#ifdef ENABLE_ARM64
#define MS_DIVQ_F32(src1, src2) vdivq_f32(src1, src2)
#else
#define MS_DIVQ_F32(src1, src2) vmulq_f32(src1, vrecpeq_f32(src2))
#endif
#define MS_MULQ_N_F32(src1, src2) vmulq_n_f32(src1, src2)
#define MS_MULQ_N_EPI32(src1, src2) vmulq_n_s32(src1, src2)
#define MS_DIVQ_N_F32(src1, src2) vdivq_n_f32(src1, src2)
#define MS_SLLIQ_EPI32(src1, src2) vshlq_s32(src1, vmovq_n_s32(src2))
#define MS_CVTQPS_EPI32(src) vcvtq_s32_f32(src)
#define MS_CVTQEPI32_PS(src) vcvtq_f32_s32(src)
#define MS_CMPGTQ_PS(src1, src2) vcgtq_f32(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) vcgtq_s32(src1, src2)
// Note: Compared with X86, the vbslq_f32 parameters are the opposite with _mm_blendv_ps
#define MS_BLENDQ_PS(src1, src2, src3) vbslq_f32(src3, src2, src1)
#define MS_BLENDQ_EPI32(src1, src2, src3) vbslq_s32(src3, src2, src1)
#elif defined(ENABLE_AVX)
#define MS_FLOAT32X8 __m256
#define MS_INT32X8 __m256i
#define MS_LD256_F32 _mm256_loadu_ps
#define MS_LD256_EPI32(src) _mm256_loadu_si256((__m256i const *)(src))
#define MS_ADD256_F32 _mm256_add_ps
#define MS_ADD256_EPI32 _mm256_add_epi32
#define MS_MOV256_F32 _mm256_set1_ps
#define MS_MOV256_EPI32 _mm256_set1_epi32
#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32.
#define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3))
#define MS_ST256_F32 _mm256_storeu_ps
#define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2)
#define MS_SUB256_F32 _mm256_sub_ps
#define MS_MAX256_F32 _mm256_max_ps
#define MS_MAX256_EPI32 _mm256_max_epi32
#define MS_MIN256_F32 _mm256_min_ps
#define MS_MIN256_EPI32 _mm256_min_epi32
#define MS_MUL256_F32(src1, src2) _mm256_mul_ps(src1, src2)
#define MS_MUL256_EPI32(src1, src2) _mm256_mul_epi32(src1, src2)
#define MS_DIV256_F32(src1, src2) _mm256_div_ps(src1, src2)
#define MS_MUL256_N_F32(src1, src2) _mm256_mul_ps(src1, _mm256_set1_ps(src2))
#define MS_MUL256_N_EPI32(src1, src2) _mm256_mul_epi32(src1, _mm256_set1_epi32(src2))
#define MS_DIV256_N_F32(src1, src2) _mm256_div_ps(src1, _mm256_set1_ps(src2))
#define MS_SLLI256_EPI32(src1, src2) _mm256_slli_epi32(src1, src2)
#define MS_CVT256PS_EPI32(src) _mm256_cvttps_epi32(src)
#define MS_CVT256EPI32_PS(src) _mm256_cvtepi32_ps(src) // truncate float to int
#define MS_CMP256_PS(src1, src2, src3) _mm256_cmp_ps(src1, src2, src3)
#define MS_CMPGT256_EPI32(src1, src2) _mm256_cmpgt_epi32(src1, src2)
#define MS_BLEND256_PS(src1, src2, src3) _mm256_blendv_ps(src1, src2, src3)
#define MS_BLEND256_EPI32(src1, src2, src3) _mm256_blendv_epi8(src1, src2, src3)
#endif
#if defined(ENABLE_SSE)
#define MS_FLOAT32X4 __m128
#define MS_INT32X4 __m128i
#define MS_LDQ_F32 _mm_loadu_ps
#define MS_LDQ_EPI32(src) _mm_loadu_si128((__m128i const *)(src))
#define MS_ADDQ_F32 _mm_add_ps
#define MS_MOVQ_F32 _mm_set_ps1
#define MS_ADDQ_EPI32 _mm_add_epi32
#define MS_MOVQ_F32 _mm_set1_ps
#define MS_MOVQ_EPI32 _mm_set1_epi32
#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32.
#define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3))
#define MS_STQ_F32 _mm_storeu_ps
#define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2)
#define MS_SUBQ_F32 _mm_sub_ps
#define MS_MAXQ_F32 _mm_max_ps
#define MS_MAXQ_EPI32 _mm_max_epi32
#define MS_MINQ_F32 _mm_min_ps
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, _mm_set_ps1(src2))
#define MS_MINQ_EPI32 _mm_min_epi32
#define MS_MULQ_F32(src1, src2) _mm_mul_ps(src1, src2)
#define MS_MULQ_EPI32(src1, src2) _mm_mul_epi32(src1, src2)
#define MS_DIVQ_F32(src1, src2) _mm_div_ps(src1, src2)
#define MS_MULQ_N_F32(src1, src2) _mm_mul_ps(src1, _mm_set1_ps(src2))
#define MS_MULQ_N_EPI32(src1, src2) _mm_mul_epi32(src1, _mm_set1_epi32(src2))
#define MS_DIVQ_N_F32(src1, src2) _mm_div_ps(src1, _mm_set1_ps(src2))
#define MS_SLLIQ_EPI32(src1, src2) _mm_slli_epi32(src1, src2)
#define MS_CVTQPS_EPI32(src) _mm_cvttps_epi32(src) // truncate float to int
#define MS_CVTQEPI32_PS(src) _mm_cvtepi32_ps(src)
#define MS_CMPGTQ_PS(src1, src2) _mm_cmpgt_ps(src1, src2)
#define MS_CMPGTQ_EPI32(src1, src2) _mm_cmpgt_epi32(src1, src2)
#define MS_BLENDQ_PS(src1, src2, src3) _mm_blendv_ps(src1, src2, src3)
#define MS_BLENDQ_EPI32(src1, src2, src3) _mm_blendv_epi8(src1, src2, src3)
#endif
#endif // MINDSPORE_LITE_NNACL_OP_BASE_H_

@ -179,7 +179,7 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_LOG(ERROR) << "element_shape_ is not fullyDefined!";
return RET_ERROR;
}
output->set_data_type(input0->data_type());
output->set_data_type(input0->tensors_data_type());
output->set_shape(element_shape_);
}
output->set_format(input0->GetTensor(index_)->format());

Loading…
Cancel
Save