Block caching after random pyfunc

pull/12327/head
Lixia Chen 4 years ago
parent ed7fef5d5e
commit 0667818d9a

@ -30,6 +30,9 @@
#endif #endif
#include "minddata/dataset/kernels/ir/validators.h" #include "minddata/dataset/kernels/ir/validators.h"
#ifdef ENABLE_PYTHON
#include "minddata/dataset/kernels/py_func_op.h"
#endif
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -78,7 +81,12 @@ Status OneHotOperation::ValidateParams() {
std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); } std::shared_ptr<TensorOp> OneHotOperation::Build() { return std::make_shared<OneHotOp>(num_classes_); }
// PreBuiltOperation // PreBuiltOperation
PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {} PreBuiltOperation::PreBuiltOperation(std::shared_ptr<TensorOp> tensor_op) : op_(tensor_op) {
#ifdef ENABLE_PYTHON
auto pyfunc_tensor_op = std::dynamic_pointer_cast<PyFuncOp>(tensor_op);
if (pyfunc_tensor_op && pyfunc_tensor_op->IsRandom()) random_op_ = true;
#endif
}
Status PreBuiltOperation::ValidateParams() { return Status::OK(); } Status PreBuiltOperation::ValidateParams() { return Status::OK(); }

@ -129,5 +129,12 @@ Status PyFuncOp::to_json(nlohmann::json *out_json) {
*out_json = args; *out_json = args;
return Status::OK(); return Status::OK();
} }
bool PyFuncOp::IsRandom() {
bool random = true;
if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false)
random = false;
return random;
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -51,6 +51,10 @@ class PyFuncOp : public TensorOp {
std::string Name() const override { return kPyFuncOp; } std::string Name() const override { return kPyFuncOp; }
Status to_json(nlohmann::json *out_json) override; Status to_json(nlohmann::json *out_json) override;
/// \brief Check whether this pyfunc op is deterministic
/// \return True if this pyfunc op is random
bool IsRandom();
private: private:
py::function py_func_ptr_; py::function py_func_ptr_;
DataType::Type output_type_; DataType::Type output_type_;

@ -552,6 +552,7 @@ class PythonTokenizer:
@check_python_tokenizer @check_python_tokenizer
def __init__(self, tokenizer): def __init__(self, tokenizer):
self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)') self.tokenizer = np.vectorize(lambda x: np.array(tokenizer(x), dtype='U'), signature='()->(n)')
self.random = False
def __call__(self, in_array): def __call__(self, in_array):
in_array = to_str(in_array) in_array = to_str(in_array)

@ -21,6 +21,11 @@ from .validators import check_one_hot_op, check_compose_list, check_random_apply
from . import py_transforms_util as util from . import py_transforms_util as util
def not_random(function):
function.random = False
return function
class OneHotOp: class OneHotOp:
""" """
Apply one hot encoding transformation to the input label, make label be more smoothing and continuous. Apply one hot encoding transformation to the input label, make label be more smoothing and continuous.
@ -42,6 +47,7 @@ class OneHotOp:
def __init__(self, num_classes, smoothing_rate=0.0): def __init__(self, num_classes, smoothing_rate=0.0):
self.num_classes = num_classes self.num_classes = num_classes
self.smoothing_rate = smoothing_rate self.smoothing_rate = smoothing_rate
self.random = False
def __call__(self, label): def __call__(self, label):
""" """
@ -114,6 +120,8 @@ class Compose:
@check_compose_list @check_compose_list
def __init__(self, transforms): def __init__(self, transforms):
self.transforms = transforms self.transforms = transforms
if all(hasattr(transform, "random") and not transform.random for transform in self.transforms):
self.random = False
@check_compose_call @check_compose_call
def __call__(self, *args): def __call__(self, *args):

