!11727 [MD] Add vision ops and py_transform ops support to Serdes save

From: @tina_mengting_zhang
Reviewed-by: 
Signed-off-by:
pull/11727/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 544da6f845

File diff suppressed because it is too large Load Diff

@ -155,8 +155,14 @@ Status MapNode::to_json(nlohmann::json *out_json) {
for (auto op : operations_) {
nlohmann::json op_args;
RETURN_IF_NOT_OK(op->to_json(&op_args));
op_args["tensor_op_name"] = op->Name();
ops.push_back(op_args);
if (op->Name() == "PyFuncOp") {
ops.push_back(op_args);
} else {
nlohmann::json op_item;
op_item["tensor_op_params"] = op_args;
op_item["tensor_op_name"] = op->Name();
ops.push_back(op_item);
}
}
args["operations"] = ops;
std::transform(callbacks_.begin(), callbacks_.end(), std::back_inserter(cbs),

@ -519,6 +519,8 @@ class AutoContrastOperation : public TensorOperation {
std::string Name() const override { return kAutoContrastOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float cutoff_;
std::vector<uint32_t> ignore_;
@ -536,6 +538,8 @@ class BoundingBoxAugmentOperation : public TensorOperation {
std::string Name() const override { return kBoundingBoxAugmentOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::shared_ptr<TensorOperation> transform_;
float ratio_;
@ -553,6 +557,8 @@ class CutMixBatchOperation : public TensorOperation {
std::string Name() const override { return kCutMixBatchOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float alpha_;
float prob_;
@ -571,6 +577,8 @@ class CutOutOperation : public TensorOperation {
std::string Name() const override { return kCutOutOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t length_;
int32_t num_patches_;
@ -638,6 +646,8 @@ class MixUpBatchOperation : public TensorOperation {
std::string Name() const override { return kMixUpBatchOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float alpha_;
};
@ -655,6 +665,8 @@ class NormalizePadOperation : public TensorOperation {
std::string Name() const override { return kNormalizePadOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<float> mean_;
std::vector<float> std_;
@ -698,6 +710,8 @@ class RandomAffineOperation : public TensorOperation {
std::string Name() const override { return kRandomAffineOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<float_t> degrees_; // min_degree, max_degree
std::vector<float_t> translate_range_; // maximum x translation percentage, maximum y translation percentage
@ -719,6 +733,8 @@ class RandomColorOperation : public TensorOperation {
std::string Name() const override { return kRandomColorOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float t_lb_;
float t_ub_;
@ -788,6 +804,8 @@ class RandomResizedCropOperation : public TensorOperation {
std::string Name() const override { return kRandomResizedCropOperation; }
Status to_json(nlohmann::json *out_json) override;
protected:
std::vector<int32_t> size_;
std::vector<float> scale_;
@ -808,6 +826,8 @@ class RandomCropDecodeResizeOperation : public RandomResizedCropOperation {
std::shared_ptr<TensorOp> Build() override;
std::string Name() const override { return kRandomCropDecodeResizeOperation; }
Status to_json(nlohmann::json *out_json) override;
};
class RandomCropWithBBoxOperation : public TensorOperation {
@ -824,6 +844,8 @@ class RandomCropWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kRandomCropWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
std::vector<int32_t> padding_;
@ -844,6 +866,8 @@ class RandomHorizontalFlipOperation : public TensorOperation {
std::string Name() const override { return kRandomHorizontalFlipOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float probability_;
};
@ -860,6 +884,8 @@ class RandomHorizontalFlipWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kRandomHorizontalFlipWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float probability_;
};
@ -876,6 +902,8 @@ class RandomPosterizeOperation : public TensorOperation {
std::string Name() const override { return kRandomPosterizeOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<uint8_t> bit_range_;
};
@ -892,6 +920,8 @@ class RandomResizeOperation : public TensorOperation {
std::string Name() const override { return kRandomResizeOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
};
@ -908,6 +938,8 @@ class RandomResizeWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kRandomResizeWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
};
@ -927,6 +959,8 @@ class RandomResizedCropWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kRandomResizedCropWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
std::vector<float> scale_;
@ -971,6 +1005,8 @@ class RandomSelectSubpolicyOperation : public TensorOperation {
std::string Name() const override { return kRandomSelectSubpolicyOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<std::vector<std::pair<std::shared_ptr<TensorOperation>, double>>> policy_;
};
@ -987,6 +1023,8 @@ class RandomSharpnessOperation : public TensorOperation {
std::string Name() const override { return kRandomSharpnessOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<float> degrees_;
};
@ -1003,6 +1041,8 @@ class RandomSolarizeOperation : public TensorOperation {
std::string Name() const override { return kRandomSolarizeOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<uint8_t> threshold_;
};
@ -1019,6 +1059,8 @@ class RandomVerticalFlipOperation : public TensorOperation {
std::string Name() const override { return kRandomVerticalFlipOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float probability_;
};
@ -1035,6 +1077,8 @@ class RandomVerticalFlipWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kRandomVerticalFlipWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float probability_;
};
@ -1071,6 +1115,8 @@ class ResizeWithBBoxOperation : public TensorOperation {
std::string Name() const override { return kResizeWithBBoxOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
InterpolationMode interpolation_;
@ -1115,6 +1161,8 @@ class SoftDvppDecodeRandomCropResizeJpegOperation : public TensorOperation {
std::string Name() const override { return kSoftDvppDecodeRandomCropResizeJpegOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
std::vector<float> scale_;
@ -1134,6 +1182,8 @@ class SoftDvppDecodeResizeJpegOperation : public TensorOperation {
std::string Name() const override { return kSoftDvppDecodeResizeJpegOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<int32_t> size_;
};
@ -1163,6 +1213,8 @@ class UniformAugOperation : public TensorOperation {
std::string Name() const override { return kUniformAugOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
std::vector<std::shared_ptr<TensorOperation>> transforms_;
int32_t num_ops_;

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -119,5 +119,15 @@ Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
return Status::OK();
}
Status PyFuncOp::to_json(nlohmann::json *out_json) {
nlohmann::json args;
auto package = pybind11::module::import("pickle");
auto module = package.attr("dumps");
args["tensor_op_params"] = module(py_func_ptr_, 0).cast<std::string>();
args["tensor_op_name"] = py_func_ptr_.attr("__class__").attr("__name__").cast<std::string>();
args["is_python_front_end_op"] = true;
*out_json = args;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -49,6 +49,7 @@ class PyFuncOp : public TensorOp {
/// \return Status
Status CastOutput(const py::object &ret_py_obj, TensorRow *output);
std::string Name() const override { return kPyFuncOp; }
Status to_json(nlohmann::json *out_json) override;
private:
py::function py_func_ptr_;

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd
# Copyright 2019-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,12 +17,13 @@ Functions to support dataset serialize and deserialize.
"""
import json
import os
import pickle
import sys
import mindspore.common.dtype as mstype
from mindspore import log as logger
from . import datasets as de
from ..vision.utils import Inter, Border
from ..vision.utils import Inter, Border, ImageBatchFormat
def serialize(dataset, json_filepath=""):
@ -277,7 +278,7 @@ def create_dataset_operation_node(node, dataset_op):
tensor_ops = construct_tensor_ops(node.get('operations'))
pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'),
node.get('column_order'), node.get('num_parallel_workers'),
True, node.get('callbacks'))
False, None, node.get('callbacks'))
elif dataset_op == 'Project':
pyobj = de.Dataset().project(node['columns'])
@ -344,66 +345,60 @@ def construct_tensor_ops(operations):
"""Instantiate tensor op object(s) based on the information from dictionary['operations']"""
result = []
for op in operations:
op_name = op['tensor_op_name']
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if op_name == "HwcToChw": op_name = "HWC2CHW"
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")
# Transforms Ops (in alphabetical order)
if op_name == 'OneHot':
result.append(op_class(op['num_classes']))
elif op_name == 'TypeCast':
result.append(op_class(to_mstype(op['data_type'])))
# Vision Ops (in alphabetical order)
elif op_name == 'CenterCrop':
result.append(op_class(op['size']))
elif op_name == 'Decode':
result.append(op_class(op.get('rgb')))
elif op_name == 'HWC2CHW':
result.append(op_class())
elif op_name == 'Normalize':
result.append(op_class(op['mean'], op['std']))
elif op_name == 'Pad':
result.append(op_class(op['padding'], tuple(op['fill_value']), Border(to_border_mode(op['padding_mode']))))
elif op_name == 'RandomColorAdjust':
result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'),
op.get('hue')))
elif op_name == 'RandomCrop':
result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'),
tuple(op.get('fill_value')), Border(to_border_mode(op.get('padding_mode')))))
elif op_name == 'RandomRotation':
result.append(op_class(op['degrees'], to_interpolation_mode(op.get('interpolation_mode')), op.get('expand'),
tuple(op.get('center')), tuple(op.get('fill_value'))))
elif op_name == 'Rescale':
result.append(op_class(op['rescale'], op['shift']))
elif op_name == 'Resize':
result.append(op_class(op['size'], to_interpolation_mode(op.get('interpolation'))))
op_name = op.get('tensor_op_name')
op_params = op.get('tensor_op_params')
if op.get('is_python_front_end_op'): # check if it's a py_transform op
result.append(pickle.loads(op_params.encode()))
else:
raise ValueError("Tensor op name is unknown: {}.".format(op_name))
if op_name == "HwcToChw": op_name = "HWC2CHW"
if op_name == "UniformAug": op_name = "UniformAugment"
op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"]
op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"]
if hasattr(op_module_vis, op_name):
op_class = getattr(op_module_vis, op_name, None)
elif hasattr(op_module_trans, op_name[:-2]):
op_name = op_name[:-2] # to remove op from the back of the name
op_class = getattr(op_module_trans, op_name, None)
else:
raise RuntimeError(op_name + " is not yet supported by deserialize().")
if op_params is None: # If no parameter is specified, call it directly
result.append(op_class())
else:
# Input parameter type cast
for key, val in op_params.items():
if key in ['center', 'fill_value']:
op_params[key] = tuple(val)
elif key in ['interpolation', 'resample']:
op_params[key] = Inter(to_interpolation_mode(val))
elif key in ['padding_mode']:
op_params[key] = Border(to_border_mode(val))
elif key in ['data_type']:
op_params[key] = to_mstype(val)
elif key in ['image_batch_format']:
op_params[key] = to_image_batch_format(val)
elif key in ['policy']:
op_params[key] = to_policy(val)
elif key in ['transform', 'transforms']:
op_params[key] = construct_tensor_ops(val)
result.append(op_class(**op_params))
return result
def to_policy(op_list):
policy_tensor_ops = []
for policy_list in op_list:
sub_policy_tensor_ops = []
for policy_item in policy_list:
sub_policy_tensor_ops.append(
(construct_tensor_ops(policy_item.get('tensor_op')), policy_item.get('prob')))
policy_tensor_ops.append(sub_policy_tensor_ops)
return policy_tensor_ops
def to_shuffle_mode(shuffle):
if shuffle == 2: return "global"
if shuffle == 1: return "file"
@ -446,7 +441,12 @@ def to_mstype(data_type):
}[data_type]
def to_image_batch_format(image_batch_format):
return {
0: ImageBatchFormat.NHWC,
1: ImageBatchFormat.NCHW
}[image_batch_format]
def check_and_replace_input(input_value, expect, replace):
if input_value == expect:
return replace
return input_value
return replace if input_value == expect else input_value

@ -28,7 +28,9 @@ from util import config_get_set_num_parallel_workers, config_get_set_seed
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c
import mindspore.dataset.transforms.py_transforms as py
import mindspore.dataset.vision.c_transforms as vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
from mindspore.dataset.vision import Inter
@ -351,7 +353,7 @@ def test_serdes_voc_dataset(remove_json_files=True):
def test_serdes_to_device(remove_json_files=True):
"""
Test serdes on VOC dataset pipeline.
Test serdes on transfer dataset pipeline.
"""
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
@ -360,6 +362,43 @@ def test_serdes_to_device(remove_json_files=True):
util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files)
def test_serdes_pyvision(remove_json_files=True):
"""
Test serdes on py_transform pipeline.
"""
data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json"
data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False)
transforms = [
py_vision.Decode(),
py_vision.CenterCrop([32, 32]),
py_vision.ToTensor()
]
data1 = data1.map(operations=py.Compose(transforms), input_columns=["image"])
util_check_serialize_deserialize_file(data1, "pyvision_dataset_pipeline", remove_json_files)
def test_serdes_uniform_augment(remove_json_files=True):
"""
Test serdes on uniform augment.
"""
data_dir = "../data/dataset/testPK/data"
data = ds.ImageFolderDataset(dataset_dir=data_dir, shuffle=False)
ds.config.set_seed(1)
transforms_ua = [vision.RandomHorizontalFlip(),
vision.RandomVerticalFlip(),
vision.RandomColor(),
vision.RandomSharpness(),
vision.Invert(),
vision.AutoContrast(),
vision.Equalize()]
transforms_all = [vision.Decode(), vision.Resize(size=[224, 224]),
vision.UniformAugment(transforms=transforms_ua, num_ops=5)]
data = data.map(operations=transforms_all, input_columns="image", num_parallel_workers=1)
util_check_serialize_deserialize_file(data, "uniform_augment_pipeline", remove_json_files)
def test_serdes_exception():
"""
Test exception case in serdes

Loading…
Cancel
Save