Add resizefiller

pull/13279/head
shenwei41 4 years ago
parent 7454ac8ecd
commit d1bb9d470e

@ -735,6 +735,22 @@ std::shared_ptr<TensorOperation> Resize::Parse(const MapTargetDevice &env) {
return std::make_shared<ResizeOperation>(data_->size_, data_->interpolation_);
}
// ResizePreserveAR Transform Operation.
struct ResizePreserveAR::Data {
Data(int32_t height, int32_t width, int32_t img_orientation)
: height_(height), width_(width), img_orientation_(img_orientation) {}
int32_t height_;
int32_t width_;
int32_t img_orientation_;
};
ResizePreserveAR::ResizePreserveAR(int32_t height, int32_t width, int32_t img_orientation)
: data_(std::make_shared<Data>(height, width, img_orientation)) {}
std::shared_ptr<TensorOperation> ResizePreserveAR::Parse() {
return std::make_shared<ResizePreserveAROperation>(data_->height_, data_->width_, data_->img_orientation_);
}
#ifdef ENABLE_ANDROID
// Rotate Transform Operation.
Rotate::Rotate() {}

@ -93,7 +93,7 @@ class CenterCrop final : public TensorTransform {
/// \brief RGB2GRAY TensorTransform.
/// \notes Convert RGB image or color image to grayscale image
class RGB2GRAY : public TensorTransform {
class RGB2GRAY final : public TensorTransform {
public:
/// \brief Constructor.
RGB2GRAY() = default;
@ -244,6 +244,29 @@ class Resize final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief ResizePreserveAR TensorTransform.
/// \notes Keep the original picture ratio and fill the rest.
class ResizePreserveAR final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] height The height of image output value after resizing.
/// \param[in] width The width of image output value after resizing.
/// \param[in] img_orientation Angle method of image rotation.
ResizePreserveAR(int32_t height, int32_t width, int32_t img_orientation = 0);
/// \brief Destructor.
~ResizePreserveAR() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Rotate TensorTransform.
/// \notes Rotate the input image using a specified angle id.
class Rotate final : public TensorTransform {

@ -44,6 +44,7 @@ add_library(kernels-image OBJECT
random_sharpness_op.cc
rescale_op.cc
resize_op.cc
resize_preserve_ar_op.cc
rgb_to_gray_op.cc
rgba_to_bgr_op.cc
rgba_to_rgb_op.cc

@ -28,6 +28,19 @@ namespace mindspore {
namespace dataset {
#define CV_PI 3.1415926535897932384626433832795
#define IM_TOOL_EXIF_ORIENTATION_0_DEG 1
#define IM_TOOL_EXIF_ORIENTATION_0_DEG_MIRROR 2
#define IM_TOOL_EXIF_ORIENTATION_180_DEG 3
#define IM_TOOL_EXIF_ORIENTATION_180_DEG_MIRROR 4
#define IM_TOOL_EXIF_ORIENTATION_90_DEG_MIRROR 5
#define IM_TOOL_EXIF_ORIENTATION_90_DEG 6
#define IM_TOOL_EXIF_ORIENTATION_270_DEG_MIRROR 7
#define IM_TOOL_EXIF_ORIENTATION_270_DEG 8
#define NUM_OF_RGB_CHANNELS 9
#define IM_TOOL_DATA_TYPE_FLOAT (1)
#define IM_TOOL_DATA_TYPE_UINT8 (2)
#define IM_TOOL_RETURN_STATUS_SUCCESS (0)
#define IM_TOOL_RETURN_STATUS_INVALID_INPUT (1)
#define INT16_CAST(X) \
static_cast<int16_t>(::std::min(::std::max(static_cast<int>(X + (X >= 0.f ? 0.5f : -0.5f)), -32768), 32767));
@ -140,6 +153,10 @@ bool Sobel(const LiteMat &src, LiteMat &dst, int flag_x, int flag_y, int ksize,
/// \brief Convert RGB image or color image to grayscale image
bool ConvertRgbToGray(const LiteMat &src, LDataType data_type, int w, int h, LiteMat &mat);
/// \brief Resize preserve AR with filler
bool ResizePreserveARWithFiller(LiteMat &src, LiteMat &dst, int h, int w, float (*ratioShiftWShiftH)[3],
float (*invM)[2][3], int img_orientation);
} // namespace dataset
} // namespace mindspore
#endif // IMAGE_PROCESS_H_

@ -63,6 +63,14 @@ struct Point {
Point(float _x, float _y) : x(_x), y(_y) {}
};
typedef struct imageToolsImage {
int w;
int h;
int stride;
int dataType;
void *image_buff;
} imageToolsImage_t;
using BOOL_C1 = Chn1<bool>;
using BOOL_C2 = Chn2<bool>;
using BOOL_C3 = Chn3<bool>;

@ -421,6 +421,43 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
return Status::OK();
}
Status ResizePreserve(const TensorRow &inputs, int32_t height, int32_t width, int32_t img_orientation,
TensorRow *outputs) {
outputs->resize(3);
std::shared_ptr<Tensor> input = inputs[0];
LiteMat lite_mat_src(input->shape()[1], input->shape()[0], input->shape()[2],
const_cast<void *>(reinterpret_cast<const void *>(input->GetBuffer())),
GetLiteCVDataType(input->type()));
LiteMat lite_mat_dst;
std::shared_ptr<Tensor> image_tensor;
TensorShape new_shape = TensorShape({height, width, input->shape()[2]});
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, DataType(DataType::DE_FLOAT32), &image_tensor));
uint8_t *buffer = reinterpret_cast<uint8_t *>(&(*image_tensor->begin<uint8_t>()));
lite_mat_dst.Init(width, height, input->shape()[2], reinterpret_cast<void *>(buffer), LDataType::FLOAT32);
float ratioShiftWShiftH[3] = {0};
float invM[2][3] = {{0, 0, 0}, {0, 0, 0}};
bool ret =
ResizePreserveARWithFiller(lite_mat_src, lite_mat_dst, height, width, &ratioShiftWShiftH, &invM, img_orientation);
CHECK_FAIL_RETURN_UNEXPECTED(ret, "Resize: bilinear resize failed.");
std::shared_ptr<Tensor> ratio_tensor;
TensorShape ratio_shape = TensorShape({3});
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(ratio_shape, DataType(DataType::DE_FLOAT32),
reinterpret_cast<uint8_t *>(&ratioShiftWShiftH), &ratio_tensor));
std::shared_ptr<Tensor> invM_tensor;
TensorShape invM_shape = TensorShape({2, 3});
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(invM_shape, DataType(DataType::DE_FLOAT32),
reinterpret_cast<uint8_t *>(&invM), &invM_tensor));
(*outputs)[0] = image_tensor;
(*outputs)[1] = ratio_tensor;
(*outputs)[2] = invM_tensor;
return Status::OK();
}
Status RgbToGray(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
if (input->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("RgbToGray: input image is not in shape of <H,W,C>");

@ -95,6 +95,15 @@ Status Resize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
int32_t output_width, double fx = 0.0, double fy = 0.0,
InterpolationMode mode = InterpolationMode::kLinear);
/// \brief Returns Resized image.
/// \param[in] inputs input TensorRow
/// \param[in] height Height of output
/// \param[in] width Width of output
/// \param[in] img_orientation Angle method of image rotation
/// \param[out] outputs Resized image of shape <height,width,C> and same type as input
Status ResizePreserve(const TensorRow &inputs, int32_t height, int32_t width, int32_t img_orientation,
TensorRow *outputs);
/// \brief Take in a 3 channel image in RBG to GRAY
/// \param[in] input The input image
/// \param[out] output The output image

@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/resize_preserve_ar_op.h"
#ifdef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const int32_t ResizePreserveAROp::kDefImgorientation = 0;
ResizePreserveAROp::ResizePreserveAROp(int32_t height, int32_t width, int32_t img_orientation)
: height_(height), width_(width), img_orientation_(img_orientation) {}
Status ResizePreserveAROp::Compute(const TensorRow &inputs, TensorRow *outputs) {
IO_CHECK_VECTOR(inputs, outputs);
#ifdef ENABLE_ANDROID
return ResizePreserve(inputs, height_, width_, img_orientation_, outputs);
#endif
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_PRESERVE_AR_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_PRESERVE_AR_OP_H_
#include <memory>
#include <vector>
#include <string>
#include "minddata/dataset/core/tensor.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/image_utils.h"
#else
#include "minddata/dataset/kernels/image/lite_image_utils.h"
#endif
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class ResizePreserveAROp : public TensorOp {
public:
// Default values, also used by python_bindings.cc
static const int32_t kDefImgorientation;
ResizePreserveAROp(int32_t height, int32_t width, int32_t img_orientation = kDefImgorientation);
~ResizePreserveAROp() override = default;
Status Compute(const TensorRow &input, TensorRow *output) override;
std::string Name() const override { return kResizePreserveAROp; }
protected:
int32_t height_;
int32_t width_;
int32_t img_orientation_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_RESIZE_PRESERVE_AR_OP_H_

@ -72,6 +72,7 @@
#include "minddata/dataset/kernels/image/rgba_to_bgr_op.h"
#include "minddata/dataset/kernels/image/rgba_to_rgb_op.h"
#endif
#include "minddata/dataset/kernels/image/resize_preserve_ar_op.h"
#include "minddata/dataset/kernels/image/rgb_to_gray_op.h"
#include "minddata/dataset/kernels/image/rotate_op.h"
#ifndef ENABLE_ANDROID
@ -1421,6 +1422,25 @@ Status ResizeOperation::to_json(nlohmann::json *out_json) {
return Status::OK();
}
// ResizePreserveAROperation
ResizePreserveAROperation::ResizePreserveAROperation(int32_t height, int32_t width, int32_t img_orientation)
: height_(height), width_(width), img_orientation_(img_orientation) {}
Status ResizePreserveAROperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> ResizePreserveAROperation::Build() {
return std::make_shared<ResizePreserveAROp>(height_, width_, img_orientation_);
}
Status ResizePreserveAROperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["height"] = height_;
args["width"] = width_;
args["img_orientation"] = img_orientation_;
*out_json = args;
return Status::OK();
}
// RotateOperation
RotateOperation::RotateOperation() { rotate_op = std::make_shared<RotateOp>(0); }

@ -71,6 +71,7 @@ constexpr char kRandomVerticalFlipOperation[] = "RandomVerticalFlip";
constexpr char kRandomVerticalFlipWithBBoxOperation[] = "RandomVerticalFlipWithBBox";
constexpr char kRescaleOperation[] = "Rescale";
constexpr char kResizeOperation[] = "Resize";
constexpr char kResizePreserveAROperation[] = "ResizePreserveAR";
constexpr char kResizeWithBBoxOperation[] = "ResizeWithBBox";
constexpr char kRgbaToBgrOperation[] = "RgbaToBgr";
constexpr char kRgbaToRgbOperation[] = "RgbaToRgb";
@ -781,6 +782,26 @@ class ResizeOperation : public TensorOperation {
InterpolationMode interpolation_;
};
class ResizePreserveAROperation : public TensorOperation {
public:
ResizePreserveAROperation(int32_t height, int32_t width, int32_t img_orientation);
~ResizePreserveAROperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kResizePreserveAROperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t height_;
int32_t width_;
int32_t img_orientation_;
};
class ResizeWithBBoxOperation : public TensorOperation {
public:
explicit ResizeWithBBoxOperation(std::vector<int32_t> size, InterpolationMode interpolation_mode);

@ -94,6 +94,7 @@ constexpr char kRandomVerticalFlipWithBBoxOp[] = "RandomVerticalFlipWithBBoxOp";
constexpr char kRescaleOp[] = "RescaleOp";
constexpr char kResizeBilinearOp[] = "ResizeBilinearOp";
constexpr char kResizeOp[] = "ResizeOp";
constexpr char kResizePreserveAROp[] = "ResizePreserveAROp";
constexpr char kResizeWithBBoxOp[] = "ResizeWithBBoxOp";
constexpr char kRgbaToBgrOp[] = "RgbaToBgrOp";
constexpr char kRgbaToRgbOp[] = "RgbaToRgbOp";

@ -88,6 +88,29 @@ class CenterCrop : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief ResizePreserveAR TensorTransform.
/// \notes Keep the original picture ratio and fill the rest.
class ResizePreserveAR final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] height The height of image output value after resizing.
/// \param[in] width The width of image output value after resizing.
/// \param[in] img_orientation Angle method of image rotation.
ResizePreserveAR(int32_t height, int32_t width, int32_t img_orientation = 0);
/// \brief Destructor.
~ResizePreserveAR() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief RGB2GRAY TensorTransform.
/// \notes Convert RGB image or color image to grayscale image
class RGB2GRAY : public TensorTransform {

@ -199,6 +199,7 @@ if(BUILD_MINDDATA STREQUAL "full")
${MINDDATA_DIR}/kernels/image/decode_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc
${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc
${MINDDATA_DIR}/kernels/image/rotate_op.cc
${MINDDATA_DIR}/kernels/image/random_affine_op.cc
@ -282,6 +283,7 @@ elseif(BUILD_MINDDATA STREQUAL "wrapper")
${MINDDATA_DIR}/kernels/image/crop_op.cc
${MINDDATA_DIR}/kernels/image/normalize_op.cc
${MINDDATA_DIR}/kernels/image/resize_op.cc
${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc.cc
${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc
${MINDDATA_DIR}/kernels/image/rotate_op.cc
${MINDDATA_DIR}/kernels/data/compose_op.cc
@ -381,6 +383,7 @@ elseif(BUILD_MINDDATA STREQUAL "lite")
"${MINDDATA_DIR}/kernels/image/random_vertical_flip_with_bbox_op.cc"
"${MINDDATA_DIR}/kernels/image/random_sharpness_op.cc"
"${MINDDATA_DIR}/kernels/image/rescale_op.cc"
"${MINDDATA_DIR}/kernels/image/resize_preserve_ar_op.cc"
"${MINDDATA_DIR}/kernels/image/rgb_to_gray_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_bgr_op.cc"
"${MINDDATA_DIR}/kernels/image/rgba_to_rgb_op.cc"

@ -1,11 +1,11 @@
cmake_minimum_required(VERSION 3.14.1)
project(testlenet)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -fPIC -std=c++17")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-compare")
set(MD_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mindspore-lite-1.1.0-inference-linux-x64/minddata")
set(MS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mindspore-lite-1.1.0-inference-linux-x64/")
set(MD_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mindspore-lite-1.2.0-inference-linux-x64/minddata")
set(MS_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mindspore-lite-1.2.0-inference-linux-x64/")
include_directories(${MD_DIR})
include_directories(${MS_DIR})
@ -16,6 +16,17 @@ add_executable(testlenet
)
target_link_libraries(testlenet
${MD_DIR}/lib/libminddata-lite.so
${MD_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so.62
${MD_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so.0
${MS_DIR}/lib/libmindspore-lite.so
pthread)
add_executable(testresize
${CMAKE_CURRENT_SOURCE_DIR}/testresize.cpp
)
target_link_libraries(testresize
${MD_DIR}/lib/libminddata-lite.so
${MD_DIR}/third_party/libjpeg-turbo/lib/libjpeg.so.62
${MD_DIR}/third_party/libjpeg-turbo/lib/libturbojpeg.so.0

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <sys/stat.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "include/datasets.h"
#include "include/iterator.h"
#include "include/vision_lite.h"
#include "include/transforms.h"
#include "include/api/types.h"
using mindspore::dataset::Album;
using mindspore::dataset::Dataset;
using mindspore::dataset::Iterator;
using mindspore::dataset::SequentialSampler;
using mindspore::dataset::TensorTransform;
using mindspore::dataset::vision::ResizePreserveAR;
int main(int argc, char **argv) {
std::string folder_path = "./testAlbum/images";
std::string schema_file = "./testAlbum/datasetSchema.json";
std::vector<std::string> column_names = {"image", "label", "id"};
// Create a Album Dataset
std::shared_ptr<Dataset> ds =
Album(folder_path, schema_file, column_names, true, std::make_shared<SequentialSampler>(0, 1));
ds = ds->SetNumWorkers(1);
std::shared_ptr<TensorTransform> resize(new ResizePreserveAR(1000, 1000));
ds = ds->Map({resize}, {"image"}, {"image", "ratio", "invM"});
std::shared_ptr<Iterator> iter = ds->CreateIterator();
std::unordered_map<std::string, mindspore::MSTensor> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
iter->GetNextRow(&row);
}
iter->Stop();
}

@ -1736,3 +1736,20 @@ TEST_F(MindDataImageProcess, testConvertRgbToGray) {
cv::imwrite("./mindspore_image.jpg", dst_image);
CompareMat(rgb_mat, lite_mat_gray);
}
TEST_F(MindDataImageProcess, testResizePreserveARWithFillerv) {
std::string filename = "data/dataset/apple.jpg";
cv::Mat image = cv::imread(filename, cv::ImreadModes::IMREAD_COLOR);
LiteMat lite_mat_rgb;
lite_mat_rgb.Init(image.cols, image.rows, image.channels(), image.data, LDataType::UINT8);
LiteMat lite_mat_resize;
float ratioShiftWShiftH[3] = {0};
float invM[2][3] = {{0, 0, 0}, {0, 0, 0}};
int h = 1000;
int w = 1000;
bool ret = ResizePreserveARWithFiller(lite_mat_rgb, lite_mat_resize, h, w, &ratioShiftWShiftH, &invM, 0);
ASSERT_TRUE(ret == true);
cv::Mat dst_image(lite_mat_resize.height_, lite_mat_resize.width_, CV_32FC3, lite_mat_resize.data_ptr_);
cv::imwrite("./mindspore_image.jpg", dst_image);
}

Loading…
Cancel
Save