|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
# ==============================================================================
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
@ -163,7 +163,6 @@ def test_sync_exception_01():
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_sync_exception_01")
|
|
|
|
|
shuffle_size = 4
|
|
|
|
|
batch_size = 10
|
|
|
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
|
|
|
|
|
@ -171,11 +170,9 @@ def test_sync_exception_01():
|
|
|
|
|
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
dataset = dataset.shuffle(shuffle_size)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
assert "shuffle" in str(e)
|
|
|
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
|
|
|
dataset.shuffle(shuffle_size)
|
|
|
|
|
assert "No shuffle after sync operators" in str(e.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_02():
|
|
|
|
@ -183,7 +180,6 @@ def test_sync_exception_02():
|
|
|
|
|
Test sync: with duplicated condition name
|
|
|
|
|
"""
|
|
|
|
|
logger.info("test_sync_exception_02")
|
|
|
|
|
batch_size = 6
|
|
|
|
|
|
|
|
|
|
dataset = ds.GeneratorDataset(gen, column_names=["input"])
|
|
|
|
|
|
|
|
|
@ -192,11 +188,9 @@ def test_sync_exception_02():
|
|
|
|
|
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
dataset = dataset.sync_wait(num_batch=2, condition_name="every batch")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
assert "name" in str(e)
|
|
|
|
|
dataset = dataset.batch(batch_size)
|
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
|
|
|
dataset.sync_wait(num_batch=2, condition_name="every batch")
|
|
|
|
|
assert "Condition name is already in use" in str(e.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_03():
|
|
|
|
@ -209,12 +203,9 @@ def test_sync_exception_03():
|
|
|
|
|
|
|
|
|
|
aug = Augment(0)
|
|
|
|
|
# try to create dataset with batch_size < 0
|
|
|
|
|
try:
|
|
|
|
|
dataset = dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
assert "num_batch" in str(e)
|
|
|
|
|
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
with pytest.raises(ValueError) as e:
|
|
|
|
|
dataset.sync_wait(condition_name="every batch", num_batch=-1, callback=aug.update)
|
|
|
|
|
assert "num_batch need to be greater than 0." in str(e.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_04():
|
|
|
|
@ -230,14 +221,13 @@ def test_sync_exception_04():
|
|
|
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
count = 0
|
|
|
|
|
try:
|
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
|
|
|
for _ in dataset.create_dict_iterator():
|
|
|
|
|
count += 1
|
|
|
|
|
data = {"loss": count}
|
|
|
|
|
# dataset.disable_sync()
|
|
|
|
|
dataset.sync_update(condition_name="every batch", num_batch=-1, data=data)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
assert "batch" in str(e)
|
|
|
|
|
assert "Sync_update batch size can only be positive" in str(e.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sync_exception_05():
|
|
|
|
|
"""
|
|
|
|
@ -251,15 +241,15 @@ def test_sync_exception_05():
|
|
|
|
|
# try to create dataset with batch_size < 0
|
|
|
|
|
dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update)
|
|
|
|
|
dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess])
|
|
|
|
|
try:
|
|
|
|
|
with pytest.raises(RuntimeError) as e:
|
|
|
|
|
for _ in dataset.create_dict_iterator():
|
|
|
|
|
dataset.disable_sync()
|
|
|
|
|
count += 1
|
|
|
|
|
data = {"loss": count}
|
|
|
|
|
dataset.disable_sync()
|
|
|
|
|
dataset.sync_update(condition_name="every", data=data)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
assert "name" in str(e)
|
|
|
|
|
assert "Condition name not found" in str(e.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
test_simple_sync_wait()
|
|
|
|
|