!909 Adding fix for set seed

Merge pull request !909 from EricZ/test_random
pull/909/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 79f087c2af

@ -15,7 +15,7 @@
"""
The configuration manager.
"""
import random
import mindspore._c_dataengine as cde
INT32_MAX = 2147483647
@ -32,6 +32,12 @@ class ConfigurationManager:
"""
Set the seed to be used in any random generator. This is used to produce deterministic results.
Note:
This set_seed function sets the seed in the python random library function for deterministic
python augmentations using randomness. This set_seed function should be called with every
iterator created to reset the random seed. In our pipeline this does not guarantee
deterministic results with num_parallel_workers > 1.
Args:
seed(int): seed to be set
@ -47,6 +53,7 @@ class ConfigurationManager:
if seed < 0 or seed > UINT32_MAX:
raise ValueError("Seed given is not within the required range")
self.config.set_seed(seed)
random.seed(seed)
def get_seed(self):
"""

File diff suppressed because it is too large Load Diff

@ -36,6 +36,7 @@ def test_textline_dataset_all_file():
assert(count == 5)
def test_textline_dataset_totext():
ds.config.set_num_parallel_workers(4)
data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False)
count = 0
line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."]

@ -37,7 +37,7 @@ def visualize(first, mse, second):
plt.subplot(142)
plt.imshow(second)
plt.title("py random_color_jitter image")
plt.title("py random_color_adjust image")
plt.subplot(143)
plt.imshow(first - second)
@ -50,20 +50,20 @@ def diff_mse(in1, in2):
return mse * 100
def test_random_color_jitter_op_brightness():
def test_random_color_adjust_op_brightness():
"""
Test RandomColorAdjust op
"""
logger.info("test_random_color_jitter_op")
logger.info("test_random_color_adjust_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_jitter_op = c_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
random_adjust_op = c_vision.RandomColorAdjust((0.8, 0.8), (1, 1), (1, 1), (0, 0))
ctrans = [decode_op,
random_jitter_op,
random_adjust_op,
]
data1 = data1.map(input_columns=["image"], operations=ctrans)
@ -100,20 +100,20 @@ def test_random_color_jitter_op_brightness():
# visualize(c_image, mse, py_image)
def test_random_color_jitter_op_contrast():
def test_random_color_adjust_op_contrast():
"""
Test RandomColorAdjust op
"""
logger.info("test_random_color_jitter_op")
logger.info("test_random_color_adjust_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_jitter_op = c_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0))
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (0.5, 0.5), (1, 1), (0, 0))
ctrans = [decode_op,
random_jitter_op
random_adjust_op
]
data1 = data1.map(input_columns=["image"], operations=ctrans)
@ -156,20 +156,20 @@ def test_random_color_jitter_op_contrast():
# visualize(c_image, mse, py_image)
def test_random_color_jitter_op_saturation():
def test_random_color_adjust_op_saturation():
"""
Test RandomColorAdjust op
"""
logger.info("test_random_color_jitter_op")
logger.info("test_random_color_adjust_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0))
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (0.5, 0.5), (0, 0))
ctrans = [decode_op,
random_jitter_op
random_adjust_op
]
data1 = data1.map(input_columns=["image"], operations=ctrans)
@ -209,20 +209,20 @@ def test_random_color_jitter_op_saturation():
# visualize(c_image, mse, py_image)
def test_random_color_jitter_op_hue():
def test_random_color_adjust_op_hue():
"""
Test RandomColorAdjust op
"""
logger.info("test_random_color_jitter_op")
logger.info("test_random_color_adjust_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
decode_op = c_vision.Decode()
random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2))
random_adjust_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2))
ctrans = [decode_op,
random_jitter_op,
random_adjust_op,
]
data1 = data1.map(input_columns=["image"], operations=ctrans)
@ -264,7 +264,7 @@ def test_random_color_jitter_op_hue():
if __name__ == "__main__":
test_random_color_jitter_op_brightness()
test_random_color_jitter_op_contrast()
test_random_color_jitter_op_saturation()
test_random_color_jitter_op_hue()
test_random_color_adjust_op_brightness()
test_random_color_adjust_op_contrast()
test_random_color_adjust_op_saturation()
test_random_color_adjust_op_hue()

@ -17,8 +17,8 @@ Testing RandomCropAndResize op in DE
"""
import matplotlib.pyplot as plt
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore import log as logger
from mindspore import log as logger
import mindspore.dataset as ds
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -45,9 +45,9 @@ def visualize(a, mse, original):
def test_random_crop_op():
"""
Test RandomCropAndResize op
Test RandomCrop Op
"""
logger.info("test_random_crop_and_resize_op")
logger.info("test_random_crop_op")
# First dataset
data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False)
@ -67,3 +67,4 @@ def test_random_crop_op():
if __name__ == "__main__":
test_random_crop_op()

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
@ -34,9 +35,9 @@ def test_rename():
for i, item in enumerate(data.create_dict_iterator()):
logger.info("item[mask] is {}".format(item["masks"]))
assert item["masks"].all() == item["input_ids"].all()
np.testing.assert_equal (item["masks"], item["input_ids"])
logger.info("item[seg_ids] is {}".format(item["seg_ids"]))
assert item["segment_ids"].all() == item["seg_ids"].all()
np.testing.assert_equal (item["segment_ids"], item["seg_ids"])
# need to consume the data in the buffer
num_iter += 1
logger.info("Number of data in data: {}".format(num_iter))

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
from util import save_and_check
import mindspore.dataset as ds
@ -117,6 +118,27 @@ def test_shuffle_05():
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
def test_shuffle_06():
"""
Test shuffle: with set seed, both datasets
"""
logger.info("test_shuffle_06")
# define parameters
buffer_size = 13
seed = 1
# apply dataset operations
data1 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
ds.config.set_seed(seed)
data1 = data1.shuffle(buffer_size=buffer_size)
data2 = ds.TFRecordDataset(DATA_DIR, shuffle=ds.Shuffle.FILES)
data2 = data2.shuffle(buffer_size=buffer_size)
for item1, item2 in zip(data1.create_dict_iterator(), data2.create_dict_iterator()):
np.testing.assert_equal (item1, item2)
def test_shuffle_exception_01():
"""
Test shuffle exception: buffer_size<0
@ -231,6 +253,7 @@ if __name__ == '__main__':
test_shuffle_03()
test_shuffle_04()
test_shuffle_05()
test_shuffle_06()
test_shuffle_exception_01()
test_shuffle_exception_02()
test_shuffle_exception_03()

Loading…
Cancel
Save