diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 636528c7f1..2ca4a63160 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -23,8 +23,8 @@ #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/auto_contrast_op.h" #include "minddata/dataset/kernels/image/bounding_box_augment_op.h" -#include "minddata/dataset/kernels/image/center_crop_op.h" #endif +#include "minddata/dataset/kernels/image/center_crop_op.h" #include "minddata/dataset/kernels/image/crop_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/cutmix_batch_op.h" @@ -94,6 +94,7 @@ std::shared_ptr BoundingBoxAugment(std::shared_ptr< // Input validation return op->ValidateParams() ? op : nullptr; } +#endif // Function to create CenterCropOperation. std::shared_ptr CenterCrop(std::vector size) { @@ -101,7 +102,6 @@ std::shared_ptr CenterCrop(std::vector size) { // Input validation return op->ValidateParams() ? op : nullptr; } -#endif // Function to create CropOperation. std::shared_ptr Crop(std::vector coordinates, std::vector size) { @@ -519,6 +519,7 @@ std::shared_ptr BoundingBoxAugmentOperation::Build() { std::shared_ptr tensor_op = std::make_shared(transform_->Build(), ratio_); return tensor_op; } +#endif // CenterCropOperation CenterCropOperation::CenterCropOperation(std::vector size) : size_(size) {} @@ -558,7 +559,6 @@ std::shared_ptr CenterCropOperation::Build() { return tensor_op; } -#endif // CropOperation. CropOperation::CropOperation(std::vector coordinates, std::vector size) : coordinates_(coordinates), size_(size) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/vision.h b/mindspore/ccsrc/minddata/dataset/include/vision.h index 61131f663f..d16655c093 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision.h @@ -35,8 +35,8 @@ namespace vision { #ifndef ENABLE_ANDROID class AutoContrastOperation; class BoundingBoxAugmentOperation; -class CenterCropOperation; #endif +class CenterCropOperation; class CropOperation; #ifndef ENABLE_ANDROID class CutMixBatchOperation; @@ -96,6 +96,7 @@ std::shared_ptr AutoContrast(float cutoff = 0.0, std::vec /// \return Shared pointer to the current TensorOperation. std::shared_ptr BoundingBoxAugment(std::shared_ptr transform, float ratio = 0.3); +#endif /// \brief Function to create a CenterCrop TensorOperation. /// \notes Crops the input image at the center to the given size. @@ -104,7 +105,7 @@ std::shared_ptr BoundingBoxAugment(std::shared_ptr< /// If size has 2 values, it should be (height, width). /// \return Shared pointer to the current TensorOperation. std::shared_ptr CenterCrop(std::vector size); -#endif + /// \brief Function to create a Crop TensorOp /// \notes Crop an image based on location and crop size /// \param[in] coordinates Starting location of crop. Must be a vector of two values, in the form of {x_coor, y_coor} @@ -502,6 +503,8 @@ class BoundingBoxAugmentOperation : public TensorOperation { float ratio_; }; +#endif + class CenterCropOperation : public TensorOperation { public: explicit CenterCropOperation(std::vector size); @@ -515,7 +518,7 @@ class CenterCropOperation : public TensorOperation { private: std::vector size_; }; -#endif + class CropOperation : public TensorOperation { public: CropOperation(std::vector coordinates, std::vector size); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc index a27b2cb000..57160ba70c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/center_crop_op.cc @@ -16,8 +16,13 @@ #include "minddata/dataset/kernels/image/center_crop_op.h" #include #include "utils/ms_utils.h" -#include "minddata/dataset/core/cv_tensor.h" + +#ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/image_utils.h" +#else +#include "minddata/dataset/kernels/image/lite_image_utils.h" +#endif + #include "minddata/dataset/util/status.h" namespace mindspore { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc index cb4ddce123..eef774fa88 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.cc @@ -73,6 +73,20 @@ LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) { Init(width, height, channel, data_type); } +LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) { + data_type_ = data_type; + InitElemSize(data_type); + width_ = width; + height_ = height; + dims_ = 3; + channel_ = channel; + c_step_ = height_ * width_; + size_ = c_step_ * channel_ * elem_size_; + data_ptr_ = p_data; + ref_count_ = new int[1]; + *ref_count_ = 0; +} + LiteMat::~LiteMat() { Release(); } int LiteMat::addRef(int *p, int value) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h index ce0ff058ec..1e0705ec8f 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/lite_mat.h @@ -195,6 +195,8 @@ class LiteMat { LiteMat(int width, int height, int channel, LDataType data_type = LDataType::UINT8); + LiteMat(int width, int height, int channel, void *p_data, LDataType data_type = LDataType::UINT8); + ~LiteMat(); LiteMat(const LiteMat &m); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc index a034325b8a..479cd36754 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc @@ -229,6 +229,11 @@ Status Crop(const std::shared_ptr &input, std::shared_ptr *outpu if (input->Rank() != 3 && input->Rank() != 2) { RETURN_STATUS_UNEXPECTED("Shape not or "); } + + if (input->type() != DataType::DE_FLOAT32 || input->type() != DataType::DE_UINT8) { + RETURN_STATUS_UNEXPECTED("Only float32, uint8 support in Crop"); + } + // account for integer overflow if (y < 0 || (y + h) > input->shape()[0] || (y + h) < 0) { RETURN_STATUS_UNEXPECTED("Invalid y coordinate value for crop"); @@ -237,18 +242,17 @@ Status Crop(const std::shared_ptr &input, std::shared_ptr *outpu if (x < 0 || (x + w) > input->shape()[1] || (x + w) < 0) { RETURN_STATUS_UNEXPECTED("Invalid x coordinate value for crop"); } - // convert to lite Mat - LiteMat lite_mat_rgb; - // rows = height, this constructor takes: cols,rows - bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0], - lite_mat_rgb); - CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed"); + + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3, + const_cast(reinterpret_cast(input->GetBuffer())), LDataType::UINT8); + try { TensorShape shape{h, w}; int num_channels = input->shape()[2]; if (input->Rank() == 3) shape = shape.AppendDim(num_channels); LiteMat lite_mat_cut; - ret = Crop(lite_mat_rgb, lite_mat_cut, x, y, x + w, y + h); + + bool ret = Crop(lite_mat_rgb, lite_mat_cut, x, y, x + w, y + h); CHECK_FAIL_RETURN_UNEXPECTED(ret, "Crop failed in lite cv"); // create output Tensor based off of lite_mat_cut std::shared_ptr output_tensor; @@ -287,14 +291,17 @@ Status Normalize(const std::shared_ptr &input, std::shared_ptr * if (input->Rank() != 3) { RETURN_STATUS_UNEXPECTED("Input tensor rank isn't 3"); } - LiteMat lite_mat_rgb; - // rows = height, this constructor takes: cols,rows - bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0], - lite_mat_rgb); - CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed"); + + if (input->type() != DataType::DE_UINT8) { + RETURN_STATUS_UNEXPECTED("Only uint8 support in Normalize"); + } + + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3, + const_cast(reinterpret_cast(input->GetBuffer())), LDataType::UINT8); + LiteMat lite_mat_float; // change input to float - ret = ConvertTo(lite_mat_rgb, lite_mat_float, 1.0); + bool ret = ConvertTo(lite_mat_rgb, lite_mat_float, 1.0); CHECK_FAIL_RETURN_UNEXPECTED(ret, "Conversion of lite cv to float failed"); mean->Squeeze(); @@ -337,6 +344,9 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out if (input->Rank() != 3) { RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of "); } + if (input->type() != DataType::DE_UINT8) { + RETURN_STATUS_UNEXPECTED("Only uint8 support in Resize"); + } // resize image too large or too small if (output_height == 0 || output_height > input->shape()[0] * 1000 || output_width == 0 || output_width > input->shape()[1] * 1000) { @@ -345,10 +355,8 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out "1000 times the original image; 2) can not be 0."; return Status(StatusCode::kShapeMisMatch, err_msg); } - LiteMat lite_mat_rgb; - bool ret = InitFromPixel(input->GetBuffer(), LPixelType::RGB, LDataType::UINT8, input->shape()[1], input->shape()[0], - lite_mat_rgb); - CHECK_FAIL_RETURN_UNEXPECTED(ret, "Creation of lite cv failed"); + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3, + const_cast(reinterpret_cast(input->GetBuffer())), LDataType::UINT8); try { TensorShape shape{output_height, output_width}; @@ -356,7 +364,7 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out if (input->Rank() == 3) shape = shape.AppendDim(num_channels); LiteMat lite_mat_resize; - ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, output_width, output_height); + bool ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, output_width, output_height); CHECK_FAIL_RETURN_UNEXPECTED(ret, "Resize failed in lite cv"); std::shared_ptr output_tensor; RETURN_IF_NOT_OK( @@ -368,5 +376,39 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out return Status::OK(); } +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { + if (input->Rank() != 3) { + RETURN_STATUS_UNEXPECTED("Input Tensor is not in shape of "); + } + + if (input->type() != DataType::DE_FLOAT32 || input->type() != DataType::DE_UINT8) { + RETURN_STATUS_UNEXPECTED("Only float32, uint8 support in Pad"); + } + + if (pad_top <= 0 || pad_bottom <= 0 || pad_left <= 0 || pad_right <= 0) { + RETURN_STATUS_UNEXPECTED("The pad, top, bottom, left, right must be greater than 0"); + } + + try { + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], 3, + const_cast(reinterpret_cast(input->GetBuffer())), LDataType::UINT8); + + LiteMat lite_mat_pad; + bool ret = Pad(lite_mat_rgb, lite_mat_pad, pad_top, pad_bottom, pad_left, pad_right, + PaddBorderType::PADD_BORDER_CONSTANT, fill_r, fill_g, fill_b); + CHECK_FAIL_RETURN_UNEXPECTED(ret, "Pad failed in lite cv"); + + std::shared_ptr output_tensor; + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(input->shape(), DataType(DataType::DE_FLOAT32), + static_cast(lite_mat_pad.data_ptr_), &output_tensor)); + *output = output_tensor; + } catch (std::runtime_error &e) { + RETURN_STATUS_UNEXPECTED("Error in image Pad."); + } + return Status::OK(); +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h index 53ce291220..e21e5a47d4 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h @@ -95,6 +95,20 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out int32_t output_width, double fx = 0.0, double fy = 0.0, InterpolationMode mode = InterpolationMode::kLinear); +/// \brief Pads the input image and puts the padded image in the output +/// \param input: input Tensor +/// \param output: padded Tensor +/// \param pad_top: amount of padding done in top +/// \param pad_bottom: amount of padding done in bottom +/// \param pad_left: amount of padding done in left +/// \param pad_right: amount of padding done in right +/// \param border_types: the interpolation to be done in the border +/// \param fill_r: red fill value for pad +/// \param fill_g: green fill value for pad +/// \param fill_b: blue fill value for pad. +Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, + const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_IMAGE_UTILS_H_ diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index 5fbe263b4f..26eb09c841 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -144,7 +144,6 @@ if (BUILD_MINDDATA STREQUAL "full") "${MINDDATA_DIR}/kernels/image/auto_contrast_op.cc" "${MINDDATA_DIR}/kernels/image/bounding_box_op.cc" "${MINDDATA_DIR}/kernels/image/bounding_box_augment_op.cc" - "${MINDDATA_DIR}/kernels/image/center_crop_op.cc" "${MINDDATA_DIR}/kernels/image/concatenate_op.cc" "${MINDDATA_DIR}/kernels/image/cut_out_op.cc" "${MINDDATA_DIR}/kernels/image/cutmix_batch_op.cc" diff --git a/tests/ut/cpp/dataset/image_process_test.cc b/tests/ut/cpp/dataset/image_process_test.cc index 9feac292c9..b0a98f1885 100644 --- a/tests/ut/cpp/dataset/image_process_test.cc +++ b/tests/ut/cpp/dataset/image_process_test.cc @@ -107,6 +107,53 @@ TEST_F(MindDataImageProcess, testRGB) { cv::Mat dst_image(lite_mat_rgb.height_, lite_mat_rgb.width_, CV_8UC3, lite_mat_rgb.data_ptr_); } +TEST_F(MindDataImageProcess, testLoadByMemPtr) { + std::string filename = "data/dataset/apple.jpg"; + cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); + + cv::Mat rgba_mat; + cv::cvtColor(image, rgba_mat, CV_BGR2RGB); + + bool ret = false; + int width = rgba_mat.cols; + int height = rgba_mat.rows; + uchar *p_rgb = (uchar *)malloc(width * height * 3 * sizeof(uchar)); + for (int i = 0; i < height; i++) { + const uchar *current = rgba_mat.ptr(i); + for (int j = 0; j < width; j++) { + p_rgb[i * width * 3 + 3 * j + 0] = current[3 * j + 0]; + p_rgb[i * width * 3 + 3 * j + 1] = current[3 * j + 1]; + p_rgb[i * width * 3 + 3 * j + 2] = current[3 * j + 2]; + } + } + + LiteMat lite_mat_rgb(width, height, 3, (void *)p_rgb, LDataType::UINT8); + LiteMat lite_mat_resize; + ret = ResizeBilinear(lite_mat_rgb, lite_mat_resize, 256, 256); + ASSERT_TRUE(ret == true); + LiteMat lite_mat_convert_float; + ret = ConvertTo(lite_mat_resize, lite_mat_convert_float, 1.0); + ASSERT_TRUE(ret == true); + + LiteMat lite_mat_crop; + ret = Crop(lite_mat_convert_float, lite_mat_crop, 16, 16, 224, 224); + ASSERT_TRUE(ret == true); + std::vector means = {0.485, 0.456, 0.406}; + std::vector stds = {0.229, 0.224, 0.225}; + LiteMat lite_norm_mat_cut; + ret = SubStractMeanNormalize(lite_mat_crop, lite_norm_mat_cut, means, stds); + + int pad_width = lite_norm_mat_cut.width_ + 20; + int pad_height = lite_norm_mat_cut.height_ + 20; + float *p_rgb_pad = (float *)malloc(pad_width * pad_height * 3 * sizeof(float)); + + LiteMat makeborder(pad_width, pad_height, 3, (void *)p_rgb_pad, LDataType::FLOAT32); + ret = Pad(lite_norm_mat_cut, makeborder, 10, 30, 40, 10, PaddBorderType::PADD_BORDER_CONSTANT, 255, 255, 255); + cv::Mat dst_image(pad_height, pad_width, CV_8UC3, p_rgb_pad); + free(p_rgb); + free(p_rgb_pad); +} + TEST_F(MindDataImageProcess, test3C) { std::string filename = "data/dataset/apple.jpg"; cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); @@ -512,8 +559,7 @@ TEST_F(MindDataImageProcess, TestSubtractInt8) { LiteMat dst_int8; EXPECT_TRUE(Subtract(src1_int8, src2_int8, dst_int8)); for (size_t i = 0; i < cols; i++) { - EXPECT_EQ(static_cast(expect_int8.data_ptr_)[i].c1, - static_cast(dst_int8.data_ptr_)[i].c1); + EXPECT_EQ(static_cast(expect_int8.data_ptr_)[i].c1, static_cast(dst_int8.data_ptr_)[i].c1); } } @@ -645,8 +691,7 @@ TEST_F(MindDataImageProcess, TestDivideInt8) { LiteMat dst_int8; EXPECT_TRUE(Divide(src1_int8, src2_int8, dst_int8)); for (size_t i = 0; i < cols; i++) { - EXPECT_EQ(static_cast(expect_int8.data_ptr_)[i].c1, - static_cast(dst_int8.data_ptr_)[i].c1); + EXPECT_EQ(static_cast(expect_int8.data_ptr_)[i].c1, static_cast(dst_int8.data_ptr_)[i].c1); } }