You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
196 lines
6.4 KiB
196 lines
6.4 KiB
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
import mindspore._c_dataengine as cde
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from mindspore.dataset.text import to_str, to_bytes
|
|
|
|
import mindspore.dataset as ds
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
# pylint: disable=comparison-with-itself
|
|
def test_basic():
|
|
x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
|
|
n = cde.Tensor(x)
|
|
arr = n.as_array()
|
|
np.testing.assert_array_equal(x, arr)
|
|
|
|
|
|
def compare(strings):
|
|
arr = np.array(strings, dtype='S')
|
|
|
|
def gen():
|
|
yield arr,
|
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
|
|
|
for d in data:
|
|
np.testing.assert_array_equal(d[0], arr)
|
|
|
|
|
|
def test_generator():
|
|
compare(["ab"])
|
|
compare(["ab", "cde", "121"])
|
|
compare([["ab", "cde", "121"], ["x", "km", "789"]])
|
|
|
|
|
|
def test_batching_strings():
|
|
def gen():
|
|
yield np.array(["ab", "cde", "121"], dtype='S'),
|
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
|
|
|
|
with pytest.raises(RuntimeError) as info:
|
|
for _ in data:
|
|
pass
|
|
assert "[Batch ERROR] Batch does not support" in str(info.value)
|
|
|
|
|
|
def test_map():
|
|
def gen():
|
|
yield np.array(["ab cde 121"], dtype='S'),
|
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
|
|
|
def split(b):
|
|
s = to_str(b)
|
|
splits = s.item().split()
|
|
return np.array(splits, dtype='S')
|
|
|
|
data = data.map(input_columns=["col"], operations=split)
|
|
expected = np.array(["ab", "cde", "121"], dtype='S')
|
|
for d in data:
|
|
np.testing.assert_array_equal(d[0], expected)
|
|
|
|
|
|
def test_map2():
|
|
def gen():
|
|
yield np.array(["ab cde 121"], dtype='S'),
|
|
|
|
data = ds.GeneratorDataset(gen, column_names=["col"])
|
|
|
|
def upper(b):
|
|
out = np.char.upper(b)
|
|
return out
|
|
|
|
data = data.map(input_columns=["col"], operations=upper)
|
|
expected = np.array(["AB CDE 121"], dtype='S')
|
|
for d in data:
|
|
np.testing.assert_array_equal(d[0], expected)
|
|
|
|
|
|
line = np.array(["This is a text file.",
|
|
"Be happy every day.",
|
|
"Good luck to everyone."])
|
|
|
|
words = np.array([["This", "text", "file", "a"],
|
|
["Be", "happy", "day", "b"],
|
|
["女", "", "everyone", "c"]])
|
|
|
|
chinese = np.array(["今天天气太好了我们一起去外面玩吧",
|
|
"男默女泪",
|
|
"江州市长江大桥参加了长江大桥的通车仪式"])
|
|
|
|
|
|
def test_tfrecord1():
|
|
s = ds.Schema()
|
|
s.add_column("line", "string", [])
|
|
s.add_column("words", "string", [-1])
|
|
s.add_column("chinese", "string", [])
|
|
|
|
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
|
|
|
|
for i, d in enumerate(data.create_dict_iterator()):
|
|
assert d["line"].shape == line[i].shape
|
|
assert d["words"].shape == words[i].shape
|
|
assert d["chinese"].shape == chinese[i].shape
|
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
|
np.testing.assert_array_equal(words[i], to_str(d["words"]))
|
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
|
|
|
|
|
def test_tfrecord2():
|
|
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False,
|
|
schema='../data/dataset/testTextTFRecord/datasetSchema.json')
|
|
for i, d in enumerate(data.create_dict_iterator()):
|
|
assert d["line"].shape == line[i].shape
|
|
assert d["words"].shape == words[i].shape
|
|
assert d["chinese"].shape == chinese[i].shape
|
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
|
np.testing.assert_array_equal(words[i], to_str(d["words"]))
|
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
|
|
|
|
|
def test_tfrecord3():
|
|
s = ds.Schema()
|
|
s.add_column("line", mstype.string, [])
|
|
s.add_column("words", mstype.string, [-1, 2])
|
|
s.add_column("chinese", mstype.string, [])
|
|
|
|
data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
|
|
|
|
for i, d in enumerate(data.create_dict_iterator()):
|
|
assert d["line"].shape == line[i].shape
|
|
assert d["words"].shape == words[i].reshape([2, 2]).shape
|
|
assert d["chinese"].shape == chinese[i].shape
|
|
np.testing.assert_array_equal(line[i], to_str(d["line"]))
|
|
np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
|
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
|
|
|
|
|
def create_text_mindrecord():
|
|
# methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord
|
|
from mindspore.mindrecord import FileWriter
|
|
|
|
mindrecord_file_name = "test.mindrecord"
|
|
data = [{"english": "This is a text file.",
|
|
"chinese": "今天天气太好了我们一起去外面玩吧"},
|
|
{"english": "Be happy every day.",
|
|
"chinese": "男默女泪"},
|
|
{"english": "Good luck to everyone.",
|
|
"chinese": "江州市长江大桥参加了长江大桥的通车仪式"},
|
|
]
|
|
writer = FileWriter(mindrecord_file_name)
|
|
schema = {"english": {"type": "string"},
|
|
"chinese": {"type": "string"},
|
|
}
|
|
writer.add_schema(schema)
|
|
writer.write_raw_data(data)
|
|
writer.commit()
|
|
|
|
|
|
def test_mindrecord():
|
|
data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False)
|
|
|
|
for i, d in enumerate(data.create_dict_iterator()):
|
|
assert d["english"].shape == line[i].shape
|
|
assert d["chinese"].shape == chinese[i].shape
|
|
np.testing.assert_array_equal(line[i], to_str(d["english"]))
|
|
np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_generator()
|
|
test_basic()
|
|
test_batching_strings()
|
|
test_map()
|
|
test_map2()
|
|
test_tfrecord1()
|
|
test_tfrecord2()
|
|
test_tfrecord3()
|
|
test_mindrecord()
|