!9910 [MD][Perf] MindData add NormalizePad for GPU performance

From: @xiefangqi
Reviewed-by: 
Signed-off-by:
pull/9910/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c4f284b928

@ -31,6 +31,7 @@
#include "minddata/dataset/kernels/image/invert_op.h"
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/normalize_pad_op.h"
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h"
#include "minddata/dataset/kernels/image/random_color_op.h"
@ -71,6 +72,11 @@ PYBIND_REGISTER(NormalizeOp, 1, ([](const py::module *m) {
.def(py::init<float, float, float, float, float, float>());
}));
PYBIND_REGISTER(NormalizePadOp, 1, ([](const py::module *m) {
(void)py::class_<NormalizePadOp, TensorOp, std::shared_ptr<NormalizePadOp>>(*m, "NormalizePadOp")
.def(py::init<float, float, float, float, float, float, std::string>());
}));
PYBIND_REGISTER(
EqualizeOp, 1, ([](const py::module *m) {
(void)py::class_<EqualizeOp, TensorOp, std::shared_ptr<EqualizeOp>>(*m, "EqualizeOp").def(py::init<>());

@ -38,6 +38,7 @@
#include "minddata/dataset/kernels/image/mixup_batch_op.h"
#endif
#include "minddata/dataset/kernels/image/normalize_op.h"
#include "minddata/dataset/kernels/image/normalize_pad_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/kernels/image/pad_op.h"
#include "minddata/dataset/kernels/image/random_affine_op.h"
@ -169,6 +170,14 @@ std::shared_ptr<NormalizeOperation> Normalize(std::vector<float> mean, std::vect
}
#ifndef ENABLE_ANDROID
// Function to create NormalizePadOperation.
std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std,
const std::string &dtype) {
auto op = std::make_shared<NormalizePadOperation>(mean, std, dtype);
// Input validation
return op->ValidateParams() ? op : nullptr;
}
// Function to create PadOperation.
std::shared_ptr<PadOperation> Pad(std::vector<int32_t> padding, std::vector<uint8_t> fill_value,
BorderType padding_mode) {
@ -668,7 +677,7 @@ Status NormalizeOperation::ValidateParams() {
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) {
if (mean_[i] < 0.0f || mean_[i] > 255.0f) {
std::string err_msg = "Normalize: mean vector has incorrect value: " + std::to_string(mean_[i]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
@ -682,6 +691,47 @@ std::shared_ptr<TensorOp> NormalizeOperation::Build() {
}
#ifndef ENABLE_ANDROID
// NormalizePadOperation
NormalizePadOperation::NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std,
const std::string &dtype)
: mean_(mean), std_(std), dtype_(dtype) {}
Status NormalizePadOperation::ValidateParams() {
if (mean_.size() != 3) {
std::string err_msg = "NormalizePad: mean vector has incorrect size: " + std::to_string(mean_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (std_.size() != 3) {
std::string err_msg = "NormalizePad: std vector has incorrect size: " + std::to_string(std_.size());
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// check std/mean value
for (int32_t i = 0; i < std_.size(); ++i) {
if (std_[i] < 0.0f || std_[i] > 255.0f || CmpFloat(std_[i], 0.0f)) {
std::string err_msg = "NormalizePad: std vector has incorrect value: " + std::to_string(std_[i]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (mean_[i] < 0.0f || mean_[i] > 255.0f) {
std::string err_msg = "NormalizePad: mean vector has incorrect value: " + std::to_string(mean_[i]);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
}
if (dtype_ != "float32" && dtype_ != "float16") {
std::string err_msg = "NormalizePad: dtype must be float32 or float16, but got: " + dtype_;
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}
std::shared_ptr<TensorOp> NormalizePadOperation::Build() {
return std::make_shared<NormalizePadOp>(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2], dtype_);
}
// PadOperation
PadOperation::PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value, BorderType padding_mode)
: padding_(padding), fill_value_(fill_value), padding_mode_(padding_mode) {}

@ -42,6 +42,7 @@ constexpr char kEqualizeOperation[] = "Equalize";
constexpr char kHwcToChwOperation[] = "HwcToChw";
constexpr char kInvertOperation[] = "Invert";
constexpr char kMixUpBatchOperation[] = "MixUpBatch";
constexpr char kNormalizePadOperation[] = "NormalizePad";
constexpr char kPadOperation[] = "Pad";
constexpr char kRandomAffineOperation[] = "RandomAffine";
constexpr char kRandomColorAdjustOperation[] = "RandomColorAdjust";
@ -79,6 +80,7 @@ class EqualizeOperation;
class HwcToChwOperation;
class InvertOperation;
class MixUpBatchOperation;
class NormalizePadOperation;
class PadOperation;
class RandomAffineOperation;
class RandomColorOperation;
@ -162,6 +164,19 @@ std::shared_ptr<InvertOperation> Invert();
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<MixUpBatchOperation> MixUpBatch(float alpha = 1);
/// \brief Function to create a NormalizePad TensorOperation.
/// \notes Normalize the input image with respect to mean and standard deviation and pad an extra
/// channel with value zero.
/// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
/// The mean values must be in range [0.0, 255.0].
/// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order.
/// The standard deviation values must be in range (0.0, 255.0]
/// \param[in] dtype The output datatype of Tensor.
/// The standard deviation values must be "float32" or "float16"default = "float32"
/// \return Shared pointer to the current TensorOperation.
std::shared_ptr<NormalizePadOperation> NormalizePad(const std::vector<float> &mean, const std::vector<float> &std,
const std::string &dtype = "float32");
/// \brief Function to create a Pad TensorOp
/// \notes Pads the image according to padding parameters
/// \param[in] padding A vector representing the number of pixels to pad the image
@ -587,6 +602,25 @@ class MixUpBatchOperation : public TensorOperation {
float alpha_;
};
class NormalizePadOperation : public TensorOperation {
public:
NormalizePadOperation(const std::vector<float> &mean, const std::vector<float> &std,
const std::string &dtype = "float32");
~NormalizePadOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kNormalizePadOperation; }
private:
std::vector<float> mean_;
std::vector<float> std_;
std::string dtype_;
};
class PadOperation : public TensorOperation {
public:
PadOperation(std::vector<int32_t> padding, std::vector<uint8_t> fill_value = {0},

@ -81,7 +81,7 @@ std::shared_ptr<DecodeOperation> Decode(bool rgb = true);
/// \brief Function to create a Normalize TensorOperation.
/// \notes Normalize the input image with respect to mean and standard deviation.
/// \param[in] mean A vector of mean values for each channel, w.r.t channel order.
/// The mean values must be in range (0.0, 255.0].
/// The mean values must be in range [0.0, 255.0].
/// \param[in] std A vector of standard deviations for each channel, w.r.t. channel order.
/// The standard deviation values must be in range (0.0, 255.0]
/// \return Shared pointer to the current TensorOperation.

@ -18,6 +18,7 @@ add_library(kernels-image OBJECT
math_utils.cc
mixup_batch_op.cc
normalize_op.cc
normalize_pad_op.cc
pad_op.cc
posterize_op.cc
random_affine_op.cc

@ -630,6 +630,57 @@ Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
}
}
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype) {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (!(input_cv->mat().data && input_cv->Rank() == 3)) {
RETURN_STATUS_UNEXPECTED("Could not convert to CV Tensor");
}
DataType tensor_type = DataType(DataType::DE_FLOAT32);
int compute_type = CV_32F;
int channel_type = CV_32FC1;
if (dtype == "float16") {
compute_type = CV_16F;
channel_type = CV_16FC1;
tensor_type = DataType(DataType::DE_FLOAT16);
}
cv::Mat in_image = input_cv->mat();
std::shared_ptr<CVTensor> output_cv;
TensorShape new_shape({input_cv->shape()[0], input_cv->shape()[1], input_cv->shape()[2] + 1});
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(new_shape, tensor_type, &output_cv));
mean->Squeeze();
if (mean->type() != DataType::DE_FLOAT32 || mean->Rank() != 1 || mean->shape()[0] != 3) {
std::string err_msg = "Mean tensor should be of size 3 and type float.";
return Status(StatusCode::kShapeMisMatch, err_msg);
}
std->Squeeze();
if (std->type() != DataType::DE_FLOAT32 || std->Rank() != 1 || std->shape()[0] != 3) {
std::string err_msg = "Std tensor should be of size 3 and type float.";
return Status(StatusCode::kShapeMisMatch, err_msg);
}
try {
// NOTE: We are assuming the input image is in RGB and the mean
// and std are in RGB
std::vector<cv::Mat> rgb;
cv::split(in_image, rgb);
if (rgb.size() != 3) {
RETURN_STATUS_UNEXPECTED("Input image is not in RGB.");
}
for (uint8_t i = 0; i < 3; i++) {
float mean_c, std_c;
RETURN_IF_NOT_OK(mean->GetItemAt<float>(&mean_c, {i}));
RETURN_IF_NOT_OK(std->GetItemAt<float>(&std_c, {i}));
rgb[i].convertTo(rgb[i], compute_type, 1.0 / std_c, (-mean_c / std_c));
}
rgb.push_back(cv::Mat::zeros(in_image.rows, in_image.cols, channel_type));
cv::merge(rgb, output_cv->mat());
*output = std::static_pointer_cast<Tensor>(output_cv);
return Status::OK();
} catch (const cv::Exception &e) {
RETURN_STATUS_UNEXPECTED("Unexpected error in NormalizePad");
}
}
Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const float &alpha) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);

@ -185,6 +185,15 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
Status Normalize(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std);
/// \brief Returns Normalized and paded image
/// \param input: Tensor of shape <H,W,C> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param mean: Tensor of shape <3> and type DE_FLOAT32 which are mean of each channel in RGB order
/// \param std: Tensor of shape <3> and type DE_FLOAT32 which are std of each channel in RGB order
/// \param dtype: output dtype
/// \param output: Normalized image Tensor and pad an extra channel, return a dtype Tensor
Status NormalizePad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
const std::shared_ptr<Tensor> &mean, const std::shared_ptr<Tensor> &std, const std::string &dtype);
/// \brief Returns image with adjusted brightness.
/// \param input: Tensor of shape <H,W,3> in RGB order and any OpenCv compatible type, see CVTensor.
/// \param alpha: Alpha value to adjust brightness by. Should be a positive number.

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/kernels/image/normalize_pad_op.h"
#include <random>
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
NormalizePadOp::NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b,
std::string dtype) {
Status s = Tensor::CreateFromVector<float>({mean_r, mean_g, mean_b}, &mean_);
if (s.IsError()) {
MS_LOG(ERROR) << "Could not create mean tensor.";
}
s = Tensor::CreateFromVector<float>({std_r, std_g, std_b}, &std_);
if (s.IsError()) {
MS_LOG(ERROR) << "Could not create std tensor.";
}
dtype_ = dtype;
}
Status NormalizePadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// Doing the normalization + pad
return NormalizePad(input, output, mean_, std_, dtype_);
}
void NormalizePadOp::Print(std::ostream &out) const {
out << "NormalizeOp, mean: " << *(mean_.get()) << std::endl << "std: " << *(std_.get()) << std::endl;
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,49 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_PAD_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class NormalizePadOp : public TensorOp {
public:
NormalizePadOp(float mean_r, float mean_g, float mean_b, float std_r, float std_g, float std_b,
std::string dtype = "float32");
~NormalizePadOp() override = default;
void Print(std::ostream &out) const override;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kNormalizePadOp; }
private:
std::shared_ptr<Tensor> mean_;
std::shared_ptr<Tensor> std_;
std::string dtype_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_IMAGE_NORMALIZE_OP_H_

@ -62,6 +62,7 @@ constexpr char kHwcToChwOp[] = "HwcToChwOp";
constexpr char kInvertOp[] = "InvertOp";
constexpr char kMixUpBatchOp[] = "MixUpBatchOp";
constexpr char kNormalizeOp[] = "NormalizeOp";
constexpr char kNormalizePadOp[] = "NormalizePadOp";
constexpr char kPadOp[] = "PadOp";
constexpr char kRandomColorAdjustOp[] = "RandomColorAdjustOp";
constexpr char kRandomCropAndResizeOp[] = "RandomCropAndResizeOp";

@ -50,8 +50,8 @@ import mindspore._c_dataengine as cde
from .utils import Inter, Border, ImageBatchFormat
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_mix_up_batch_c, check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_mix_up_batch_c, check_normalize_c, check_normalizepad_c, check_random_crop, check_random_color_adjust, \
check_random_rotation, check_range, check_resize, check_rescale, check_pad, check_cutout, \
check_uniform_augment_cpp, \
check_bounding_box_augment_cpp, check_random_select_subpolicy_op, check_auto_contrast, check_random_affine, \
check_random_solarize, check_soft_dvpp_decode_random_crop_resize_jpeg, check_positive_degrees, FLOAT_MAX_INTEGER, \
@ -319,6 +319,50 @@ class Normalize(cde.NormalizeOp):
return img.as_array()
class NormalizePad(cde.NormalizePadOp):
"""
Normalize the input image with respect to mean and standard deviation then pad an extra channel with value zero.
Args:
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
The mean values must be in range (0.0, 255.0].
std (sequence): List or tuple of standard deviations for each channel, with respect to channel order.
The standard deviation values must be in range (0.0, 255.0].
dtype (str): Set the output data type of normalized image (default is "float32").
Examples:
>>> import mindspore.dataset.vision.c_transforms as c_vision
>>>
>>> decode_op = c_vision.Decode()
>>> normalize_op = c_vision.NormalizePad(mean=[121.0, 115.0, 100.0], std=[70.0, 68.0, 71.0], dtype="float32")
>>> transforms_list = [decode_op, normalize_pad_op]
>>> data1 = data1.map(operations=transforms_list, input_columns=["image"])
"""
@check_normalizepad_c
def __init__(self, mean, std, dtype="float32"):
self.mean = mean
self.std = std
self.dtype = dtype
super().__init__(*mean, *std, dtype)
def __call__(self, img):
"""
Call method.
Args:
img (NumPy or PIL image): Image array to be normalizepad.
Returns:
img (NumPy), NormalizePaded Image array.
"""
if not isinstance(img, (np.ndarray, Image.Image)):
raise TypeError("Input should be NumPy or PIL image, got {}.".format(type(img)))
normalize_pad = cde.Execute(cde.NormalizePadOp(*self.mean, *self.std, self.dtype))
img = normalize_pad(cde.Tensor(np.asarray(img)))
return img.as_array()
class RandomAffine(cde.RandomAffineOp):
"""
Apply Random affine transformation to the input image.

@ -28,7 +28,7 @@ from PIL import Image
from . import py_transforms_util as util
from .c_transforms import parse_padding
from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \
check_normalize_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_normalize_py, check_normalizepad_py, check_random_crop, check_random_color_adjust, check_random_rotation, \
check_ten_crop, check_num_channels, check_pad, \
check_random_perspective, check_random_erasing, check_cutout, check_linear_transform, check_random_affine, \
check_mix_up, check_positive_degrees, check_uniform_augment_py, check_auto_contrast
@ -231,6 +231,49 @@ class Normalize:
return util.normalize(img, self.mean, self.std)
class NormalizePad:
"""
Normalize the input NumPy image array of shape (C, H, W) with the given mean and standard deviation
then pad an extra channel with value zero.
The values of the array need to be in the range (0.0, 1.0].
Args:
mean (sequence): List or tuple of mean values for each channel, with respect to channel order.
The mean values must be in the range (0.0, 1.0].
std (sequence): List or tuple of standard deviations for each channel, w.r.t. channel order.
The standard deviation values must be in the range (0.0, 1.0].
dtype (str): Set the output data type of image (default is "float32").
Examples:
>>> import mindspore.dataset.vision.py_transforms as py_vision
>>> from mindspore.dataset.transforms.py_transforms import Compose
>>>
>>> Compose([py_vision.Decode(),
>>> py_vision.RandomHorizontalFlip(0.5),
>>> py_vision.ToTensor(),
>>> py_vision.NormalizePad((0.491, 0.482, 0.447), (0.247, 0.243, 0.262), "float32")])
"""
@check_normalizepad_py
def __init__(self, mean, std, dtype="float32"):
self.mean = mean
self.std = std
self.dtype = dtype
def __call__(self, img):
"""
Call method.
Args:
img (numpy.ndarray): Image array to be normalizepad.
Returns:
img (numpy.ndarray), NormalizePaded Image array.
"""
return util.normalize(img, self.mean, self.std, pad_channel=True, dtype=self.dtype)
class RandomCrop:
"""
Crop the input PIL image at a random location.

@ -42,7 +42,7 @@ def is_pil(img):
return isinstance(img, Image.Image)
def normalize(img, mean, std):
def normalize(img, mean, std, pad_channel=False, dtype="float32"):
"""
Normalize the image between [0, 1] with respect to mean and standard deviation.
@ -50,6 +50,8 @@ def normalize(img, mean, std):
img (numpy.ndarray): Image array of shape CHW to be normalized.
mean (list): List of mean values for each channel, w.r.t channel order.
std (list): List of standard deviations for each channel, w.r.t. channel order.
pad_channel (bool): Whether to pad a extra channel with value zero.
dtype (str): Output datatype of normalize, only worked when pad_channel is True. (default is "float32")
Returns:
img (numpy.ndarray), Normalized image.
@ -72,7 +74,13 @@ def normalize(img, mean, std):
mean = np.array(mean, dtype=img.dtype)
std = np.array(std, dtype=img.dtype)
return (img - mean[:, None, None]) / std[:, None, None]
image = (img - mean[:, None, None]) / std[:, None, None]
if pad_channel:
zeros = np.zeros([1, image.shape[1], image.shape[2]], dtype=np.float32)
image = np.concatenate((image, zeros), axis=0)
if dtype == "float16":
image = image.astype(np.float16)
return image
def decode(img):

@ -294,6 +294,40 @@ def check_normalize_py(method):
return new_method
def check_normalizepad_c(method):
"""A wrapper that wraps a parameter checker around the original function(normalizepad operation written in C++)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
check_normalize_c_param(mean, std)
if not isinstance(dtype, str):
raise TypeError("dtype should be string.")
if dtype not in ["float32", "float16"]:
raise ValueError("dtype only support float32 or float16.")
return method(self, *args, **kwargs)
return new_method
def check_normalizepad_py(method):
"""A wrapper that wraps a parameter checker around the original function(normalizepad operation written in Python)."""
@wraps(method)
def new_method(self, *args, **kwargs):
[mean, std, dtype], _ = parse_user_args(method, *args, **kwargs)
check_normalize_py_param(mean, std)
if not isinstance(dtype, str):
raise TypeError("dtype should be string.")
if dtype not in ["float32", "float16"]:
raise ValueError("dtype only support float32 or float16.")
return method(self, *args, **kwargs)
return new_method
def check_random_crop(method):
"""Wrapper method to check the parameters of random crop."""

@ -58,11 +58,6 @@ class MyTimeMonitor(Callback):
fps = self.batch_size / step_mseconds *1000 * self.size
print("Epoch time: {:5.3f} ms, fps: {:d} img/sec.".format(step_mseconds, int(fps)), flush=True, end=" ")
def pad(image):
zeros = np.zeros([224, 224, 1], dtype=np.uint8)
output = np.concatenate((image, zeros), axis=2)
return output
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="GPU", dtype="fp16"):
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=4, shuffle=True)
@ -71,24 +66,25 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, target="
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
# define map operations
normalize_op = C.Normalize(mean=mean, std=std)
if dtype == "float16":
normalize_op = C.NormalizePad(mean=mean, std=std, dtype="float16")
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
normalize_op,
]
else:
trans = [
C.Decode(),
C.Resize(256),
C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std),
normalize_op,
]
if dtype == "fp32":
trans.append(C.HWC2CHW())
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=4)
if dtype == "fp16":
ds = ds.map(operations=pad, input_columns="image", num_parallel_workers=4)
ds = ds.map(operations=trans, input_columns="image", num_parallel_workers=8)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation

