diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc index ee36d3446c..6c0f46b872 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc @@ -456,9 +456,7 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { } int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_; - if (src_a.data_type_ == LDataType::BOOL) { - DivideImpl(src_a, src_b, *dst, total_size); - } else if (src_a.data_type_ == LDataType::INT8) { + if (src_a.data_type_ == LDataType::INT8) { DivideImpl(src_a, src_b, *dst, total_size); } else if (src_a.data_type_ == LDataType::UINT8) { DivideImpl(src_a, src_b, *dst, total_size); @@ -484,5 +482,102 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { return true; } +template +inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) { + for (size_t i = 0; i < total_size; i++) { + dst[i] = src0[i] * src1[i]; + } +} + +template <> +inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) { + int64_t x = 0; +#ifdef USE_NEON + const int64_t step = 32; + for (; x <= total_size - step; x += step) { + uint8x16_t v_src00 = vld1q_u8(src0 + x); + uint8x16_t v_src01 = vld1q_u8(src0 + x + 16); + uint8x16_t v_src10 = vld1q_u8(src1 + x); + uint8x16_t v_src11 = vld1q_u8(src1 + x + 16); + uint8x16_t v_dst_l, v_dst_h; + + v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10)); + v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10)); + vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h))); + + v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11)); + v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11)); + vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h))); + } +#endif + for (; x < total_size; x++) { + int32_t val = src0[x] * src1[x]; + dst[x] = std::max(std::numeric_limits::min(), + std::min(std::numeric_limits::max(), val)); + } +} + +template <> +inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) { + for (size_t i = 0; i < total_size; i++) { + int32_t val = src0[i] * src1[i]; + dst[i] = std::max(std::numeric_limits::min(), + std::min(std::numeric_limits::max(), val)); + } +} + +template <> +inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) { + for (size_t i = 0; i < total_size; i++) { + int64_t val = src0[i] * src1[i]; + dst[i] = std::max(std::numeric_limits::min(), + std::min(std::numeric_limits::max(), val)); + } +} + +bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { + if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) { + return false; + } + + if (src_a.data_type_ != src_b.data_type_) { + return false; + } + + if (dst->IsEmpty()) { + dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_); + } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) { + return false; + } else if (src_a.data_type_ != dst->data_type_) { + return false; + } + + int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_; + if (src_a.data_type_ == LDataType::INT8) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::UINT8) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::INT16) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::UINT16) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::INT32) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::UINT32) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::INT64) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::UINT64) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::FLOAT32) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else if (src_a.data_type_ == LDataType::FLOAT64) { + MultiplyImpl(src_a, src_b, *dst, total_size); + } else { + return false; + } + return true; +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h index 453a1d6e28..bbeedd9714 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h @@ -254,6 +254,9 @@ bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst); /// \brief Calculates the division between the two images for each element bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst); +/// \brief Calculates the multiply between the two images for each element +bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst); + } // namespace dataset } // namespace mindspore #endif // MINI_MAT_H_ diff --git a/tests/ut/cpp/dataset/image_process_test.cc b/tests/ut/cpp/dataset/image_process_test.cc index 717d494770..a460f02a44 100644 --- a/tests/ut/cpp/dataset/image_process_test.cc +++ b/tests/ut/cpp/dataset/image_process_test.cc @@ -789,3 +789,60 @@ TEST_F(MindDataImageProcess, TestDivideFloat) { static_cast(dst_float.data_ptr_)[i].c1); } } + +TEST_F(MindDataImageProcess, TestMultiplyUint8) { + const size_t cols = 4; + // Test uint8 + LiteMat src1_uint8(1, cols); + LiteMat src2_uint8(1, cols); + LiteMat expect_uint8(1, cols); + for (size_t i = 0; i < cols; i++) { + static_cast(src1_uint8.data_ptr_)[i] = 8; + static_cast(src2_uint8.data_ptr_)[i] = 4; + static_cast(expect_uint8.data_ptr_)[i] = 32; + } + LiteMat dst_uint8; + EXPECT_TRUE(Multiply(src1_uint8, src2_uint8, &dst_uint8)); + for (size_t i = 0; i < cols; i++) { + EXPECT_EQ(static_cast(expect_uint8.data_ptr_)[i].c1, + static_cast(dst_uint8.data_ptr_)[i].c1); + } +} + +TEST_F(MindDataImageProcess, TestMultiplyUInt16) { + const size_t cols = 4; + // Test int16 + LiteMat src1_int16(1, cols, LDataType(LDataType::UINT16)); + LiteMat src2_int16(1, cols, LDataType(LDataType::UINT16)); + LiteMat expect_int16(1, cols, LDataType(LDataType::UINT16)); + for (size_t i = 0; i < cols; i++) { + static_cast(src1_int16.data_ptr_)[i] = 60000; + static_cast(src2_int16.data_ptr_)[i] = 2; + static_cast(expect_int16.data_ptr_)[i] = 65535; + } + LiteMat dst_int16; + EXPECT_TRUE(Multiply(src1_int16, src2_int16, &dst_int16)); + for (size_t i = 0; i < cols; i++) { + EXPECT_EQ(static_cast(expect_int16.data_ptr_)[i].c1, + static_cast(dst_int16.data_ptr_)[i].c1); + } +} + +TEST_F(MindDataImageProcess, TestMultiplyFloat) { + const size_t cols = 4; + // Test float + LiteMat src1_float(1, cols, LDataType(LDataType::FLOAT32)); + LiteMat src2_float(1, cols, LDataType(LDataType::FLOAT32)); + LiteMat expect_float(1, cols, LDataType(LDataType::FLOAT32)); + for (size_t i = 0; i < cols; i++) { + static_cast(src1_float.data_ptr_)[i] = 30.0f; + static_cast(src2_float.data_ptr_)[i] = -2.0f; + static_cast(expect_float.data_ptr_)[i] = -60.0f; + } + LiteMat dst_float; + EXPECT_TRUE(Multiply(src1_float, src2_float, &dst_float)); + for (size_t i = 0; i < cols; i++) { + EXPECT_FLOAT_EQ(static_cast(expect_float.data_ptr_)[i].c1, + static_cast(dst_float.data_ptr_)[i].c1); + } +}