|
|
@ -16,17 +16,15 @@
|
|
|
|
This is the test module for mindrecord
|
|
|
|
This is the test module for mindrecord
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
import collections
|
|
|
|
import collections
|
|
|
|
import json
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
import string
|
|
|
|
import string
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
import mindspore.dataset.transforms.vision.c_transforms as vision
|
|
|
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore import log as logger
|
|
|
|
from mindspore.dataset.transforms.vision import Inter
|
|
|
|
|
|
|
|
from mindspore.mindrecord import FileWriter
|
|
|
|
from mindspore.mindrecord import FileWriter
|
|
|
|
|
|
|
|
|
|
|
|
FILES_NUM = 4
|
|
|
|
FILES_NUM = 4
|
|
|
@ -52,9 +50,9 @@ def add_and_remove_cv_file():
|
|
|
|
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
|
|
|
writer = FileWriter(CV_FILE_NAME, FILES_NUM)
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
cv_schema_json = {"id": {"type": "int32"},
|
|
|
|
cv_schema_json = {"id": {"type": "int32"},
|
|
|
|
"file_name": {"type": "string"},
|
|
|
|
"file_name": {"type": "string"},
|
|
|
|
"label": {"type": "int32"},
|
|
|
|
"label": {"type": "int32"},
|
|
|
|
"data": {"type": "bytes"}}
|
|
|
|
"data": {"type": "bytes"}}
|
|
|
|
writer.add_schema(cv_schema_json, "img_schema")
|
|
|
|
writer.add_schema(cv_schema_json, "img_schema")
|
|
|
|
writer.add_index(["file_name", "label"])
|
|
|
|
writer.add_index(["file_name", "label"])
|
|
|
|
writer.write_raw_data(data)
|
|
|
|
writer.write_raw_data(data)
|
|
|
@ -85,14 +83,14 @@ def add_and_remove_nlp_file():
|
|
|
|
writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
|
|
|
|
writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
|
|
|
|
nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
|
|
|
|
"rating": {"type": "float32"},
|
|
|
|
"rating": {"type": "float32"},
|
|
|
|
"input_ids": {"type": "int64",
|
|
|
|
"input_ids": {"type": "int64",
|
|
|
|
"shape": [-1]},
|
|
|
|
"shape": [-1]},
|
|
|
|
"input_mask": {"type": "int64",
|
|
|
|
"input_mask": {"type": "int64",
|
|
|
|
"shape": [1, -1]},
|
|
|
|
"shape": [1, -1]},
|
|
|
|
"segment_ids": {"type": "int64",
|
|
|
|
"segment_ids": {"type": "int64",
|
|
|
|
"shape": [2, -1]}
|
|
|
|
"shape": [2, -1]}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
writer.set_header_size(1 << 14)
|
|
|
|
writer.set_header_size(1 << 14)
|
|
|
|
writer.set_page_size(1 << 15)
|
|
|
|
writer.set_page_size(1 << 15)
|
|
|
|
writer.add_schema(nlp_schema_json, "nlp_schema")
|
|
|
|
writer.add_schema(nlp_schema_json, "nlp_schema")
|
|
|
@ -110,6 +108,7 @@ def add_and_remove_nlp_file():
|
|
|
|
os.remove("{}".format(x))
|
|
|
|
os.remove("{}".format(x))
|
|
|
|
os.remove("{}.db".format(x))
|
|
|
|
os.remove("{}.db".format(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
|
|
|
|
"""tutorial for cv minderdataset."""
|
|
|
|
"""tutorial for cv minderdataset."""
|
|
|
|
columns_list = ["label", "file_name", "data"]
|
|
|
|
columns_list = ["label", "file_name", "data"]
|
|
|
@ -130,7 +129,7 @@ def test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file):
|
|
|
|
if item['label'] == -1:
|
|
|
|
if item['label'] == -1:
|
|
|
|
num_padded_iter += 1
|
|
|
|
num_padded_iter += 1
|
|
|
|
assert item['file_name'] == bytes(padded_sample['file_name'],
|
|
|
|
assert item['file_name'] == bytes(padded_sample['file_name'],
|
|
|
|
encoding='utf8')
|
|
|
|
encoding='utf8')
|
|
|
|
assert item['label'] == padded_sample['label']
|
|
|
|
assert item['label'] == padded_sample['label']
|
|
|
|
assert (item['data'] == np.array(list(padded_sample['data']))).all()
|
|
|
|
assert (item['data'] == np.array(list(padded_sample['data']))).all()
|
|
|
|
num_iter += 1
|
|
|
|
num_iter += 1
|
|
|
@ -177,6 +176,7 @@ def test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file):
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_file):
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
@ -248,6 +248,7 @@ def test_cv_minddataset_partition_padded_samples_multi_epoch(add_and_remove_cv_f
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv_file):
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
"""tutorial for cv minddataset."""
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
@ -273,6 +274,7 @@ def test_cv_minddataset_partition_padded_samples_no_dividsible(add_and_remove_cv
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
partitions(4, 1)
|
|
|
|
partitions(4, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_and_remove_cv_file):
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
|
|
|
|
|
|
|
@ -291,8 +293,10 @@ def test_cv_minddataset_partition_padded_samples_dataset_size_no_divisible(add_a
|
|
|
|
num_padded=num_padded)
|
|
|
|
num_padded=num_padded)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
with pytest.raises(RuntimeError):
|
|
|
|
data_set.get_dataset_size() == 3
|
|
|
|
data_set.get_dataset_size() == 3
|
|
|
|
|
|
|
|
|
|
|
|
partitions(4, 1)
|
|
|
|
partitions(4, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_remove_cv_file):
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
|
|
|
|
|
|
|
@ -314,9 +318,11 @@ def test_cv_minddataset_partition_padded_samples_no_equal_column_list(add_and_re
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(Exception, match="padded_sample cannot match columns_list."):
|
|
|
|
with pytest.raises(Exception, match="padded_sample cannot match columns_list."):
|
|
|
|
partitions(4, 2)
|
|
|
|
partitions(4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_cv_file):
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
padded_sample = data[0]
|
|
|
|
padded_sample = data[0]
|
|
|
@ -336,9 +342,11 @@ def test_cv_minddataset_partition_padded_samples_no_column_list(add_and_remove_c
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(Exception, match="padded_sample is specified and requires columns_list as well."):
|
|
|
|
with pytest.raises(Exception, match="padded_sample is specified and requires columns_list as well."):
|
|
|
|
partitions(4, 2)
|
|
|
|
partitions(4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv_file):
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
@ -357,9 +365,11 @@ def test_cv_minddataset_partition_padded_samples_no_num_padded(add_and_remove_cv
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(Exception, match="padded_sample is specified and requires num_padded as well."):
|
|
|
|
with pytest.raises(Exception, match="padded_sample is specified and requires num_padded as well."):
|
|
|
|
partitions(4, 2)
|
|
|
|
partitions(4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file):
|
|
|
|
def test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remove_cv_file):
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
columns_list = ["data", "file_name", "label"]
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
|
data = get_data(CV_DIR_NAME)
|
|
|
@ -378,18 +388,18 @@ def test_cv_minddataset_partition_padded_samples_no_padded_samples(add_and_remov
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
|
|
|
|
|
|
|
|
|
|
|
|
with pytest.raises(Exception, match="num_padded is specified but padded_sample is not."):
|
|
|
|
with pytest.raises(Exception, match="num_padded is specified but padded_sample is not."):
|
|
|
|
partitions(4, 2)
|
|
|
|
partitions(4, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
|
|
|
|
def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
|
|
|
|
columns_list = ["input_ids", "id", "rating"]
|
|
|
|
columns_list = ["input_ids", "id", "rating"]
|
|
|
|
|
|
|
|
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
padded_sample = data[0]
|
|
|
|
padded_sample = data[0]
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
|
|
|
|
padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
num_readers = 4
|
|
|
|
num_readers = 4
|
|
|
|
|
|
|
|
|
|
|
@ -406,7 +416,9 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
|
|
|
|
for item in data_set.create_dict_iterator(num_epochs=1):
|
|
|
|
for item in data_set.create_dict_iterator(num_epochs=1):
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(item["input_ids"], item["input_ids"].shape))
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
|
|
|
|
|
|
|
|
item["input_ids"],
|
|
|
|
|
|
|
|
item["input_ids"].shape))
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
num_padded_iter += 1
|
|
|
|
num_padded_iter += 1
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
@ -420,13 +432,14 @@ def test_nlp_minddataset_reader_basic_padded_samples(add_and_remove_nlp_file):
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(5, 5, 3)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
partitions(9, 8, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file):
|
|
|
|
def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_nlp_file):
|
|
|
|
columns_list = ["input_ids", "id", "rating"]
|
|
|
|
columns_list = ["input_ids", "id", "rating"]
|
|
|
|
|
|
|
|
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
|
|
|
|
padded_sample = data[0]
|
|
|
|
padded_sample = data[0]
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
|
|
|
|
padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
num_readers = 4
|
|
|
|
num_readers = 4
|
|
|
|
repeat_size = 3
|
|
|
|
repeat_size = 3
|
|
|
@ -451,7 +464,9 @@ def test_nlp_minddataset_reader_basic_padded_samples_multi_epoch(add_and_remove_
|
|
|
|
for item in data_set.create_dict_iterator(num_epochs=1):
|
|
|
|
for item in data_set.create_dict_iterator(num_epochs=1):
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(item["input_ids"], item["input_ids"].shape))
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
|
|
|
|
|
|
|
|
item["input_ids"],
|
|
|
|
|
|
|
|
item["input_ids"].shape))
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
num_padded_iter += 1
|
|
|
|
num_padded_iter += 1
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
@ -488,7 +503,7 @@ def test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_resul
|
|
|
|
|
|
|
|
|
|
|
|
padded_sample = {}
|
|
|
|
padded_sample = {}
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['id'] = "-1"
|
|
|
|
padded_sample['input_ids'] = np.array([-1,-1,-1,-1], dtype=np.int64)
|
|
|
|
padded_sample['input_ids'] = np.array([-1, -1, -1, -1], dtype=np.int64)
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
padded_sample['rating'] = 1.0
|
|
|
|
num_readers = 4
|
|
|
|
num_readers = 4
|
|
|
|
repeat_size = 3
|
|
|
|
repeat_size = 3
|
|
|
@ -512,14 +527,15 @@ def test_nlp_minddataset_reader_basic_padded_samples_check_whole_reshuffle_resul
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------"
|
|
|
|
logger.info("-------------- item[input_ids]: {}, shape: {} -----------------"
|
|
|
|
.format(item["input_ids"], item["input_ids"].shape))
|
|
|
|
.format(item["input_ids"], item["input_ids"].shape))
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
if item['id'] == bytes('-1', encoding='utf-8'):
|
|
|
|
num_padded_iter += 1
|
|
|
|
num_padded_iter += 1
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
|
assert item['id'] == bytes(padded_sample['id'], encoding='utf-8')
|
|
|
|
assert (item['input_ids'] == padded_sample['input_ids']).all()
|
|
|
|
assert (item['input_ids'] == padded_sample['input_ids']).all()
|
|
|
|
assert (item['rating'] == padded_sample['rating']).all()
|
|
|
|
assert (item['rating'] == padded_sample['rating']).all()
|
|
|
|
# save epoch result
|
|
|
|
# save epoch result
|
|
|
|
epoch_result[partition_id][int(inner_num_iter / dataset_size)][inner_num_iter % dataset_size] = item["id"]
|
|
|
|
epoch_result[partition_id][int(inner_num_iter / dataset_size)][inner_num_iter % dataset_size] = item[
|
|
|
|
|
|
|
|
"id"]
|
|
|
|
num_iter += 1
|
|
|
|
num_iter += 1
|
|
|
|
inner_num_iter += 1
|
|
|
|
inner_num_iter += 1
|
|
|
|
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
|
|
|
|
assert epoch_result[partition_id][0] not in (epoch_result[partition_id][1], epoch_result[partition_id][2])
|
|
|
@ -651,6 +667,7 @@ def inputs(vectors, maxlen=50):
|
|
|
|
segment = [0] * maxlen
|
|
|
|
segment = [0] * maxlen
|
|
|
|
return input_, mask, segment
|
|
|
|
return input_, mask, segment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file)
|
|
|
|
test_cv_minddataset_reader_basic_padded_samples(add_and_remove_cv_file)
|
|
|
|
test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file)
|
|
|
|
test_cv_minddataset_partition_padded_samples(add_and_remove_cv_file)
|
|
|
|