# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Testing concatenate op """ import numpy as np import pytest import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as data_trans def test_concatenate_op_all(): def gen(): yield (np.array([5., 6., 7., 8.], dtype=np.float),) prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) data = data.map(operations=concatenate_op, input_columns=["col"]) expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, 11., 12.]) for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], expected) def test_concatenate_op_none(): def gen(): yield (np.array([5., 6., 7., 8.], dtype=np.float),) data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate() data = data.map(operations=concatenate_op, input_columns=["col"]) for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], np.array([5., 6., 7., 8.], dtype=np.float)) def test_concatenate_op_string(): def gen(): yield (np.array(["ss", "ad"], dtype='S'),) prepend_tensor = np.array(["dw", "df"], dtype='S') append_tensor = np.array(["dwsdf", "df"], dtype='S') data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate(0, prepend_tensor, append_tensor) data = data.map(operations=concatenate_op, input_columns=["col"]) expected = np.array(["dw", "df", "ss", "ad", "dwsdf", "df"], dtype='S') for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], expected) def test_concatenate_op_multi_input_string(): prepend_tensor = np.array(["dw", "df"], dtype='S') append_tensor = np.array(["dwsdf", "df"], dtype='S') data = ([["1", "2", "d"]], [["3", "4", "e"]]) data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor, append=append_tensor) data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"], output_columns=["out1"]) expected = np.array(["dw", "df", "1", "2", "d", "3", "4", "e", "dwsdf", "df"], dtype='S') for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], expected) def test_concatenate_op_multi_input_numeric(): prepend_tensor = np.array([3, 5]) data = ([[1, 2]], [[3, 4]]) data = ds.NumpySlicesDataset(data, column_names=["col1", "col2"]) concatenate_op = data_trans.Concatenate(0, prepend=prepend_tensor) data = data.map(operations=concatenate_op, input_columns=["col1", "col2"], column_order=["out1"], output_columns=["out1"]) expected = np.array([3, 5, 1, 2, 3, 4]) for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], expected) def test_concatenate_op_type_mismatch(): def gen(): yield (np.array([3, 4], dtype=np.float),) prepend_tensor = np.array(["ss", "ad"], dtype='S') data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate(0, prepend_tensor) data = data.map(operations=concatenate_op, input_columns=["col"]) with pytest.raises(RuntimeError) as error_info: for _ in data: pass assert "input datatype does not match" in str(error_info.value) def test_concatenate_op_type_mismatch2(): def gen(): yield (np.array(["ss", "ad"], dtype='S'),) prepend_tensor = np.array([3, 5], dtype=np.float) data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate(0, prepend_tensor) data = data.map(operations=concatenate_op, input_columns=["col"]) with pytest.raises(RuntimeError) as error_info: for _ in data: pass assert "input datatype does not match" in str(error_info.value) def test_concatenate_op_incorrect_dim(): def gen(): yield (np.array([["ss", "ad"], ["ss", "ad"]], dtype='S'),) prepend_tensor = np.array(["ss", "ss"], dtype='S') concatenate_op = data_trans.Concatenate(0, prepend_tensor) data = ds.GeneratorDataset(gen, column_names=["col"]) data = data.map(operations=concatenate_op, input_columns=["col"]) with pytest.raises(RuntimeError) as error_info: for _ in data: pass assert "only 1D input supported" in str(error_info.value) def test_concatenate_op_wrong_axis(): with pytest.raises(ValueError) as error_info: data_trans.Concatenate(2) assert "only 1D concatenation supported." in str(error_info.value) def test_concatenate_op_negative_axis(): def gen(): yield (np.array([5., 6., 7., 8.], dtype=np.float),) prepend_tensor = np.array([1.4, 2., 3., 4., 4.5], dtype=np.float) append_tensor = np.array([9., 10.3, 11., 12.], dtype=np.float) data = ds.GeneratorDataset(gen, column_names=["col"]) concatenate_op = data_trans.Concatenate(-1, prepend_tensor, append_tensor) data = data.map(operations=concatenate_op, input_columns=["col"]) expected = np.array([1.4, 2., 3., 4., 4.5, 5., 6., 7., 8., 9., 10.3, 11., 12.]) for data_row in data.create_tuple_iterator(output_numpy=True): np.testing.assert_array_equal(data_row[0], expected) def test_concatenate_op_incorrect_input_dim(): prepend_tensor = np.array([["ss", "ad"], ["ss", "ad"]], dtype='S') with pytest.raises(ValueError) as error_info: data_trans.Concatenate(0, prepend_tensor) assert "can only prepend 1D arrays." in str(error_info.value) if __name__ == "__main__": test_concatenate_op_all() test_concatenate_op_none() test_concatenate_op_string() test_concatenate_op_multi_input_string() test_concatenate_op_multi_input_numeric() test_concatenate_op_type_mismatch() test_concatenate_op_type_mismatch2() test_concatenate_op_incorrect_dim() test_concatenate_op_negative_axis() test_concatenate_op_wrong_axis() test_concatenate_op_incorrect_input_dim()