|
|
|
@ -21,6 +21,13 @@
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_ANDROID
|
|
|
|
|
#if defined(__arm__) || defined(__aarch64__) || defined(_M_ARM) || defined(_M_ARM64)
|
|
|
|
|
#define USE_NEON
|
|
|
|
|
#include <arm_neon.h>
|
|
|
|
|
#endif
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
|
|
|
|
@ -423,25 +430,59 @@ bool ConvertTo(const LiteMat &src, LiteMat &dst, double scale) {
|
|
|
|
|
if (src.data_type_ != LDataType::UINT8) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (scale < 0.0 || scale > 100) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dst.IsEmpty()) {
|
|
|
|
|
(void)dst.Init(src.width_, src.height_, src.channel_, LDataType::FLOAT32);
|
|
|
|
|
} else if (dst.height_ != src.height_ || dst.width_ != src.width_ || dst.channel_ != src.channel_) {
|
|
|
|
|
dst.Init(src.width_, src.height_, src.channel_, LDataType::FLOAT32);
|
|
|
|
|
} else if (src.width_ != dst.width_ || src.height_ != dst.height_ || src.channel_ != dst.channel_) {
|
|
|
|
|
return false;
|
|
|
|
|
} else if (dst.data_type_ != LDataType::FLOAT32) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
const unsigned char *src_start_p = src;
|
|
|
|
|
float *dst_start_p = dst;
|
|
|
|
|
for (int h = 0; h < src.height_; h++) {
|
|
|
|
|
for (int w = 0; w < src.width_; w++) {
|
|
|
|
|
uint32_t index = (h * src.width_ + w) * src.channel_;
|
|
|
|
|
for (int c = 0; c < src.channel_; c++) {
|
|
|
|
|
dst_start_p[index + c] = (static_cast<float>(src_start_p[index + c] * scale));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const uint8_t *src_ptr = (const uint8_t *)src;
|
|
|
|
|
float *dst_ptr = reinterpret_cast<float *>(dst.data_ptr_);
|
|
|
|
|
int64_t total_size = src.height_ * src.width_ * src.channel_;
|
|
|
|
|
int64_t x = 0;
|
|
|
|
|
#ifdef USE_NEON
|
|
|
|
|
float32x4_t v_scale = vdupq_n_f32(static_cast<float>(scale));
|
|
|
|
|
float32x4_t v_c = vdupq_n_f32(0.0f);
|
|
|
|
|
const int64_t step = 16;
|
|
|
|
|
for (; x <= total_size - step; x += step) {
|
|
|
|
|
uint8x16_t v_src = vld1q_u8(src_ptr + x);
|
|
|
|
|
uint8x16_t v_dst;
|
|
|
|
|
|
|
|
|
|
uint16x8_t v_l_16x8 = vmovl_u8(vget_low_u8(v_src));
|
|
|
|
|
uint16x8_t v_h_16x8 = vmovl_u8(vget_high_u8(v_src));
|
|
|
|
|
|
|
|
|
|
float32x4_t v_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_l_16x8)));
|
|
|
|
|
float32x4_t v_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_l_16x8)));
|
|
|
|
|
float32x4_t v_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(v_h_16x8)));
|
|
|
|
|
float32x4_t v_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(v_h_16x8)));
|
|
|
|
|
|
|
|
|
|
#if defined(__aarch64__) || defined(_M_ARM64)
|
|
|
|
|
v_ll_f32x4 = vfmaq_f32(v_c, v_ll_f32x4, v_scale);
|
|
|
|
|
v_lh_f32x4 = vfmaq_f32(v_c, v_lh_f32x4, v_scale);
|
|
|
|
|
v_hl_f32x4 = vfmaq_f32(v_c, v_hl_f32x4, v_scale);
|
|
|
|
|
v_hh_f32x4 = vfmaq_f32(v_c, v_hh_f32x4, v_scale);
|
|
|
|
|
#else
|
|
|
|
|
v_ll_f32x4 = vmlaq_f32(v_c, v_ll_f32x4, v_scale);
|
|
|
|
|
v_lh_f32x4 = vmlaq_f32(v_c, v_lh_f32x4, v_scale);
|
|
|
|
|
v_hl_f32x4 = vmlaq_f32(v_c, v_hl_f32x4, v_scale);
|
|
|
|
|
v_hh_f32x4 = vmlaq_f32(v_c, v_hh_f32x4, v_scale);
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
vst1q_f32(dst_ptr + x, v_ll_f32x4);
|
|
|
|
|
vst1q_f32(dst_ptr + x + 4, v_lh_f32x4);
|
|
|
|
|
vst1q_f32(dst_ptr + x + 8, v_hl_f32x4);
|
|
|
|
|
vst1q_f32(dst_ptr + x + 12, v_hh_f32x4);
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
for (; x < total_size; x++) {
|
|
|
|
|
dst_ptr[x] = static_cast<float>(src_ptr[x] * scale);
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|