diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 093d6726fd..0c50415e77 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -134,6 +134,9 @@ std::shared_ptr CenterCrop::Parse(const MapTargetDevice &env) { return std::make_shared(data_->size_); } +// RGB2GRAY Transform Operation. +std::shared_ptr RGB2GRAY::Parse() { return std::make_shared(); } + // Crop Transform Operation. struct Crop::Data { Data(const std::vector &coordinates, const std::vector &size) diff --git a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h index 7f28dbe1eb..a24b46d9ba 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h @@ -91,6 +91,22 @@ class CenterCrop : public TensorTransform { std::shared_ptr data_; }; +/// \brief RGB2GRAY TensorTransform. +/// \notes Convert RGB image or color image to grayscale image +class RGB2GRAY : public TensorTransform { + public: + /// \brief Constructor. + RGB2GRAY() = default; + + /// \brief Destructor. + ~RGB2GRAY() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; +}; + /// \brief Crop TensorTransform. /// \notes Crop an image based on location and crop size class Crop : public TensorTransform { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 1edce28b9a..36bd1b3b5c 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -44,6 +44,7 @@ add_library(kernels-image OBJECT random_sharpness_op.cc rescale_op.cc resize_op.cc + rgb_to_gray_op.cc rgba_to_bgr_op.cc rgba_to_rgb_op.cc sharpness_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc index ba99ae864e..8de575c107 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc @@ -1096,6 +1096,23 @@ Status RgbaToBgr(const std::shared_ptr &input, std::shared_ptr * } } +Status RgbToGray(const std::shared_ptr &input, std::shared_ptr *output) { + try { + std::shared_ptr input_cv = CVTensor::AsCVTensor(std::move(input)); + if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) { + RETURN_STATUS_UNEXPECTED("RgbToGray: image shape is not or channel is not 3."); + } + TensorShape out_shape = TensorShape({input_cv->shape()[0], input_cv->shape()[1]}); + std::shared_ptr output_cv; + RETURN_IF_NOT_OK(CVTensor::CreateEmpty(out_shape, input_cv->type(), &output_cv)); + cv::cvtColor(input_cv->mat(), output_cv->mat(), static_cast(cv::COLOR_RGB2GRAY)); + *output = std::static_pointer_cast(output_cv); + return Status::OK(); + } catch (const cv::Exception &e) { + RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what())); + } +} + Status GetJpegImageInfo(const std::shared_ptr &input, int *img_width, int *img_height) { struct jpeg_decompress_struct cinfo {}; struct JpegErrorManagerCustom jerr {}; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h index 518cdee0ab..3a3ab405a7 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.h @@ -293,6 +293,12 @@ Status RgbaToRgb(const std::shared_ptr &input, std::shared_ptr * /// \return Status code Status RgbaToBgr(const std::shared_ptr &input, std::shared_ptr *output); +/// \brief Take in a 3 channel image in RBG to GRAY +/// \param[in] input The input image +/// \param[out] output The output image +/// \return Status code +Status RgbToGray(const std::shared_ptr &input, std::shared_ptr *output); + /// \brief Get jpeg image width and height /// \param input: CVTensor containing the not decoded image 1D bytes /// \param img_width: the jpeg image width diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc index 64116e712c..057ec63d1b 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.cc @@ -1672,5 +1672,25 @@ bool GetAffineTransform(std::vector src_point, std::vector dst_poi return true; } +bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat) { + if (data_type == LDataType::UINT8) { + if (mat.IsEmpty()) { + mat.Init(w, h, 1, LDataType::UINT8); + } + unsigned char *ptr = mat; + const unsigned char *data_ptr = src; + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + *ptr = (data_ptr[2] * B2GRAY + data_ptr[1] * G2GRAY + data_ptr[0] * R2GRAY + GRAYSHIFT_DELTA) >> GRAYSHIFT; + ptr++; + data_ptr += 3; + } + } + } else { + return false; + } + return true; +} + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h index 5606bf5029..324b9dcd96 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_cv/image_process.h @@ -137,6 +137,9 @@ bool ConvRowCol(const LiteMat &src, const LiteMat &kx, const LiteMat &ky, LiteMa /// \brief Filter the image by a Sobel kernel bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize, PaddBorderType pad_type); +/// \brief Convert RGB image or color image to grayscale image +bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat); + } // namespace dataset } // namespace mindspore #endif // IMAGE_PROCESS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc index 4e399e8413..66db262648 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.cc @@ -421,6 +421,40 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out return Status::OK(); } +Status RgbToGray(const std::shared_ptr &input, std::shared_ptr *output) { + if (input->Rank() != 3) { + RETURN_STATUS_UNEXPECTED("RgbToGray: input image is not in shape of "); + } + if (input->type() != DataType::DE_UINT8) { + RETURN_STATUS_UNEXPECTED("RgbToGray: image datatype is not uint8."); + } + + try { + int output_height = input->shape()[0]; + int output_width = input->shape()[1]; + + LiteMat lite_mat_rgb(input->shape()[1], input->shape()[0], input->shape()[2], + const_cast(reinterpret_cast(input->GetBuffer())), + GetLiteCVDataType(input->type())); + LiteMat lite_mat_convert; + std::shared_ptr output_tensor; + TensorShape new_shape = TensorShape({output_height, output_width, 1}); + RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), &output_tensor)); + uint8_t *buffer = reinterpret_cast(&(*output_tensor->begin())); + lite_mat_convert.Init(output_width, output_height, 1, reinterpret_cast(buffer), + GetLiteCVDataType(input->type())); + + bool ret = + ConvertRgbToGray(lite_mat_rgb, GetLiteCVDataType(input->type()), output_width, output_height, lite_mat_convert); + CHECK_FAIL_RETURN_UNEXPECTED(ret, "RgbToGray: RGBToGRAY failed."); + + *output = output_tensor; + } catch (std::runtime_error &e) { + RETURN_STATUS_UNEXPECTED("RgbToGray: " + std::string(e.what())); + } + return Status::OK(); +} + Status Pad(const std::shared_ptr &input, std::shared_ptr *output, const int32_t &pad_top, const int32_t &pad_bottom, const int32_t &pad_left, const int32_t &pad_right, const BorderType &border_types, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h index c89fe321cf..c3dc9abdc2 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/lite_image_utils.h @@ -95,6 +95,12 @@ Status Resize(const std::shared_ptr &input, std::shared_ptr *out int32_t output_width, double fx = 0.0, double fy = 0.0, InterpolationMode mode = InterpolationMode::kLinear); +/// \brief Take in a 3 channel image in RBG to GRAY +/// \param[in] input The input image +/// \param[out] output The output image +/// \return Status code +Status RgbToGray(const std::shared_ptr &input, std::shared_ptr *output); + /// \brief Pads the input image and puts the padded image in the output /// \param[in] input: input Tensor /// \param[out] output: padded Tensor diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.cc new file mode 100644 index 0000000000..cec8e5dcb1 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.cc @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/kernels/image/rgb_to_gray_op.h" +#ifndef ENABLE_ANDROID +#include "minddata/dataset/kernels/image/image_utils.h" +#else +#include "minddata/dataset/kernels/image/lite_image_utils.h" +#endif + +namespace mindspore { +namespace dataset { + +Status RgbToGrayOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + return RgbToGray(input, output); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.h new file mode 100644 index 0000000000..6972e7e525 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rgb_to_gray_op.h @@ -0,0 +1,42 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class RgbToGrayOp : public TensorOp { + public: + RgbToGrayOp() = default; + + ~RgbToGrayOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kRgbToGrayOp; } +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_RGB_TO_GRAY_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc index 771bbec289..e799277e07 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.cc @@ -72,6 +72,7 @@ #include "minddata/dataset/kernels/image/rgba_to_bgr_op.h" #include "minddata/dataset/kernels/image/rgba_to_rgb_op.h" #endif +#include "minddata/dataset/kernels/image/rgb_to_gray_op.h" #include "minddata/dataset/kernels/image/rotate_op.h" #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/soft_dvpp/soft_dvpp_decode_random_crop_resize_jpeg_op.h" @@ -232,6 +233,11 @@ Status CenterCropOperation::to_json(nlohmann::json *out_json) { return Status::OK(); } +// RGB2GRAYOperation +Status RgbToGrayOperation::ValidateParams() { return Status::OK(); } + +std::shared_ptr RgbToGrayOperation::Build() { return std::make_shared(); } + // CropOperation. CropOperation::CropOperation(std::vector coordinates, std::vector size) : coordinates_(coordinates), size_(size) {} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h index 953c0fecf7..b18ca219f5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/ir/vision/vision_ir.h @@ -74,6 +74,7 @@ constexpr char kResizeOperation[] = "Resize"; constexpr char kResizeWithBBoxOperation[] = "ResizeWithBBox"; constexpr char kRgbaToBgrOperation[] = "RgbaToBgr"; constexpr char kRgbaToRgbOperation[] = "RgbaToRgb"; +constexpr char kRgbToGrayOperation[] = "RgbToGray"; constexpr char kRotateOperation[] = "Rotate"; constexpr char kSoftDvppDecodeRandomCropResizeJpegOperation[] = "SoftDvppDecodeRandomCropResizeJpeg"; constexpr char kSoftDvppDecodeResizeJpegOperation[] = "SoftDvppDecodeResizeJpeg"; @@ -163,6 +164,19 @@ class CenterCropOperation : public TensorOperation { std::vector size_; }; +class RgbToGrayOperation : public TensorOperation { + public: + RgbToGrayOperation() = default; + + ~RgbToGrayOperation() = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kRgbToGrayOperation; } +}; + class CropOperation : public TensorOperation { public: CropOperation(std::vector coordinates, std::vector size); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 6a5a173d6a..b564376d76 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -97,6 +97,7 @@ constexpr char kResizeOp[] = "ResizeOp"; constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp"; constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp"; constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp"; +constexpr char kRgbToGrayOp[] = "RgbToGrayOp"; constexpr char kSharpnessOp[] = "SharpnessOp"; constexpr char kSolarizeOp[] = "SolarizeOp"; constexpr char kSwapRedBlueOp[] = "SwapRedBlueOp"; diff --git a/mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h b/mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h index 34c645dd94..6587e394d5 100644 --- a/mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h +++ b/mindspore/ccsrc/minddata/dataset/liteapi/include/vision_lite.h @@ -88,6 +88,22 @@ class CenterCrop : public TensorTransform { std::shared_ptr data_; }; +/// \brief RGB2GRAY TensorTransform. +/// \notes Convert RGB image or color image to grayscale image +class RGB2GRAY : public TensorTransform { + public: + /// \brief Constructor. + RGB2GRAY() = default; + + /// \brief Destructor. + ~RGB2GRAY() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; +}; + /// \brief Crop TensorTransform. /// \notes Crop an image based on location and crop size class Crop : public TensorTransform { diff --git a/mindspore/lite/minddata/CMakeLists.txt b/mindspore/lite/minddata/CMakeLists.txt index 38bc9e3134..dd99a551f7 100644 --- a/mindspore/lite/minddata/CMakeLists.txt +++ b/mindspore/lite/minddata/CMakeLists.txt @@ -191,13 +191,14 @@ if(BUILD_MINDDATA STREQUAL "full") ${MINDDATA_DIR}/util/cond_var.cc ${MINDDATA_DIR}/engine/data_schema.cc ${MINDDATA_DIR}/kernels/tensor_op.cc + ${MINDDATA_DIR}/kernels/image/affine_op.cc ${MINDDATA_DIR}/kernels/image/lite_image_utils.cc ${MINDDATA_DIR}/kernels/image/center_crop_op.cc ${MINDDATA_DIR}/kernels/image/crop_op.cc ${MINDDATA_DIR}/kernels/image/decode_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc - ${MINDDATA_DIR}/kernels/image/affine_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc + ${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc ${MINDDATA_DIR}/kernels/image/random_affine_op.cc ${MINDDATA_DIR}/kernels/image/math_utils.cc @@ -279,6 +280,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper") ${MINDDATA_DIR}/kernels/image/crop_op.cc ${MINDDATA_DIR}/kernels/image/normalize_op.cc ${MINDDATA_DIR}/kernels/image/resize_op.cc + ${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc ${MINDDATA_DIR}/kernels/image/rotate_op.cc ${MINDDATA_DIR}/kernels/data/compose_op.cc ${MINDDATA_DIR}/kernels/data/duplicate_op.cc @@ -377,6 +379,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite") "${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc" "${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc" "${MINDDATA_DIR}/kernels/image/rescale_op.cc" + "${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc" "${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc" "${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc" "${MINDDATA_DIR}/kernels/image/sharpness_op.cc" diff --git a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc index 5c6859051d..88715b71cd 100644 --- a/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc +++ b/tests/ut/cpp/dataset/c_api_vision_r_to_z_test.cc @@ -195,3 +195,39 @@ TEST_F(MindDataTestPipeline, TestResizeWithBBoxSuccess) { // Manually terminate the pipeline iter->Stop(); } + +TEST_F(MindDataTestPipeline, TestRGB2GRAYSucess) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRGB2GRAYSucess."; + // Create an ImageFolder Dataset + std::string folder_path = datasets_root_path_ + "/testPK/data/"; + std::shared_ptr ds = ImageFolder(folder_path, true, std::make_shared(0, 1)); + EXPECT_NE(ds, nullptr); + + // Create objects for the tensor ops + std::shared_ptr convert(new mindspore::dataset::vision::RGB2GRAY()); + + ds = ds->Map({convert}); + EXPECT_NE(ds, nullptr); + + // Create an iterator over the result of the above dataset + // This will trigger the creation of the Execution Tree and launch it. + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(iter, nullptr); + + // Iterate the dataset and get each row + std::unordered_map row; + iter->GetNextRow(&row); + + uint64_t i = 0; + while (row.size() != 0) { + i++; + auto image = row["image"]; + MS_LOG(INFO) << "Tensor image shape: " << image.Shape(); + iter->GetNextRow(&row); + } + + EXPECT_EQ(i, 1); + + // Manually terminate the pipeline + iter->Stop(); +} diff --git a/tests/ut/cpp/dataset/image_process_test.cc b/tests/ut/cpp/dataset/image_process_test.cc index b8baea089f..437f41affc 100644 --- a/tests/ut/cpp/dataset/image_process_test.cc +++ b/tests/ut/cpp/dataset/image_process_test.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - + #include "common/common.h" #include "lite_cv/lite_mat.h" #include "lite_cv/image_process.h" @@ -1714,3 +1714,25 @@ TEST_F(MindDataImageProcess, TestSobelFlag) { distance_x = sqrt(distance_x / total_size); EXPECT_EQ(distance_x, 0.0f); } + +TEST_F(MindDataImageProcess, testConvertRgbToGray) { + std::string filename = "data/dataset/apple.jpg"; + cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR); + cv::Mat rgb_mat; + cv::Mat rgb_mat1; + + cv::cvtColor(image, rgb_mat, CV_BGR2GRAY); + cv::imwrite("./opencv_image.jpg", rgb_mat); + + cv::cvtColor(image, rgb_mat1, CV_BGR2RGB); + + LiteMat lite_mat_rgb; + lite_mat_rgb.Init(rgb_mat1.cols, rgb_mat1.rows, rgb_mat1.channels(), rgb_mat1.data, LDataType::UINT8); + LiteMat lite_mat_gray; + bool ret = ConvertRgbToGray(lite_mat_rgb, LDataType::UINT8, image.cols, image.rows, lite_mat_gray); + ASSERT_TRUE(ret == true); + + cv::Mat dst_image(lite_mat_gray.height_, lite_mat_gray.width_, CV_8UC1, lite_mat_gray.data_ptr_); + cv::imwrite("./mindspore_image.jpg", dst_image); + CompareMat(rgb_mat, lite_mat_gray); +}