Fix bugs of minddata transforms

pull/12883/head
YangLuo 4 years ago
parent 853354b07f
commit 74431ad117

@ -846,6 +846,12 @@ class SoftDvppDecodeRandomCropResizeJpeg : public TensorTransform {
/// \param[in] size A vector representing the output size of the resized image.
/// If size is a single value, smaller edge of the image will be resized to this value with
/// the same image aspect ratio. If size has 2 values, it should be (height, width).
/// \param[in] scale Range [min, max) of respective size of the original
/// size to be cropped (default=(0.08, 1.0)).
/// \param[in] ratio Range [min, max) of aspect ratio to be cropped
/// (default=(3. / 4., 4. / 3.)).
/// \param[in] max_attempts The maximum number of attempts to propose a valid
/// crop_area (default=10). If exceeded, fall back to use center_crop instead.
SoftDvppDecodeRandomCropResizeJpeg(std::vector<int32_t> size, std::vector<float> scale = {0.08, 1.0},
std::vector<float> ratio = {3. / 4., 4. / 3.}, int32_t max_attempts = 10);

@ -72,6 +72,10 @@ bool CheckTensorShape(const std::shared_ptr<Tensor> &tensor, const int &channel)
Status Flip(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, int flip_code) {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(std::move(input));
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
RETURN_STATUS_UNEXPECTED("Flip: input tensor is not in shape of <H,W,C> or <H,W>.");
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
@ -583,9 +587,13 @@ Status Rotate(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
if (!input_cv->mat().data) {
RETURN_STATUS_UNEXPECTED("Rotate: load image failed.");
}
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
RETURN_STATUS_UNEXPECTED("Rotate: input tensor is not in shape of <H,W,C> or <H,W>.");
}
cv::Mat input_img = input_cv->mat();
if (input_img.cols > (MAX_INT_PRECISION * 2) || input_img.rows > (MAX_INT_PRECISION * 2)) {
RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center not precise");
RETURN_STATUS_UNEXPECTED("Rotate: image is too large and center is not precise.");
}
// default to center of image
if (fx == -1 && fy == -1) {
@ -728,7 +736,7 @@ Status AdjustBrightness(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
}
int num_channels = input_cv->shape()[2];
if (input_cv->Rank() != 3 || num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustBrightness: image shape is not <H,W,C>.");
RETURN_STATUS_UNEXPECTED("AdjustBrightness: image shape is not <H,W,C> or channel is not 3.");
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
@ -749,7 +757,7 @@ Status AdjustContrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tens
}
int num_channels = input_cv->shape()[2];
if (input_cv->Rank() != 3 || num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustContrast: image shape is not <H,W,C>.");
RETURN_STATUS_UNEXPECTED("AdjustContrast: image shape is not <H,W,C> or channel is not 3.");
}
cv::Mat gray, output_img;
cv::cvtColor(input_img, gray, CV_RGB2GRAY);
@ -854,7 +862,7 @@ Status AdjustSaturation(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
}
int num_channels = input_cv->shape()[2];
if (input_cv->Rank() != 3 || num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustSaturation: image shape is not <H,W,C>.");
RETURN_STATUS_UNEXPECTED("AdjustSaturation: image shape is not <H,W,C> or channel is not 3.");
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
@ -882,7 +890,7 @@ Status AdjustHue(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
}
int num_channels = input_cv->shape()[2];
if (input_cv->Rank() != 3 || num_channels != 3) {
RETURN_STATUS_UNEXPECTED("AdjustHue: image shape is not <H,W,C>.");
RETURN_STATUS_UNEXPECTED("AdjustHue: image shape is not <H,W,C> or channel is not 3.");
}
std::shared_ptr<CVTensor> output_cv;
RETURN_IF_NOT_OK(CVTensor::CreateEmpty(input_cv->shape(), input_cv->type(), &output_cv));
@ -956,7 +964,7 @@ Status Erase(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
RETURN_STATUS_UNEXPECTED("CutOut: load image failed.");
}
if (input_cv->Rank() != 3 || num_channels != 3) {
RETURN_STATUS_UNEXPECTED("CutOut: image shape is not <H,W,C> or <H,W>.");
RETURN_STATUS_UNEXPECTED("CutOut: image shape is not <H,W,C> or channel is not 3.");
}
cv::Mat input_img = input_cv->mat();
int32_t image_h = input_cv->shape()[0];
@ -1016,6 +1024,12 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
try {
// input image
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
// validate rank
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
RETURN_STATUS_UNEXPECTED("Pad: input tensor is not in shape of <H,W,C> or <H,W>.");
}
// get the border type in openCV
auto b_type = GetCVBorderType(border_types);
// output image
@ -1106,6 +1120,10 @@ Status Affine(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
InterpolationMode interpolation, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) {
try {
std::shared_ptr<CVTensor> input_cv = CVTensor::AsCVTensor(input);
if (input_cv->Rank() != 3 || input_cv->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("Affine: image shape is not <H,W,C> or channel is not 3.");
}
cv::Mat affine_mat(mat);
affine_mat = affine_mat.reshape(1, {2, 3});

@ -25,8 +25,8 @@ RandomColorOp::RandomColorOp(float t_lb, float t_ub) : rnd_(GetSeed()), dist_(t_
Status RandomColorOp::Compute(const std::shared_ptr<Tensor> &in, std::shared_ptr<Tensor> *out) {
IO_CHECK(in, out);
if (in->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("RandomColor: image shape is not <H,W,C>.");
if (in->Rank() != 3 || in->shape()[2] != 3) {
RETURN_STATUS_UNEXPECTED("RandomColor: image shape is not <H,W,C> or channel is not 3.");
}
// 0.5 pixel precision assuming an 8 bit image
const auto eps = 0.00195;

@ -34,8 +34,8 @@ Status SharpnessOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_pt
RETURN_STATUS_UNEXPECTED("Sharpness: load image failed.");
}
if (input_cv->Rank() != 3 && input_cv->Rank() != 2) {
RETURN_STATUS_UNEXPECTED("Sharpness: image shape is not <H,W,C> or <H,W>");
if (input_cv->Rank() == 1 || input_cv->mat().dims > 2) {
RETURN_STATUS_UNEXPECTED("Sharpness: input tensor is not in shape of <H,W,C> or <H,W>.");
}
/// creating a smoothing filter. 1, 1, 1,

@ -39,16 +39,6 @@ Status SolarizeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr
RETURN_STATUS_UNEXPECTED("Solarize: load image failed.");
}
if (input_cv->Rank() != 2 && input_cv->Rank() != 3) {
RETURN_STATUS_UNEXPECTED("Solarize: image shape is not <H,W,C> or <H,W>.");
}
if (input_cv->Rank() == 3) {
int num_channels = input_cv->shape()[2];
if (num_channels != 3 && num_channels != 1) {
RETURN_STATUS_UNEXPECTED("Solarize: image shape is not <H,W,C>.");
}
}
std::shared_ptr<CVTensor> mask_mat_tensor;
std::shared_ptr<CVTensor> output_cv_tensor;
RETURN_IF_NOT_OK(CVTensor::CreateFromMat(input_cv->mat(), &mask_mat_tensor));

@ -160,8 +160,8 @@ Status ValidateVectorRatio(const std::string &op_name, const std::vector<float>
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[0], {0}, true));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "scale", ratio[1], {0}, true));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[0], {0}, true));
RETURN_IF_NOT_OK(ValidateScalar(op_name, "ratio", ratio[1], {0}, true));
if (ratio[1] < ratio[0]) {
std::string err_msg = op_name + ": ratio must be in the format of (min, max).";
MS_LOG(ERROR) << op_name + ": ratio must be in the format of (min, max), but got: " << ratio;

@ -187,7 +187,7 @@ BoundingBoxAugmentOperation::BoundingBoxAugmentOperation(std::shared_ptr<TensorO
Status BoundingBoxAugmentOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateVectorTransforms("BoundingBoxAugment", {transform_}));
RETURN_IF_NOT_OK(ValidateProbability("BoundingBoxAugment", ratio_));
RETURN_IF_NOT_OK(ValidateScalar("BoundingBoxAugment", "ratio", ratio_, {0.0, 1.0}, false, false));
return Status::OK();
}
@ -1566,7 +1566,8 @@ Status UniformAugOperation::ValidateParams() {
// transforms
RETURN_IF_NOT_OK(ValidateVectorTransforms("UniformAug", transforms_));
if (num_ops_ > transforms_.size()) {
std::string err_msg = "UniformAug: num_ops is greater than transforms size, but got: " + std::to_string(num_ops_);
std::string err_msg =
"UniformAug: num_ops must be less than or equal to transforms size, but got: " + std::to_string(num_ops_);
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}

@ -387,5 +387,13 @@ def check_tensor_op(param, param_name):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def check_c_tensor_op(param, param_name):
"""check whether param is a tensor op or a callable Python function but not a py_transform"""
if callable(param) and getattr(param, 'parse', True):
raise TypeError("{0} is a py_transform op which is not allow to use.".format(param_name))
if not isinstance(param, cde.TensorOp) and not callable(param) and not getattr(param, 'parse', None):
raise TypeError("{0} is neither a c_transform op (TensorOperation) nor a callable pyfunc.".format(param_name))
def replace_none(value, default):
return value if value is not None else default

@ -34,5 +34,5 @@ __all__ = ["CelebADataset", "Cifar100Dataset", "Cifar10Dataset", "CLUEDataset",
"GeneratorDataset", "GraphData", "ImageFolderDataset", "ManifestDataset", "MindDataset", "MnistDataset",
"NumpySlicesDataset", "PaddedDataset", "TextFileDataset", "TFRecordDataset", "VOCDataset",
"DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler",
"WeightedRandomSampler",
"WeightedRandomSampler", "SubsetSampler",
"config", "DatasetCache", "Schema", "zip"]

@ -52,8 +52,8 @@ import mindspore.common.dtype as mstype
from .utils import JiebaMode, NormalizeForm, to_str, SPieceTokenizerOutType, SPieceTokenizerLoadType
from .validators import check_lookup, check_jieba_add_dict, \
check_jieba_add_word, check_jieba_init, check_with_offsets, check_unicode_script_tokenizer, \
check_wordpiece_tokenizer, check_regex_tokenizer, check_basic_tokenizer, check_ngram, check_pair_truncate, \
check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
check_wordpiece_tokenizer, check_regex_replace, check_regex_tokenizer, check_basic_tokenizer, check_ngram, \
check_pair_truncate, check_to_number, check_bert_tokenizer, check_python_tokenizer, check_slidingwindow
from ..core.datatypes import mstype_to_detype
from ..core.validator_helpers import replace_none
from ..transforms.c_transforms import TensorOperation
@ -756,6 +756,7 @@ if platform.system().lower() != 'windows':
>>> text_file_dataset = text_file_dataset.map(operations=replace_op)
"""
@check_regex_replace
def __init__(self, pattern, replace, replace_all=True):
self.pattern = pattern
self.replace = replace

@ -216,6 +216,20 @@ def check_wordpiece_tokenizer(method):
return new_method
def check_regex_replace(method):
"""Wrapper method to check the parameter of RegexReplace."""
@wraps(method)
def new_method(self, *args, **kwargs):
[pattern, replace, replace_all], _ = parse_user_args(method, *args, **kwargs)
type_check(pattern, (str,), "pattern")
type_check(replace, (str,), "replace")
type_check(replace_all, (bool,), "replace_all")
return method(self, *args, **kwargs)
return new_method
def check_regex_tokenizer(method):
"""Wrapper method to check the parameter of RegexTokenizer."""

@ -133,9 +133,9 @@ class _SliceOption(cde.SliceOption):
1. :py:obj:`int`: Slice this index only along the dimension. Negative index is supported.
2. :py:obj:`list(int)`: Slice these indices along the dimension. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object along the dimension.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing.
6. :py:obj:`boolean`: Slice the whole dimension. Similar to `:` in Python indexing.
4. :py:obj:`None`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing.
6. :py:obj:`boolean`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing.
"""
@check_slice_option
@ -165,8 +165,8 @@ class Slice(cde.SliceOp):
2. :py:obj:`list(int)`: Slice these indices along the first dimension. Negative indices are supported.
3. :py:obj:`slice`: Slice the generated indices from the slice object along the first dimension.
Similar to start:stop:step.
4. :py:obj:`None`: Slice the whole dimension. Similar to `:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension. Similar to `:` in Python indexing.
4. :py:obj:`None`: Slice the whole dimension. Similar to :py:obj:`:` in Python indexing.
5. :py:obj:`Ellipsis`: Slice the whole dimension, same result with `None`.
Examples:
>>> # Data before

@ -271,7 +271,7 @@ class Decode(ImageTensorOperation):
img (NumPy), Decoded image.
"""
if not isinstance(img, np.ndarray) or img.ndim != 1 or img.dtype.type is np.str_:
raise TypeError("Input should be an encoded image with 1-D NumPy type, got {}.".format(type(img)))
raise TypeError("Input should be an encoded image in 1-D NumPy format, got {}.".format(type(img)))
return super().__call__(img)
def parse(self):
@ -763,6 +763,14 @@ class RandomCropDecodeResize(ImageTensorOperation):
DE_C_INTER_MODE[self.interpolation],
self.max_attempts)
def __call__(self, img):
if not isinstance(img, np.ndarray):
raise TypeError("Input should be an encoded image in 1-D NumPy format, got {}.".format(type(img)))
if img.ndim != 1 or img.dtype.type is not np.uint8:
raise TypeError("Input should be an encoded image with uint8 type in 1-D NumPy format, " +
"got format:{}, dtype:{}.".format(type(img), img.dtype.type))
super().__call__(img=img)
class RandomCropWithBBox(ImageTensorOperation):
"""
@ -1164,8 +1172,8 @@ class RandomSharpness(ImageTensorOperation):
degree of 1.0 gives the original image, and degree of 2.0 gives a sharpened image.
Args:
degrees (tuple, optional): Range of random sharpness adjustment degrees. It should be in (min, max) format.
If min=max, then it is a single fixed magnitude operation (default = (0.1, 1.9)).
degrees (Union[list, tuple], optional): Range of random sharpness adjustment degrees. It should be in
(min, max) format. If min=max, then it is a single fixed magnitude operation (default = (0.1, 1.9)).
Raises:
TypeError : If degrees is not a list or tuple.

@ -21,7 +21,7 @@ from mindspore._c_dataengine import TensorOp, TensorOperation
from mindspore.dataset.core.validator_helpers import check_value, check_uint8, FLOAT_MAX_INTEGER, check_pos_float32, \
check_float32, check_2tuple, check_range, check_positive, INT32_MAX, parse_user_args, type_check, type_check_list, \
check_tensor_op, UINT8_MAX, check_value_normalize_std
check_c_tensor_op, UINT8_MAX, check_value_normalize_std
from .utils import Inter, Border, ImageBatchFormat
@ -727,7 +727,7 @@ def check_random_select_subpolicy_op(method):
raise ValueError("policy[{0}] can not be empty.".format(sub_ind))
for op_ind, tp in enumerate(sub):
check_2tuple(tp, "policy[{0}][{1}]".format(sub_ind, op_ind))
check_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
check_c_tensor_op(tp[0], "op of (op, prob) in policy[{0}][{1}]".format(sub_ind, op_ind))
check_value(tp[1], (0, 1), "prob of (op, prob) policy[{0}][{1}]".format(sub_ind, op_ind))
return method(self, *args, **kwargs)

@ -43,11 +43,12 @@ TEST_F(MindDataTestPipeline, TestTFRecordDatasetBasic) {
EXPECT_NE(ds, nullptr);
// Create objects for the tensor ops
std::shared_ptr<TensorTransform> decode_op = std::make_shared<vision::Decode>();
std::shared_ptr<TensorTransform> random_horizontal_flip_op = std::make_shared<vision::RandomHorizontalFlip>(0.5);
EXPECT_NE(random_horizontal_flip_op, nullptr);
// Create a Map operation on ds
ds = ds->Map({random_horizontal_flip_op}, {}, {}, {"image"});
ds = ds->Map({decode_op, random_horizontal_flip_op}, {}, {}, {"image"});
EXPECT_NE(ds, nullptr);
// Create a Batch operation on ds

@ -95,14 +95,14 @@ def test_eager_exceptions():
img = C.Decode()(img)
assert False
except TypeError as e:
assert "Input should be an encoded image with 1-D NumPy type" in str(e)
assert "Input should be an encoded image in 1-D NumPy format" in str(e)
try:
img = np.array(["a", "b", "c"])
img = C.Decode()(img)
assert False
except TypeError as e:
assert "Input should be an encoded image with 1-D NumPy type" in str(e)
assert "Input should be an encoded image in 1-D NumPy format" in str(e)
try:
img = cv2.imread("../data/dataset/apple.jpg")

@ -239,7 +239,7 @@ def test_random_color_c_errors():
with pytest.raises(RuntimeError) as error_info:
for _ in enumerate(mnist_ds):
pass
assert "Invalid number of channels in input image" in str(error_info.value)
assert "image shape is not <H,W,C> or channel is not 3" in str(error_info.value)
if __name__ == "__main__":

Loading…
Cancel
Save