|
|
|
@ -495,9 +495,7 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
|
|
|
|
|
if (src_a.data_type_ == LDataType::BOOL) {
|
|
|
|
|
DivideImpl<bool>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::INT8) {
|
|
|
|
|
if (src_a.data_type_ == LDataType::INT8) {
|
|
|
|
|
DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::UINT8) {
|
|
|
|
|
DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
|
|
|
|
@ -523,5 +521,102 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
|
|
|
|
|
for (size_t i = 0; i < total_size; i++) {
|
|
|
|
|
dst[i] = src0[i] * src1[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
|
|
|
|
|
int64_t x = 0;
|
|
|
|
|
#ifdef USE_NEON
|
|
|
|
|
const int64_t step = 32;
|
|
|
|
|
for (; x <= total_size - step; x += step) {
|
|
|
|
|
uint8x16_t v_src00 = vld1q_u8(src0 + x);
|
|
|
|
|
uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
|
|
|
|
|
uint8x16_t v_src10 = vld1q_u8(src1 + x);
|
|
|
|
|
uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
|
|
|
|
|
uint8x16_t v_dst_l, v_dst_h;
|
|
|
|
|
|
|
|
|
|
v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
|
|
|
|
|
v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
|
|
|
|
|
vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
|
|
|
|
|
|
|
|
|
|
v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
|
|
|
|
|
v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
|
|
|
|
|
vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
for (; x < total_size; x++) {
|
|
|
|
|
int32_t val = src0[x] * src1[x];
|
|
|
|
|
dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
|
|
|
|
|
std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
|
|
|
|
|
for (size_t i = 0; i < total_size; i++) {
|
|
|
|
|
int32_t val = src0[i] * src1[i];
|
|
|
|
|
dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
|
|
|
|
|
std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
|
|
|
|
|
for (size_t i = 0; i < total_size; i++) {
|
|
|
|
|
int64_t val = src0[i] * src1[i];
|
|
|
|
|
dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
|
|
|
|
|
std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
|
|
|
|
|
if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (src_a.data_type_ != src_b.data_type_) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dst->IsEmpty()) {
|
|
|
|
|
dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
|
|
|
|
|
} else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
|
|
|
|
|
return false;
|
|
|
|
|
} else if (src_a.data_type_ != dst->data_type_) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
|
|
|
|
|
if (src_a.data_type_ == LDataType::INT8) {
|
|
|
|
|
MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::UINT8) {
|
|
|
|
|
MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::INT16) {
|
|
|
|
|
MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::UINT16) {
|
|
|
|
|
MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::INT32) {
|
|
|
|
|
MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::UINT32) {
|
|
|
|
|
MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::INT64) {
|
|
|
|
|
MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::UINT64) {
|
|
|
|
|
MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::FLOAT32) {
|
|
|
|
|
MultiplyImpl<float>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else if (src_a.data_type_ == LDataType::FLOAT64) {
|
|
|
|
|
MultiplyImpl<double>(src_a, src_b, *dst, total_size);
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace dataset
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|