@ -45,6 +45,11 @@ DE_PY_BORDER_TYPE = {Border.CONSTANT: 'constant',
Border.SYMMETRIC: 'symmetric'} Border.SYMMETRIC: 'symmetric'}
def not_random(function):
function.random = False
return function
class ToTensor: class ToTensor:
""" """
Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W). Convert the input NumPy image array or PIL image of shape (H, W, C) to a NumPy ndarray of shape (C, H, W).
@ -70,6 +75,7 @@ class ToTensor:
def __init__(self, output_type=np.float32): def __init__(self, output_type=np.float32):
self.output_type = output_type self.output_type = output_type
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -105,6 +111,7 @@ class ToType:
def __init__(self, output_type): def __init__(self, output_type):
self.output_type = output_type self.output_type = output_type
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -132,6 +139,9 @@ class HWC2CHW:
... input_columns="image") ... input_columns="image")
""" """
def __init__(self):
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
Call method. Call method.
@ -160,6 +170,9 @@ class ToPIL:
... input_columns="image") ... input_columns="image")
""" """
def __init__(self):
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
Call method. Call method.
@ -187,6 +200,9 @@ class Decode:
... input_columns="image") ... input_columns="image")
""" """
def __init__(self):
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
Call method. Call method.
@ -227,6 +243,7 @@ class Normalize:
def __init__(self, mean, std): def __init__(self, mean, std):
self.mean = mean self.mean = mean
self.std = std self.std = std
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -271,6 +288,7 @@ class NormalizePad:
self.mean = mean self.mean = mean
self.std = std self.std = std
self.dtype = dtype self.dtype = dtype
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -456,6 +474,7 @@ class Resize:
def __init__(self, size, interpolation=Inter.BILINEAR): def __init__(self, size, interpolation=Inter.BILINEAR):
self.size = size self.size = size
self.interpolation = DE_PY_INTER_MODE[interpolation] self.interpolation = DE_PY_INTER_MODE[interpolation]
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -550,6 +569,7 @@ class CenterCrop:
@check_crop @check_crop
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -700,6 +720,7 @@ class FiveCrop:
@check_crop @check_crop
def __init__(self, size): def __init__(self, size):
self.size = size self.size = size
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -744,6 +765,7 @@ class TenCrop:
size = (size, size) size = (size, size)
self.size = size self.size = size
self.use_vertical_flip = use_vertical_flip self.use_vertical_flip = use_vertical_flip
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -781,6 +803,7 @@ class Grayscale:
@check_num_channels @check_num_channels
def __init__(self, num_output_channels=1): def __init__(self, num_output_channels=1):
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -884,6 +907,7 @@ class Pad:
self.padding = padding self.padding = padding
self.fill_value = fill_value self.fill_value = fill_value
self.padding_mode = DE_PY_BORDER_TYPE[padding_mode] self.padding_mode = DE_PY_BORDER_TYPE[padding_mode]
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -1030,6 +1054,7 @@ class Cutout:
def __init__(self, length, num_patches=1): def __init__(self, length, num_patches=1):
self.length = length self.length = length
self.num_patches = num_patches self.num_patches = num_patches
self.random = False
def __call__(self, np_img): def __call__(self, np_img):
""" """
@ -1087,6 +1112,7 @@ class LinearTransformation:
def __init__(self, transformation_matrix, mean_vector): def __init__(self, transformation_matrix, mean_vector):
self.transformation_matrix = transformation_matrix self.transformation_matrix = transformation_matrix
self.mean_vector = mean_vector self.mean_vector = mean_vector
self.random = False
def __call__(self, np_img): def __call__(self, np_img):
""" """
@ -1229,6 +1255,7 @@ class MixUp:
self.batch_size = batch_size self.batch_size = batch_size
self.alpha = alpha self.alpha = alpha
self.is_single = is_single self.is_single = is_single
self.random = False
def __call__(self, image, label): def __call__(self, image, label):
""" """
@ -1268,6 +1295,7 @@ class RgbToHsv:
def __init__(self, is_hwc=False): def __init__(self, is_hwc=False):
self.is_hwc = is_hwc self.is_hwc = is_hwc
self.random = False
def __call__(self, rgb_imgs): def __call__(self, rgb_imgs):
""" """
@ -1304,6 +1332,7 @@ class HsvToRgb:
def __init__(self, is_hwc=False): def __init__(self, is_hwc=False):
self.is_hwc = is_hwc self.is_hwc = is_hwc
self.random = False
def __call__(self, hsv_imgs): def __call__(self, hsv_imgs):
""" """
@ -1414,6 +1443,7 @@ class AutoContrast:
def __init__(self, cutoff=0.0, ignore=None): def __init__(self, cutoff=0.0, ignore=None):
self.cutoff = cutoff self.cutoff = cutoff
self.ignore = ignore self.ignore = ignore
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
@ -1443,6 +1473,9 @@ class Invert:
... input_columns="image") ... input_columns="image")
""" """
def __init__(self):
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
Call method. Call method.
@ -1472,6 +1505,9 @@ class Equalize:
""" """
def __init__(self):
self.random = False
def __call__(self, img): def __call__(self, img):
""" """
Call method. Call method.
@ -1516,6 +1552,7 @@ class UniformAugment:
def __init__(self, transforms, num_ops=2): def __init__(self, transforms, num_ops=2):
self.transforms = transforms self.transforms = transforms
self.num_ops = num_ops self.num_ops = num_ops
self.random = False
def __call__(self, img): def __call__(self, img):
""" """

