allow both image-idx and image.idx in mnist

pull/6225/head
Zirui Wu 5 years ago
parent d0e49c5cf8
commit 61004697c0

@ -362,8 +362,8 @@ Status MnistOp::ParseMnistData() {
}
Status MnistOp::WalkAllFiles() {
const std::string kImageExtension = "idx3-ubyte";
const std::string kLabelExtension = "idx1-ubyte";
const std::string img_ext = "idx3-ubyte";
const std::string lbl_ext = "idx1-ubyte";
const std::string train_prefix = "train";
const std::string test_prefix = "t10k";
@ -374,13 +374,13 @@ Status MnistOp::WalkAllFiles() {
if (dir_it != nullptr) {
while (dir_it->hasNext()) {
Path file = dir_it->next();
std::string filename = file.Basename();
if (filename.find(prefix + "-images-" + kImageExtension) != std::string::npos) {
std::string fname = file.Basename(); // name of the mnist file
if ((fname.find(prefix + "-images") != std::string::npos) && (fname.find(img_ext) != std::string::npos)) {
image_names_.push_back(file.toString());
MS_LOG(INFO) << "Mnist operator found image file at " << filename << ".";
} else if (filename.find(prefix + "-labels-" + kLabelExtension) != std::string::npos) {
MS_LOG(INFO) << "Mnist operator found image file at " << fname << ".";
} else if ((fname.find(prefix + "-labels") != std::string::npos) && (fname.find(lbl_ext) != std::string::npos)) {
label_names_.push_back(file.toString());
MS_LOG(INFO) << "Mnist Operator found label file at " << filename << ".";
MS_LOG(INFO) << "Mnist Operator found label file at " << fname << ".";
}
}
} else {

Loading…
Cancel
Save