@ -932,6 +932,70 @@ TEST_F(MindDataTestPipeline, TestNormalizeFail) {
EXPECT_EQ(normalize, nullptr);
}
TEST_F(MindDataTestPipeline, TestNormalizePad) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePad.";
// Create an ImageFolder Dataset
std::string folder_path = datasets_root_path_ + "/testPK/data/";
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(false, 10));
EXPECT_NE(ds, nullptr);
// Create a Repeat operation on ds
int32_t repeat_num = 2;
ds = ds->Repeat(repeat_num);
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorOperation> normalizepad = vision::NormalizePad({121.0, 115.0, 100.0}, {70.0, 68.0, 71.0},
"float32");
EXPECT_NE(normalizepad, nullptr);
// Create a Map operation on ds
ds = ds->Map({normalizepad});
EXPECT_NE(ds, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(iter, nullptr);
// Iterate the dataset and get each row
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
iter->GetNextRow(&row);
uint64_t i = 0;
while (row.size() != 0) {
i++;
auto image = row["image"];
EXPECT_EQ(image->shape()[2], 4);
MS_LOG(INFO) << "Tensor image shape: " << image->shape();
iter->GetNextRow(&row);
}
EXPECT_EQ(i, 20);
// Manually terminate the pipeline
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestNormalizePadFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizePadFail with invalid parameters.";
// std value at 0.0
std::shared_ptr<TensorOperation> normalizepad =
mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0});
EXPECT_EQ(normalizepad, nullptr);
// normalizepad with 2 values (not 3 values) for mean
normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0}, {70.0, 68.0, 71.0});
EXPECT_EQ(normalizepad, nullptr);
// normalizepad with 2 values (not 3 values) for standard deviation
normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0});
EXPECT_EQ(normalizepad, nullptr);
// normalizepad with invalid dtype
normalizepad = mindspore::dataset::vision::NormalizePad({121.0, 115.0, 100.0}, {68.0, 71.0, 71.0}, "123");
EXPECT_EQ(normalizepad, nullptr);
}
TEST_F(MindDataTestPipeline, TestPad) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPad.";

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "minddata/dataset/kernels/image/normalize_pad_op.h"
#include "minddata/dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
#include <opencv2/opencv.hpp>
using namespace mindspore::dataset;
using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestNormalizePadOP : public UT::CVOP::CVOpCommon {
public:
MindDataTestNormalizePadOP() : CVOpCommon() {}
};
TEST_F(MindDataTestNormalizePadOP, TestFloat32) {
MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat32.";
std::shared_ptr<Tensor> output_tensor;
// Numbers are from the resnet50 model implementation
float mean[3] = {121.0, 115.0, 100.0};
float std[3] = {70.0, 68.0, 71.0};
// NormalizePad Op
std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float32"));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
}
TEST_F(MindDataTestNormalizePadOP, TestFloat16) {
MS_LOG(INFO) << "Doing TestNormalizePadOp::TestFloat16.";
std::shared_ptr<Tensor> output_tensor;
// Numbers are from the resnet50 model implementation
float mean[3] = {121.0, 115.0, 100.0};
float std[3] = {70.0, 68.0, 71.0};
// NormalizePad Op
std::unique_ptr<NormalizePadOp> op(new NormalizePadOp(mean[0], mean[1], mean[2], std[0], std[1], std[2], "float16"));
EXPECT_TRUE(op->OneToOne());
Status s = op->Compute(input_tensor_, &output_tensor);
EXPECT_TRUE(s.IsOk());
}

