Bug in weak reference.

Add new test cases
pull/434/head
hesham 5 years ago
parent fb6c7ba2e1
commit 3c02c82771

@ -28,10 +28,10 @@ ITERATORS_LIST = list()
def _cleanup(): def _cleanup():
for itr in ITERATORS_LIST: for itr_ref in ITERATORS_LIST:
iter_ref = itr() itr = itr_ref()
if itr is not None: if itr is not None:
iter_ref.release() itr.release()
def alter_tree(node): def alter_tree(node):

@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import numpy as np import numpy as np
import pytest
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore.dataset.engine.iterators import ITERATORS_LIST, _cleanup
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"] DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json" SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
@ -41,3 +43,41 @@ def test_case_iterator():
check(COLUMNS[0:7]) check(COLUMNS[0:7])
check(COLUMNS[7:8]) check(COLUMNS[7:8])
check(COLUMNS[0:2:8]) check(COLUMNS[0:2:8])
def test_iterator_weak_ref():
ITERATORS_LIST.clear()
data = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR)
itr1 = data.create_tuple_iterator()
itr2 = data.create_tuple_iterator()
itr3 = data.create_tuple_iterator()
assert len(ITERATORS_LIST) == 3
assert sum(itr() is not None for itr in ITERATORS_LIST) == 3
del itr1
assert len(ITERATORS_LIST) == 3
assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
del itr2
assert len(ITERATORS_LIST) == 3
assert sum(itr() is not None for itr in ITERATORS_LIST) == 1
del itr3
assert len(ITERATORS_LIST) == 3
assert sum(itr() is not None for itr in ITERATORS_LIST) == 0
itr1 = data.create_tuple_iterator()
itr2 = data.create_tuple_iterator()
itr3 = data.create_tuple_iterator()
_cleanup()
with pytest.raises(AttributeError) as info:
itr2.get_next()
assert "object has no attribute 'depipeline'" in str(info.value)
del itr1
assert len(ITERATORS_LIST) == 6
assert sum(itr() is not None for itr in ITERATORS_LIST) == 2
_cleanup()

Loading…
Cancel
Save