From a22763b612ced9746d76db4af647d842688d3e64 Mon Sep 17 00:00:00 2001 From: hesham Date: Sun, 14 Jun 2020 13:27:17 -0400 Subject: [PATCH] Bug in CIFAR after removing GetMutabble buffer Fixes # I1KIPC --- .../engine/datasetops/source/cifar_op.cc | 9 ++++-- tests/ut/python/dataset/test_cifarop.py | 30 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc index f13af59199..542f193a8a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/cifar_op.cc @@ -367,9 +367,14 @@ Status CifarOp::ParseCifarData() { TensorShape({kCifarImageHeight, kCifarImageWidth, kCifarImageChannel}), data_schema_->column(0).type())); auto itr = image_tensor->begin(); - for (; itr != image_tensor->end(); itr++) { - *itr = block[cur_block_index++]; + uint32_t total_pix = kCifarImageHeight * kCifarImageWidth; + for (int pix = 0; pix < total_pix; ++pix) { + for (int ch = 0; ch < kCifarImageChannel; ++ch) { + *itr = block[cur_block_index + ch * total_pix + pix]; + itr++; + } } + cur_block_index += total_pix * kCifarImageChannel; cifar_image_label_pairs_.emplace_back(std::make_pair(image_tensor, labels)); } RETURN_IF_NOT_OK(cifar_raw_data_block_->PopFront(&block)); diff --git a/tests/ut/python/dataset/test_cifarop.py b/tests/ut/python/dataset/test_cifarop.py index 41777e1cea..e944f8703d 100644 --- a/tests/ut/python/dataset/test_cifarop.py +++ b/tests/ut/python/dataset/test_cifarop.py @@ -12,6 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import os + +import numpy as np + import mindspore.dataset as ds from mindspore import log as logger @@ -26,6 +30,20 @@ DATA_DIR_10 = "../data/dataset/testCifar10Data" DATA_DIR_100 = "../data/dataset/testCifar100Data" +def load_cifar(path): + raw = np.empty(0, dtype=np.uint8) + for file_name in os.listdir(path): + if file_name.endswith(".bin"): + with open(os.path.join(path, file_name), mode='rb') as file: + raw = np.append(raw, np.fromfile(file, dtype=np.uint8), axis=0) + raw = raw.reshape(-1, 3073) + labels = raw[:, 0] + images = raw[:, 1:] + images = images.reshape(-1, 3, 32, 32) + images = images.transpose(0, 2, 3, 1) + return images, labels + + def test_case_dataset_cifar10(): """ dataset parameter @@ -56,6 +74,18 @@ def test_case_dataset_cifar100(): assert num_iter == 100 +def test_reading_cifar10(): + """ + Validate CIFAR10 image readings + """ + data1 = ds.Cifar10Dataset(DATA_DIR_10, 100, shuffle=False) + images, labels = load_cifar(DATA_DIR_10) + for i, d in enumerate(data1.create_dict_iterator()): + np.testing.assert_array_equal(d["image"], images[i]) + np.testing.assert_array_equal(d["label"], labels[i]) + + if __name__ == '__main__': test_case_dataset_cifar10() test_case_dataset_cifar100() + test_reading_cifar10()