@ -0,0 +1,201 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing Normalize op in DE
"""
import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.transforms.py_transforms
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
from util import diff_mse, visualize_image
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
GENERATE_GOLDEN = False
def normalizepad_np(image, mean, std):
"""
Apply the normalize+pad
"""
# DE decodes the image in RGB by deafult, hence
# the values here are in RGB
image = np.array(image, np.float32)
image = image - np.array(mean)
image = image * (1.0 / np.array(std))
zeros = np.zeros([image.shape[0], image.shape[1], 1], dtype=np.float32)
output = np.concatenate((image, zeros), axis=2)
return output
def test_normalizepad_op_c(plot=False):
"""
Test NormalizePad in cpp transformations
"""
logger.info("Test Normalize in cpp")
mean = [121.0, 115.0, 100.0]
std = [70.0, 68.0, 71.0]
# define map operations
decode_op = c_vision.Decode()
normalizepad_op = c_vision.NormalizePad(mean, std)
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(operations=decode_op, input_columns=["image"])
data1 = data1.map(operations=normalizepad_op, input_columns=["image"])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(operations=decode_op, input_columns=["image"])
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_de_normalized = item1["image"]
image_original = item2["image"]
image_np_normalized = normalizepad_np(image_original, mean, std)
mse = diff_mse(image_de_normalized, image_np_normalized)
logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
assert mse < 0.01
if plot:
visualize_image(image_original, image_de_normalized, mse, image_np_normalized)
num_iter += 1
def test_normalizepad_op_py(plot=False):
"""
Test NormalizePad in python transformations
"""
logger.info("Test Normalize in python")
mean = [0.475, 0.45, 0.392]
std = [0.275, 0.267, 0.278]
# define map operations
transforms = [
py_vision.Decode(),
py_vision.ToTensor()
]
transform = mindspore.dataset.transforms.py_transforms.Compose(transforms)
normalizepad_op = py_vision.NormalizePad(mean, std)
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data1 = data1.map(operations=transform, input_columns=["image"])
data1 = data1.map(operations=normalizepad_op, input_columns=["image"])
# Second dataset
data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
data2 = data2.map(operations=transform, input_columns=["image"])
num_iter = 0
for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True),
data2.create_dict_iterator(num_epochs=1, output_numpy=True)):
image_de_normalized = (item1["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
image_np_normalized = (normalizepad_np(item2["image"].transpose(1, 2, 0), mean, std) * 255).astype(np.uint8)
image_original = (item2["image"].transpose(1, 2, 0) * 255).astype(np.uint8)
mse = diff_mse(image_de_normalized, image_np_normalized)
logger.info("image_{}, mse: {}".format(num_iter + 1, mse))
assert mse < 0.01
if plot:
visualize_image(image_original, image_de_normalized, mse, image_np_normalized)
num_iter += 1
def test_decode_normalizepad_op():
"""
Test Decode op followed by NormalizePad op
"""
logger.info("Test [Decode, Normalize] in one Map")
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], num_parallel_workers=1,
shuffle=False)
# define map operations
decode_op = c_vision.Decode()
normalizepad_op = c_vision.NormalizePad([121.0, 115.0, 100.0], [70.0, 68.0, 71.0], "float16")
# apply map operations on images
data1 = data1.map(operations=[decode_op, normalizepad_op], input_columns=["image"])
num_iter = 0
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("Looping inside iterator {}".format(num_iter))
assert item["image"].dtype == np.float16
num_iter += 1
def test_normalizepad_exception_unequal_size_c():
"""
Test NormalizePad in c transformation: len(mean) != len(std)
expected to raise ValueError
"""
logger.info("test_normalize_exception_unequal_size_c")
try:
_ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75, 75])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Length of mean and std must be equal."
try:
_ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], 1)
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "dtype should be string."
try:
_ = c_vision.NormalizePad([100, 250, 125], [50, 50, 75], "")
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "dtype only support float32 or float16."
def test_normalizepad_exception_unequal_size_py():
"""
Test NormalizePad in python transformation: len(mean) != len(std)
expected to raise ValueError
"""
logger.info("test_normalizepad_exception_unequal_size_py")
try:
_ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71, 0.72])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "Length of mean and std must be equal."
try:
_ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], 1)
except TypeError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "dtype should be string."
try:
_ = py_vision.NormalizePad([0.50, 0.30, 0.75], [0.18, 0.32, 0.71], "")
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert str(e) == "dtype only support float32 or float16."
def test_normalizepad_exception_invalid_range_py():
"""
Test NormalizePad in python transformation: value is not in range [0,1]
expected to raise ValueError
"""
logger.info("test_normalizepad_exception_invalid_range_py")
try:
_ = py_vision.NormalizePad([0.75, 1.25, 0.5], [0.1, 0.18, 1.32])
except ValueError as e:
logger.info("Got an exception in DE: {}".format(str(e)))
assert "Input mean_value is not within the required interval of (0.0 to 1.0)." in str(e)
Loading…
Cancel
Save