|
|
@ -12,10 +12,10 @@
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ==============================================================================
|
|
|
|
# ==============================================================================
|
|
|
|
from util import save_and_check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
from util import save_and_check
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
|
|
|
DATA_DIR = ["../data/dataset/testTFTestAllTypes/test.data"]
|
|
|
|
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
|
|
|
SCHEMA_DIR = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
|
|
|
@ -24,7 +24,7 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float",
|
|
|
|
GENERATE_GOLDEN = False
|
|
|
|
GENERATE_GOLDEN = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def skip_test_case_0():
|
|
|
|
def test_2ops_repeat_shuffle():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Repeat then Shuffle
|
|
|
|
Test Repeat then Shuffle
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -43,11 +43,11 @@ def skip_test_case_0():
|
|
|
|
ds.config.set_seed(seed)
|
|
|
|
ds.config.set_seed(seed)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_0_result.npz"
|
|
|
|
filename = "test_2ops_repeat_shuffle.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def skip_test_case_0_reverse():
|
|
|
|
def skip_test_2ops_shuffle_repeat():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Shuffle then Repeat
|
|
|
|
Test Shuffle then Repeat
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -67,11 +67,11 @@ def skip_test_case_0_reverse():
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_0_reverse_result.npz"
|
|
|
|
filename = "test_2ops_shuffle_repeat.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_1():
|
|
|
|
def test_2ops_repeat_batch():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Repeat then Batch
|
|
|
|
Test Repeat then Batch
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -87,11 +87,11 @@ def test_case_1():
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_1_result.npz"
|
|
|
|
filename = "test_2ops_repeat_batch.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_1_reverse():
|
|
|
|
def test_2ops_batch_repeat():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Batch then Repeat
|
|
|
|
Test Batch then Repeat
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -107,11 +107,11 @@ def test_case_1_reverse():
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
data1 = data1.repeat(repeat_count)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_1_reverse_result.npz"
|
|
|
|
filename = "test_2ops_batch_repeat.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_2():
|
|
|
|
def test_2ops_batch_shuffle():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Batch then Shuffle
|
|
|
|
Test Batch then Shuffle
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -130,11 +130,11 @@ def test_case_2():
|
|
|
|
ds.config.set_seed(seed)
|
|
|
|
ds.config.set_seed(seed)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_2_result.npz"
|
|
|
|
filename = "test_2ops_batch_shuffle.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case_2_reverse():
|
|
|
|
def test_2ops_shuffle_batch():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Test Shuffle then Batch
|
|
|
|
Test Shuffle then Batch
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -153,5 +153,14 @@ def test_case_2_reverse():
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.shuffle(buffer_size=buffer_size)
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
data1 = data1.batch(batch_size, drop_remainder=True)
|
|
|
|
|
|
|
|
|
|
|
|
filename = "test_case_2_reverse_result.npz"
|
|
|
|
filename = "test_2ops_shuffle_batch.npz"
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
test_2ops_repeat_shuffle()
|
|
|
|
|
|
|
|
#test_2ops_shuffle_repeat()
|
|
|
|
|
|
|
|
test_2ops_repeat_batch()
|
|
|
|
|
|
|
|
test_2ops_batch_repeat()
|
|
|
|
|
|
|
|
test_2ops_batch_shuffle()
|
|
|
|
|
|
|
|
test_2ops_shuffle_batch()
|
|
|
|