@ -318,6 +318,9 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1 PytestCmd "test_cache_nomap.py" "test_cache_nomap_failure" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_pyfunc" 1
HandleRcExit $? 0 0
for i in $(seq 1 3) for i in $(seq 1 3)
do do
test_name="test_cache_nomap_multiple_cache${i}" test_name="test_cache_nomap_multiple_cache${i}"

@ -20,6 +20,7 @@ import pytest
import numpy as np import numpy as np
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as c_vision import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger from mindspore import log as logger
from util import save_and_check_md5 from util import save_and_check_md5
@ -481,7 +482,7 @@ def test_cache_map_failure7():
some_cache = ds.DatasetCache(session_id=session_id, size=0) some_cache = ds.DatasetCache(session_id=session_id, size=0)
data = ds.GeneratorDataset(generator_1d, ["data"]) data = ds.GeneratorDataset(generator_1d, ["data"])
data = data.map((lambda x: x), ["data"], cache=some_cache) data = data.map(py_vision.not_random(lambda x: x), ["data"], cache=some_cache)
data = data.repeat(4) data = data.repeat(4)
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:

@ -17,11 +17,13 @@ Testing cache operator with non-mappable datasets
""" """
import os import os
import itertools import itertools
import numpy as np
import pytest import pytest
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset as ds import mindspore.dataset as ds
import mindspore.dataset.text as text import mindspore.dataset.text as text
import mindspore.dataset.vision.c_transforms as c_vision import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -41,6 +43,9 @@ CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
CSV_DATA_DIR = '../data/dataset/testCSV/1.csv' CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt" TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
GENERATE_GOLDEN = False GENERATE_GOLDEN = False
@ -1633,7 +1638,7 @@ def test_cache_nomap_clue2():
some_cache = ds.DatasetCache(session_id=session_id, size=0) some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2) ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache) ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache)
num_epoch = 4 num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -1710,7 +1715,7 @@ def test_cache_nomap_csv2():
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"], ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2) column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache) ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache)
num_epoch = 4 num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True) iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -2124,6 +2129,139 @@ def test_cache_nomap_failure5():
logger.info('test_cache_nomap_failure5 Ended.\n') logger.info('test_cache_nomap_failure5 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_lambda():
"""
Test cache after map op with a python lambda function.
Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised.
Cache
|
Map(lambda function1, lambda function2)
|
TFRecord
"""
logger.info("Test cache nomap pyfunc lambda")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 12 records in it
data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)]
data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
other_cache = ds.DatasetCache(session_id=session_id, size=0)
ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_builtin():
"""
Test cache after map op with a python builtin PyFunc.
An error will be raised if the builtin pyfunc containing random operation.
Cache
|
Map([builtin pyfunc1, builtin pyfunc2])
|
TFRecord
"""
logger.info("Test cache nomap pyfunc builtin")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()],
input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_function():
"""
Test cache after map op with a python customized function.
Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised.
Cache
|
Map([function1, function2])
|
TFRecord
"""
@py_vision.not_random
def not_random_func(x):
return np.ones(x.shape, dtype=x.dtype)
def normal_func(x):
return np.ones(x.shape, dtype=x.dtype)
logger.info("Test cache nomap pyfunc function")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
else:
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
if __name__ == '__main__': if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py' # This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
# since cache server is required to be brought up first # since cache server is required to be brought up first
@ -2180,3 +2318,6 @@ if __name__ == '__main__':
test_cache_nomap_failure3() test_cache_nomap_failure3()
test_cache_nomap_failure4() test_cache_nomap_failure4()
test_cache_nomap_failure5() test_cache_nomap_failure5()
test_cache_nomap_pyfunc_lambda()
test_cache_nomap_pyfunc_builtin()
test_cache_nomap_pyfunc_function()

Loading…
Cancel
Save