diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index ce70476423..78fcdb7dd4 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -15,6 +15,7 @@ */ #include "dataset/api/de_pipeline.h" +#include #include #include @@ -45,7 +46,7 @@ namespace mindspore { namespace dataset { -using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *); +using pFunction = Status (DEPipeline::*)(const py::dict &, std::shared_ptr *, std::shared_ptr *); static std::unordered_map g_parse_op_func_ = { {kShuffle, &DEPipeline::ParseShuffleOp}, @@ -107,18 +108,44 @@ DEPipeline::~DEPipeline() { } // Function to add a Node to the Execution Tree. -Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, DsOpPtr *out) { - // For each operator, Parse through the list of arguments, - // then call the respective builder/constructor. +Status DEPipeline::AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output) { + // For each operator, Parse through the list of arguments, then call the respective builder/constructor. + // Note that each call to the parse function may result in building more than one dataset operator. + // For example, one call to ParseNNNOp may result in multiple internal C nodes: + // nodeA + // | + // nodeB + // | + // nodeC + // However, the python side dataset is more abstract, and it does not know about the potential subtree that + // is being built here. Since the python api is hooking tree nodes together (parent/child hookups), the + // python side needs to know about nodeA and NodeC to be able to appropriately hook up parents and child + // to this subtee. + // Thus, it is required that both the top-most parent and bottom-most child are returned from the parse + // function. + DsOpPtr top = nullptr; + DsOpPtr bottom = nullptr; auto iter = g_parse_op_func_.find(op_name); if (iter != g_parse_op_func_.end()) { pFunction func = iter->second; - RETURN_IF_NOT_OK((this->*func)(args, out)); + RETURN_IF_NOT_OK((this->*func)(args, &top, &bottom)); + + if (top == nullptr) { + RETURN_STATUS_UNEXPECTED("An operator was parsed but it did not produce a C node."); + } + + // It is not required that the parse function always produces the bottom pointer. If it's still null, + // then set top and bottom to be the same operator + if (bottom == nullptr) bottom = top; + + // Pack these pointers into a py dict so that we can return both back to python. + (*output)["top"] = top; + (*output)["bottom"] = bottom; } else { RETURN_STATUS_UNEXPECTED("No such Op"); } // Associate current dataset op node with the tree. - RETURN_IF_NOT_OK(tree_->AssociateNode(*out)); + RETURN_IF_NOT_OK(tree_->AssociateNode(top)); return Status::OK(); } // Function to add a child and parent relationship. @@ -300,7 +327,8 @@ Status DEPipeline::SetBatchParameters(const py::dict &args) { return Status::OK(); } -Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); if (!args["buffer_size"].is_none()) { (void)builder->SetShuffleSize(ToInt(args["buffer_size"])); @@ -322,7 +350,7 @@ Status DEPipeline::ParseShuffleOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } @@ -350,7 +378,8 @@ Status DEPipeline::BuildMindrecordSamplerChain(const py::handle &handle, return Status::OK(); } -Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["dataset_file"].is_none()) { std::string err_msg = "Error: at least one of dataset_files is missing"; RETURN_STATUS_UNEXPECTED(err_msg); @@ -403,13 +432,15 @@ Status DEPipeline::ParseMindRecordOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); num_rows_ = op->num_rows(); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *ptr) { - std::shared_ptr builder = std::make_shared(); +Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + MapOp::Builder map_builder; std::vector> tensor_op_list; + std::vector project_columns; if (args["operations"].is_none()) RETURN_STATUS_UNEXPECTED("Error: 'operations' is not set. \n"); @@ -419,15 +450,15 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * if (!value.is_none()) { if (key == "input_columns") { std::vector in_col_names = ToStringVector(args["input_columns"]); - (void)builder->SetInColNames(in_col_names); + (void)map_builder.SetInColNames(in_col_names); } else if (key == "output_columns") { - (void)builder->SetOutColNames(ToStringVector(value)); + (void)map_builder.SetOutColNames(ToStringVector(value)); } else if (key == "columns_order") { - (void)builder->SetColOrder(ToStringVector(value)); + project_columns = ToStringVector(value); } else if (key == "num_parallel_workers") { - (void)builder->SetNumWorkers(ToInt(value)); + (void)map_builder.SetNumWorkers(ToInt(value)); } else if (key == "prefetch_size") { - (void)builder->SetOpConnectorSize(ToInt(value)); + (void)map_builder.SetOpConnectorSize(ToInt(value)); } else if (key == "operations") { py::handle tensor_ops = args["operations"]; // operation can be a list of TensorOps or a single TensorOp. @@ -445,20 +476,34 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * } } if (tensor_op_list.empty()) RETURN_STATUS_UNEXPECTED("Error: tensor_op is invalid or not set."); - (void)builder->SetTensorFuncs(std::move(tensor_op_list)); + (void)map_builder.SetTensorFuncs(std::move(tensor_op_list)); } else { RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + std::shared_ptr map_op; + RETURN_IF_NOT_OK(map_builder.Build(&map_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(map_op)); + *top = map_op; + + // Add a project op over top of the map if the user wanted to reposition the columns + if (!project_columns.empty()) { + ProjectOp::Builder proj_builder(project_columns); + std::shared_ptr proj_op; + RETURN_IF_NOT_OK(proj_builder.Build(&proj_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(proj_op)); + RETURN_IF_NOT_OK(proj_op->AddChild(map_op)); + *top = proj_op; + *bottom = map_op; + } + return Status::OK(); } -Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); if (args["predicate"].is_none()) { @@ -489,11 +534,12 @@ Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; RETURN_STATUS_UNEXPECTED(err_msg); @@ -501,22 +547,24 @@ Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(RepeatOp::Builder(ToInt(args["count"])).Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseSkipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; RETURN_STATUS_UNEXPECTED(err_msg); } std::shared_ptr op; RETURN_IF_NOT_OK(SkipOp::Builder(ToInt(args["count"])).Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); for (auto arg : args) { std::string key = py::str(arg.first); @@ -538,11 +586,12 @@ Status DEPipeline::ParseGeneratorOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder; if (py::isinstance(args["batch_size"])) { batch_size_ = ToInt(args["batch_size"]); @@ -582,11 +631,12 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::vector mandatory_arguments = {"length_dependent_columns", "bucket_boundaries", "bucket_batch_sizes"}; for (auto name : mandatory_arguments) { @@ -632,11 +682,12 @@ Status DEPipeline::ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); // Right now barrier should only take num_rows_per_buffer = 1 // The reason for this is because having it otherwise can lead to blocking issues @@ -656,11 +707,12 @@ Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { int32_t prefetch_size = 0; if (args.contains("prefetch_size")) { if (args["prefetch_size"].is_none()) { @@ -687,11 +739,12 @@ Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::vector in_col_names; std::vector out_col_names; std::shared_ptr builder = std::make_shared(); @@ -718,48 +771,57 @@ Status DEPipeline::ParseRenameOp(const py::dict &args, std::shared_ptrSetOutColNames(out_col_names); std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseTakeOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; RETURN_STATUS_UNEXPECTED(err_msg); } std::shared_ptr op; RETURN_IF_NOT_OK(TakeOp::Builder(ToInt(args["count"])).Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments + std::vector files_list; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { - (void)builder->SetDatasetFilesList(ToStringVector(args["dataset_files"])); + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetDatasetFilesList(files_list); } else { std::string err_msg = "Error: at least one of dataset_files or schema_file is missing"; RETURN_STATUS_UNEXPECTED(err_msg); } std::vector columns_to_load; bool schema_exists = false; + bool shuffle_required = false; + int64_t num_devices = 0; + int64_t total_rows = 0; // Optional arguments for (auto arg : args) { std::string key = py::str(arg.first); @@ -773,13 +835,15 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { - (void)builder->SetShuffleGlobal(ToBool(value)); + shuffle_required = ToBool(value); } else if (key == "schema_file_path" || key == "schema_json_string") { schema_exists = true; } else if (key == "num_samples") { - (void)builder->setTotalRows(ToInt(value)); + total_rows = ToInt(value); + (void)builder->setTotalRows(total_rows); } else if (key == "num_shards") { - (void)builder->SetNumDevices(ToInt(value)); + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); } else if (key == "shard_id") { (void)builder->SetDeviceId(ToInt(value)); } else if (key == "shard_equal_rows") { @@ -796,13 +860,33 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptrSetDataSchema(std::move(schema)); } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + std::shared_ptr tf_op; + RETURN_IF_NOT_OK(builder->Build(&tf_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(tf_op)); + *top = tf_op; + + if (shuffle_required) { + const boolean estimate = true; + const int64_t workers = 8; + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset via estimate and then compute the shuffle size + RETURN_IF_NOT_OK(TFReaderOp::CountTotalRows(&num_rows, files_list, workers, estimate)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, total_rows, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, tf_op, &shuffle_op)); + *top = shuffle_op; + *bottom = tf_op; + } + return Status::OK(); } -Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["columns"].is_none()) { std::string err_msg = "Error: columns is missing"; RETURN_STATUS_UNEXPECTED(err_msg); @@ -811,11 +895,12 @@ Status DEPipeline::ParseProjectOp(const py::dict &args, std::shared_ptr builder = std::make_shared(columns_to_project); std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; @@ -846,11 +931,12 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_file"].is_none()) { std::string err_msg = "Error: No dataset files specified for manifest"; @@ -881,11 +967,12 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; RETURN_STATUS_UNEXPECTED(err_msg); @@ -924,11 +1011,13 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr * } std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; + return Status::OK(); } -Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; RETURN_STATUS_UNEXPECTED(err_msg); @@ -965,11 +1054,12 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr } std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; @@ -998,11 +1088,12 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; @@ -1031,11 +1122,12 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments RandomDataOp::Builder builder; @@ -1072,13 +1164,14 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder.Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } int32_t DEPipeline::GetNumClasses() const { return num_classes_; } -Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; @@ -1104,11 +1197,12 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr } std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments if (args["dataset_dir"].is_none()) { std::string err_msg = "Error: No dataset path specified"; @@ -1143,19 +1237,24 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { // Required arguments + std::vector files_list; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { - (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetTextFilesList(files_list); } else { RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); } // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; for (auto arg : args) { std::string key = py::str(arg.first); py::handle value = arg.second; @@ -1165,19 +1264,38 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptrSetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { - (void)builder->SetShuffleGlobal(ToBool(value)); + shuffle_required = ToBool(value); } else if (key == "num_samples") { (void)builder->SetTotalRows(ToInt(value)); } else if (key == "num_shards") { - (void)builder->SetNumDevices(ToInt(value)); + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); } else if (key == "shard_id") { (void)builder->SetDeviceId(ToInt(value)); } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + + std::shared_ptr txt_op; + RETURN_IF_NOT_OK(builder->Build(&txt_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(txt_op)); + *top = txt_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(TextFileOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, txt_op, &shuffle_op)); + *top = shuffle_op; + *bottom = txt_op; + } + return Status::OK(); } @@ -1208,7 +1326,8 @@ Status DEPipeline::ParsePadInfo(py::handle value, PadInfo *pad_info) { return Status::OK(); } -Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { std::shared_ptr builder = std::make_shared(); for (auto arg : args) { std::string key = py::str(arg.first); @@ -1235,18 +1354,23 @@ Status DEPipeline::ParseBuildVocabOp(const py::dict &args, std::shared_ptr op; RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + *top = op; return Status::OK(); } -Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *ptr) { +Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom) { + std::vector files_list; std::shared_ptr builder = std::make_shared(); if (!args["dataset_files"].is_none()) { - (void)builder->SetClueFilesList(ToStringVector(args["dataset_files"])); + files_list = ToStringVector(args["dataset_files"]); + (void)builder->SetClueFilesList(files_list); } else { RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); } // Optional arguments + bool shuffle_required = false; + int64_t num_devices = 0; for (auto arg : args) { std::string key = py::str(arg.first); py::handle value = arg.second; @@ -1256,11 +1380,12 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr } else if (key == "shuffle_files") { (void)builder->SetShuffleFiles(ToBool(value)); } else if (key == "shuffle_global") { - (void)builder->SetShuffleGlobal(ToBool(value)); + shuffle_required = ToBool(value); } else if (key == "num_samples") { (void)builder->SetNumSamples(ToInt(value)); } else if (key == "num_shards") { - (void)builder->SetNumDevices(ToInt(value)); + num_devices = ToInt(value); + (void)builder->SetNumDevices(num_devices); } else if (key == "shard_id") { (void)builder->SetDeviceId(ToInt(value)); } else if (key == "cols_to_keyword") { @@ -1276,9 +1401,76 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr } } } - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - *ptr = op; + + std::shared_ptr clue_op; + RETURN_IF_NOT_OK(builder->Build(&clue_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(clue_op)); + *top = clue_op; + + if (shuffle_required) { + std::shared_ptr shuffle_op = nullptr; + int64_t shuffle_size = 0; + int64_t num_rows = 0; + + // First, get the number of rows in the dataset and then compute the shuffle size + RETURN_IF_NOT_OK(ClueOp::CountAllFileRows(files_list, &num_rows)); + RETURN_IF_NOT_OK(ComputeShuffleSize(files_list.size(), num_devices, num_rows, 0, &shuffle_size)); + + // Add the shuffle op over top of this op and return the subtree (top/bottom) to caller + RETURN_IF_NOT_OK(AddShuffleOp(shuffle_size, clue_op, &shuffle_op)); + *top = shuffle_op; + *bottom = clue_op; + } + + return Status::OK(); +} + +// Helper function to inject a shuffle operator over top of the current operation being built. +Status DEPipeline::AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op) { + std::shared_ptr new_shuffle_op = nullptr; + ShuffleOp::Builder shuffle_builder; + + (void)shuffle_builder.SetShuffleSize(shuffle_size); + RETURN_IF_NOT_OK(shuffle_builder.Build(&new_shuffle_op)); + RETURN_IF_NOT_OK(tree_->AssociateNode(new_shuffle_op)); + RETURN_IF_NOT_OK(new_shuffle_op->AddChild(input_op)); + // We have now created: + // + // ShuffleOp + // | + // input_op + // + *shuffle_op = new_shuffle_op; + + return Status::OK(); +} + +// Common code for computing a default shuffle size +Status DEPipeline::ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size) { + const int64_t average_files_multiplier = 4; + const int64_t shuffle_max = 10000; + int64_t avg_rows_per_file = 0; + + // Adjust the num rows per shard if sharding was given + if (num_devices > 0) { + if (num_rows % num_devices == 0) { + num_rows = num_rows / num_devices; + } else { + num_rows = (num_rows / num_devices) + 1; + } + } + + // Cap based on total rows directive. Some ops do not have this and give value of 0. + if (total_rows > 0) { + num_rows = std::min(num_rows, total_rows); + } + + // get the average per file + avg_rows_per_file = num_rows / num_files; + + *shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index d6127d5d44..7cfc73307c 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -77,7 +77,7 @@ class DEPipeline { ~DEPipeline(); // Function to add a Node to the Execution Tree. - Status AddNodeToTree(const OpName &op_name, const py::dict &args, DsOpPtr *out); + Status AddNodeToTree(const OpName &op_name, const py::dict &args, py::dict *output); // Function to add a child and parent relationship. static Status AddChildToParentNode(const DsOpPtr &child_op, const DsOpPtr &parent_op); @@ -104,73 +104,74 @@ class DEPipeline { int GetRepeatCount() const; - Status ParseShuffleOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseShuffleOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseMindRecordOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); Status BuildMindrecordSamplerChain(const py::handle &handle, std::vector> *operators, int num_padded); - Status ParseMapOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseMapOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseFilterOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseFilterOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseSkipOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseSkipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBatchOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBucketBatchByLengthOp(const py::dict &args, std::shared_ptr *top, + std::shared_ptr *bottom); - Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRenameOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseTakeOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTakeOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseZipOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseConcatOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseConcatOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseProjectOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseProjectOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseImageFolderOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseManifestOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseManifestOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseVOCOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseVOCOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseCocoOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseCocoOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseCifar10Op(const py::dict &args, std::shared_ptr *ptr); + Status ParseCifar10Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseCifar100Op(const py::dict &args, std::shared_ptr *ptr); + Status ParseCifar100Op(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRandomDataOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); void PrintTree(); int32_t GetNumClasses() const; - Status ParseMnistOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseMnistOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); Status SetBatchParameters(const py::dict &args); - Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseCelebAOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBuildVocabOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); - Status ParseClueOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseClueOp(const py::dict &args, std::shared_ptr *top, std::shared_ptr *bottom); private: // Execution tree that links the dataset operators. @@ -180,6 +181,25 @@ class DEPipeline { static Status ParsePadInfo(py::handle value, PadInfo *pad_info); + /// \brief Helper function to inject a shuffle operator over top of the current operation being built. + /// \param[in] shuffle_size The size to use in the shuffle buffer + /// \param[in] input_op The operator to build shuffle on top of + /// \param[out] shuffle_op The top node of the created subtree (subtree contains two nodes). In this case it will be + /// the shuffle operator + /// \return Status return code + Status AddShuffleOp(int64_t shuffle_size, std::shared_ptr input_op, + std::shared_ptr *shuffle_op); + + /// \brief Helper function to compute the shuffle size + /// \param[in] num_files The number of files in the dataset + /// \param[in] num_devices The number of devices in the dataset + /// \param[in] num_rows The number of rows in the dataset + /// \param[in] total_rows An upper bound on the total rows in the dataset + /// \param[out] shuffle_size The resultant computed shuffle size + /// \return Status return code + Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows, + int64_t *shuffle_size); + int batch_size_; int repeat_num_; int num_rows_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 51f2be49d5..7bed870f1a 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -116,9 +116,9 @@ void bindDEPipeline(py::module *m) { .def( "AddNodeToTree", [](DEPipeline &de, const OpName &op_name, const py::dict &args) { - DsOpPtr op; - THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &op)); - return op; + py::dict out; + THROW_IF_ERROR(de.AddNodeToTree(op_name, args, &out)); + return out; }, py::return_value_policy::reference) .def_static("AddChildToParentNode", diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 053559f88b..fcb2e357e8 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -54,20 +54,19 @@ Status MapOp::Builder::sanityCheck() const { Status MapOp::Builder::Build(std::shared_ptr *ptr) { RETURN_IF_NOT_OK(sanityCheck()); *ptr = std::make_shared(std::move(build_in_col_names_), std::move(build_out_col_names_), - std::move(build_tensor_funcs_), std::move(build_col_order_), build_num_workers_, - build_op_connector_size_, build_perf_mode_); + std::move(build_tensor_funcs_), build_num_workers_, build_op_connector_size_, + build_perf_mode_); return Status::OK(); } // Constructor of MapOp MapOp::MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, const std::vector &columns_order, - int32_t num_workers, int32_t op_connector_size, bool perf_mode) + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode) : ParallelOp(num_workers, op_connector_size), tfuncs_(std::move(tensor_funcs)), in_columns_(in_col_names), out_columns_(out_col_names), - columns_order_(columns_order), perf_mode_(perf_mode) { // If caller didn't specify the out_col_names, assume they are same as the in_columns. if (out_columns_.empty() || out_columns_[0].empty()) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h index 94569bd41f..371d865196 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h @@ -93,13 +93,6 @@ class MapOp : public ParallelOp { return *this; } - // Setter method. - // @return Builder setter method returns reference to the builder. - Builder &SetColOrder(const std::vector &col_order_) { - build_col_order_ = col_order_; - return *this; - } - // Setter method. // @return Builder setter method returns reference to the builder. Builder &SetNumWorkers(int32_t num_workers) { @@ -130,7 +123,6 @@ class MapOp : public ParallelOp { std::vector build_in_col_names_; std::vector build_out_col_names_; std::vector> build_tensor_funcs_; - std::vector build_col_order_; int32_t build_num_workers_; int32_t build_op_connector_size_; bool build_perf_mode_; // Default true. @@ -145,12 +137,11 @@ class MapOp : public ParallelOp { // @param in_col_names A list of input column names (should match the input/output \p tensorFuncs). // @param out_col_names A list of output column names (should match the input/output \p tensorFuncs). // @param tensor_funcs A list of TensorOp pointers for MapOp to apply to each data. - // @param columns_order names A full list of column names (should match the whole dataset view post \p tensorFuncs). // @param num_workers The number of worker threads. // @param op_connector_size The size of each queue in the connector. MapOp(const std::vector &in_col_names, const std::vector &out_col_names, - std::vector> tensor_funcs, const std::vector &columns_order, - int32_t num_workers, int32_t op_connector_size, bool perf_mode); + std::vector> tensor_funcs, int32_t num_workers, int32_t op_connector_size, + bool perf_mode); // Destructor ~MapOp() = default; @@ -190,10 +181,6 @@ class MapOp : public ParallelOp { // @return Name of the current Op std::string Name() const override { return "MapOp"; } - // Columns order getter - // @return The post map columns order - std::vector const &ColumnsOrder() const { return columns_order_; } - private: // Local queues where worker threads can pop from. // Popping directly from the Connector can block if the previous designated threads haven't pop. @@ -215,9 +202,6 @@ class MapOp : public ParallelOp { // Indices of the columns to process. std::vector to_process_indices_; - // Variable to store the column_order of all columns post tensorOps - std::vector columns_order_; - // Performance mode is when the main thread creates local queues, pulls databuffers from the previous // op's Connector and distributes them to the local queues. Workers pull from the local queues. // If this flag is false, each worker pulls directly from the Connector. This use less resources diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc index d863de15ad..9fceb6f333 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.cc @@ -31,11 +31,7 @@ namespace mindspore { namespace dataset { ClueOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_num_samples_(0), - builder_shuffle_files_(false), - builder_shuffle_global_(false) { + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -66,8 +62,8 @@ Status ClueOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr clue_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, - builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_shuffle_global_, - builder_num_devices_, builder_device_id_); + builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, + builder_device_id_); RETURN_IF_NOT_OK(clue_op->Init()); *op = std::move(clue_op); @@ -87,7 +83,7 @@ std::vector ClueOp::Builder::split(const std::string &s, char delim ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, bool shuffle_global, int32_t num_device, int32_t device_id) + bool shuffle_files, int32_t num_device, int32_t device_id) : ParallelOp(num_workers, op_connector_size), rows_per_buffer_(rows_per_buffer), num_rows_per_shard_(0), @@ -98,7 +94,6 @@ ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples load_jagged_connector_(true), cols_to_keyword_(cols_to_keyword), shuffle_files_(shuffle_files), - shuffle_global_(shuffle_global), finished_reading_dataset_(false), num_devices_(num_device), device_id_(device_id), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h index f41abd020c..487ed0d47f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/clue_op.h @@ -104,13 +104,6 @@ class ClueOp : public ParallelOp { return *this; } - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleGlobal(bool shuffle_global) { - builder_shuffle_global_ = shuffle_global; - return *this; - } - // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetNumSamples(int64_t num_samples) { @@ -139,15 +132,13 @@ class ClueOp : public ParallelOp { int32_t builder_worker_connector_size_; std::vector builder_clue_files_list_; bool builder_shuffle_files_; - bool builder_shuffle_global_; std::map builder_cols_to_keyword_; }; // Constructor of ClueOp - // @param shuffle_global - whether or not to shuffle the entire dataset. ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword, std::vector clue_files_list, int32_t op_connector_size, - bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); + bool shuffle_files, int32_t num_devices, int32_t device_id); // Default destructor ~ClueOp() = default; @@ -182,10 +173,6 @@ class ClueOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return clue_files_list_; } - // Global shuffle flag getter - // @return Bool - whether this Op requires global shuffle - bool RequireGlobalShuffle() { return shuffle_global_; } - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -269,7 +256,6 @@ class ClueOp : public ParallelOp { int32_t device_id_; bool shuffle_files_; - bool shuffle_global_; bool finished_reading_dataset_; int32_t num_devices_; int64_t rows_per_buffer_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc index 5ae950b803..fbba73de21 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -33,11 +33,7 @@ namespace mindspore { namespace dataset { TextFileOp::Builder::Builder() - : builder_device_id_(0), - builder_num_devices_(1), - builder_total_rows_(0), - builder_shuffle_files_(false), - builder_shuffle_global_(false) { + : builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) { std::shared_ptr config_manager = GlobalContext::config_manager(); builder_num_workers_ = config_manager->num_parallel_workers(); builder_op_connector_size_ = config_manager->op_connector_size(); @@ -68,7 +64,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { std::shared_ptr text_file_op = std::make_shared( builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_, std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, - builder_shuffle_global_, builder_num_devices_, builder_device_id_); + builder_num_devices_, builder_device_id_); RETURN_IF_NOT_OK(text_file_op->Init()); *op = std::move(text_file_op); @@ -77,8 +73,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr *op) { TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr schema, std::vector text_files_list, - int32_t op_connector_size, bool shuffle_files, bool shuffle_global, int32_t num_device, - int32_t device_id) + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), @@ -86,7 +81,6 @@ TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t tot total_rows_(total_rows), text_files_list_(std::move(text_files_list)), shuffle_files_(shuffle_files), - shuffle_global_(shuffle_global), data_schema_(std::move(schema)), all_num_rows_(0), num_rows_per_shard_(0), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h index 31224cb299..5379263979 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -105,13 +105,6 @@ class TextFileOp : public ParallelOp { return *this; } - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleGlobal(bool shuffle_global) { - builder_shuffle_global_ = shuffle_global; - return *this; - } - // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetTotalRows(int64_t total_rows) { @@ -129,7 +122,6 @@ class TextFileOp : public ParallelOp { int32_t builder_worker_connector_size_; std::vector builder_text_files_list_; bool builder_shuffle_files_; - bool builder_shuffle_global_; std::unique_ptr builder_schema_; }; @@ -143,11 +135,10 @@ class TextFileOp : public ParallelOp { // @param op_connector_size - size of each queue in the connector that the child operator pulls from. // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param shuffle_global - whether or not to shuffle the entire dataset. // @param equal_rows_per_shard - whether or not to get equal rows for each process. TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size, std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, - bool shuffle_files, bool shuffle_global, int32_t num_devices, int32_t device_id); + bool shuffle_files, int32_t num_devices, int32_t device_id); // Default destructor ~TextFileOp() = default; @@ -186,10 +177,6 @@ class TextFileOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return text_files_list_; } - // Global shuffle flag getter - // @return Bool - whether this Op requires global shuffle - bool RequireGlobalShuffle() { return shuffle_global_; } - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -274,7 +261,6 @@ class TextFileOp : public ParallelOp { int64_t total_rows_; std::vector text_files_list_; bool shuffle_files_; - bool shuffle_global_; std::unique_ptr data_schema_; int64_t all_num_rows_; int64_t num_rows_per_shard_; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 8b92d19249..b05fa54978 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -55,7 +55,6 @@ TFReaderOp::Builder::Builder() builder_op_connector_size_ = config_manager->op_connector_size(); builder_rows_per_buffer_ = config_manager->rows_per_buffer(); builder_shuffle_files_ = false; - builder_shuffle_global_ = false; builder_data_schema_ = std::make_unique(); } @@ -126,8 +125,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) std::shared_ptr new_tf_reader_op = std::make_shared( builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_, builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_, - builder_shuffle_files_, builder_shuffle_global_, builder_num_devices_, builder_device_id_, - builder_equal_rows_per_shard_); + builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_); RETURN_IF_NOT_OK(new_tf_reader_op->Init()); *out_tf_reader_op = std::move(new_tf_reader_op); @@ -137,8 +135,8 @@ Status TFReaderOp::Builder::Build(std::shared_ptr *out_tf_reader_op) TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, - std::vector columns_to_load, bool shuffle_files, bool shuffle_global, - int32_t num_device, int32_t device_id, bool equal_rows_per_shard) + std::vector columns_to_load, bool shuffle_files, int32_t num_device, + int32_t device_id, bool equal_rows_per_shard) : ParallelOp(num_workers, op_connector_size), device_id_(device_id), num_devices_(num_device), @@ -148,7 +146,6 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64 columns_to_load_(std::move(columns_to_load)), finished_reading_dataset_(false), shuffle_files_(shuffle_files), - shuffle_global_(shuffle_global), data_schema_(std::move(data_schema)), filename_index_(std::make_unique()), load_io_block_queue_(true), @@ -174,7 +171,6 @@ void TFReaderOp::Print(std::ostream &out, bool show_all) const { // Then show any custom derived-internal stuff out << "\nRows per buffer: " << rows_per_buffer_ << "\nTotal rows: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_ << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") - << "\nShuffle global: " << ((shuffle_global_) ? "yes" : "no") << "\nDataset files list: Size: " << dataset_files_list_.size() << "\n"; for (int i = 0; i < dataset_files_list_.size(); ++i) { out << " " << dataset_files_list_[i]; diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h index 417cd8bef0..9d2e38ec6b 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.h @@ -146,13 +146,6 @@ class TFReaderOp : public ParallelOp { return *this; } - // Setter method. - // @return Builder - setter method returns reference to the builder. - Builder &SetShuffleGlobal(bool shuffle_global) { - builder_shuffle_global_ = shuffle_global; - return *this; - } - // Setter method. // @return Builder - setter method returns reference to the builder. Builder &SetShardEqualRows(bool shard_equal_rows) { @@ -172,7 +165,6 @@ class TFReaderOp : public ParallelOp { std::vector builder_dataset_files_list_; std::vector builder_columns_to_load_; bool builder_shuffle_files_; - bool builder_shuffle_global_; bool builder_equal_rows_per_shard_; }; @@ -187,12 +179,11 @@ class TFReaderOp : public ParallelOp { // @param op_connector_size - size of each queue in the connector that the child operator pulls from. // @param columns_to_load - the names of the columns to load data from. // @param shuffle_files - whether or not to shuffle the files before reading data. - // @param shuffle_global - whether or not to shuffle the entire dataset. // @param equal_rows_per_shard - whether or not to get equal rows for each process. TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows, std::vector dataset_files_list, std::unique_ptr data_schema, int32_t op_connector_size, std::vector columns_to_load, bool shuffle_files, - bool shuffle_global, int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); + int32_t num_devices, int32_t device_id, bool equal_rows_per_shard); // Default destructor ~TFReaderOp() = default; @@ -245,10 +236,6 @@ class TFReaderOp : public ParallelOp { // @return Vector of the input file names std::vector FileNames() { return dataset_files_list_; } - // Global shuffle flag getter - // @return Bool - whether this Op requires global shuffle - bool RequireGlobalShuffle() { return shuffle_global_; } - private: // The entry point for when workers are launched. // @param worker_id - the id of the worker that is executing this function. @@ -393,7 +380,6 @@ class TFReaderOp : public ParallelOp { std::vector columns_to_load_; bool finished_reading_dataset_; bool shuffle_files_; - bool shuffle_global_; std::unique_ptr data_schema_; std::unique_ptr filename_index_; bool load_io_block_queue_; diff --git a/mindspore/ccsrc/dataset/engine/execution_tree.cc b/mindspore/ccsrc/dataset/engine/execution_tree.cc index 80a11aca02..8dd622912b 100644 --- a/mindspore/ccsrc/dataset/engine/execution_tree.cc +++ b/mindspore/ccsrc/dataset/engine/execution_tree.cc @@ -19,8 +19,7 @@ #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/shuffle_op.h" #include "dataset/util/task_manager.h" -#include "dataset/engine/opt/pre/map_column_reorder.h" -#include "dataset/engine/opt/pre/global_shuffle.h" +#include "dataset/engine/opt/pass.h" #include "dataset/engine/perf/profiling.h" #include "dataset/engine/perf/monitor.h" @@ -42,6 +41,10 @@ ExecutionTree::~ExecutionTree() { (void)tg_->ServiceStop(); } // provides it with a link to the tree. A node cannot form any relationships (parent/child) with // other nodes unless they are associated with the same tree. Status ExecutionTree::AssociateNode(const std::shared_ptr &op) { + // If we are already a part of the tree, no-op + if (op->tree_ == this) { + return Status::OK(); + } if (tree_state_ != kDeTStateInit && tree_state_ != kDeTStateBuilding) { std::string err_msg = "Invalid tree state for adding a node. Current state: " + std::to_string(static_cast(tree_state_)) + @@ -211,8 +214,7 @@ Status ExecutionTree::PrepareTreePreAction() { bool modified = false; std::vector> pre_actions; // Construct pre actions - pre_actions.push_back(std::make_unique()); - pre_actions.push_back(std::make_unique()); + // example: pre_actions.push_back(new SomePass()); // Apply pre action passes for (auto &pass : pre_actions) { RETURN_IF_NOT_OK(pass->Run(this, &modified)); diff --git a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt index 170cbb55e5..af0a8918db 100644 --- a/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/opt/CMakeLists.txt @@ -2,7 +2,5 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_library(engine-opt OBJECT pass.cc - pre/map_column_reorder.cc - pre/global_shuffle.cc util/printer_pass.cc - ) \ No newline at end of file + ) diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc deleted file mode 100644 index 2adf734a6c..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.cc +++ /dev/null @@ -1,98 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/opt/pre/global_shuffle.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/shuffle_op.h" -#include "dataset/engine/datasetops/source/tf_reader_op.h" -#include "dataset/engine/datasetops/source/text_file_op.h" -#include "dataset/engine/datasetops/source/clue_op.h" - -namespace mindspore { -namespace dataset { - -Status GlobalShufflePass::RunOnTree(ExecutionTree *tree, bool *modified) { - std::vector> tf_readers; - std::vector> text_files; - std::vector> clues; - - // Pass 1, search for all sources which requires global shuffle - for (auto &op : *tree) { - if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { - if (ptr->RequireGlobalShuffle()) { - tf_readers.push_back(ptr); - continue; - } - } - if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { - if (ptr->RequireGlobalShuffle()) { - text_files.push_back(ptr); - continue; - } - } - if (auto ptr = std::dynamic_pointer_cast(op.shared_from_this())) { - if (ptr->RequireGlobalShuffle()) { - clues.push_back(ptr); - continue; - } - } - } - - // Pass 2, insert shuffle nodes - // The following blocks can be implemented with template if we unify the CountTotalRows across all source nodes . - for (auto node : tf_readers) { - std::shared_ptr builder = std::make_shared(); - int64_t total_rows = 0; - TFReaderOp::CountTotalRows(&total_rows, node->FileNames(), 8, true); - int32_t avg_file_size = total_rows / (node->FileNames().size()); - builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - RETURN_IF_NOT_OK(tree->AssociateNode(op)); - RETURN_IF_NOT_OK(node->InsertAsParent(op)); - } - - for (auto node : text_files) { - std::shared_ptr builder = std::make_shared(); - int64_t total_rows = 0; - TextFileOp::CountAllFileRows(node->FileNames(), &total_rows); - int32_t avg_file_size = total_rows / (node->FileNames().size()); - builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - RETURN_IF_NOT_OK(tree->AssociateNode(op)); - RETURN_IF_NOT_OK(node->InsertAsParent(op)); - } - - for (auto node : clues) { - std::shared_ptr builder = std::make_shared(); - int64_t total_rows = 0; - ClueOp::CountAllFileRows(node->FileNames(), &total_rows); - int32_t avg_file_size = total_rows / (node->FileNames().size()); - builder->SetShuffleSize(std::max(avg_file_size * 4, 10000)); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - RETURN_IF_NOT_OK(tree->AssociateNode(op)); - RETURN_IF_NOT_OK(node->InsertAsParent(op)); - } - - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h b/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h deleted file mode 100644 index 6865ac9391..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/global_shuffle.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H -#define DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -// Global Shuffle Pass will insert ShuffleOp when the leaf nodes requires global shuffle. -// Example: -// Input Tree: TFReader(GLOBAL_SHUFFLE) -> Batch -// Output Tree: TFReader -> Shuffle -> Batch -class GlobalShufflePass : public TreePass { - Status RunOnTree(ExecutionTree *tree, bool *modified) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_GLOBALSHUFFLE_H diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc deleted file mode 100644 index a3dbbfcc54..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.cc +++ /dev/null @@ -1,51 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "dataset/engine/opt/pre/map_column_reorder.h" -#include "dataset/engine/execution_tree.h" -#include "dataset/engine/datasetops/map_op.h" -#include "dataset/engine/datasetops/project_op.h" - -namespace mindspore { -namespace dataset { - -Status MapColumnReorder::RunOnTree(ExecutionTree *tree, bool *modified) { - std::vector> to_process; - - // Pass 1, search for all MapOp with column orders - for (auto &op : *tree) { - if (auto mapOp = std::dynamic_pointer_cast(op.shared_from_this())) { - if (mapOp->ColumnsOrder().size() != 0) { - to_process.push_back(mapOp); - } - } - } - - // Pass 2, insert nodes for all MapOp - for (auto node : to_process) { - std::shared_ptr builder = std::make_shared(node->ColumnsOrder()); - std::shared_ptr op; - RETURN_IF_NOT_OK(builder->Build(&op)); - RETURN_IF_NOT_OK(tree->AssociateNode(op)); - RETURN_IF_NOT_OK(node->InsertAsParent(op)); - } - return Status::OK(); -} - -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h b/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h deleted file mode 100644 index 84274db3d5..0000000000 --- a/mindspore/ccsrc/dataset/engine/opt/pre/map_column_reorder.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H -#define DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H - -#include -#include "dataset/engine/opt/pass.h" - -namespace mindspore { -namespace dataset { -// Map Column Recorder Pass will insert ProjectOp when MapOp requires a full output columns reorder. -// Example: -// Input Tree: TFReader -> MapOp(with col_order) -> Batch -// Output Tree: TFReader -> MapOp -> ProjectOp(col_order) -> Batch -class MapColumnReorder : public TreePass { - Status RunOnTree(ExecutionTree *tree, bool *modified) override; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_ENGINE_OPT_PASS_PRE_MAPCOLREORDER_H diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 4946fb3252..1d2d28c1c0 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -172,13 +172,13 @@ class Iterator: # Convert python node into C node and add to C layer execution tree in postorder traversal. def __convert_node_postorder(self, node): op_type = self.__get_dataset_type(node) - c_node = self.depipeline.AddNodeToTree(op_type, node.get_args()) + c_nodes = self.depipeline.AddNodeToTree(op_type, node.get_args()) for py_child in node.children: c_child = self.__convert_node_postorder(py_child) - self.depipeline.AddChildToParentNode(c_child, c_node) + self.depipeline.AddChildToParentNode(c_child, c_nodes["bottom"]) - return c_node + return c_nodes["top"] def __batch_node(self, dataset, level): """Recursively get batch node in the dataset tree.""" diff --git a/tests/ut/cpp/dataset/map_op_test.cc b/tests/ut/cpp/dataset/map_op_test.cc index b01b4a6df6..8b6a152488 100644 --- a/tests/ut/cpp/dataset/map_op_test.cc +++ b/tests/ut/cpp/dataset/map_op_test.cc @@ -130,75 +130,6 @@ std::shared_ptr ImageFolder(int64_t num_works, int64_t rows, int6 std::shared_ptr Build(std::vector> ops); -// TestByPosition scenario: -// TFReaderOp reads a dataset that have column ordering |image|label|A|B|. -// A TensorOp that does nothing picks the label column and output a column also named label. -// Thus, based on the new MapOp behaviour, the column ordering will be |image|label|A|B|. -// Verify the column ordering based on the Tensor properties matching to that of in the schema file. -TEST_F(MindDataTestMapOp, TestByPosition) { - Status rc; - MS_LOG(INFO) << "Doing TestByPosition."; - - // Note: The above TFReader config yields 5 buffers, each with 2 rows, for a total - // of 10 rows. - auto my_tfreader_op = this->CreateTFReaderOp(); - rc = my_tree_->AssociateNode(my_tfreader_op); - EXPECT_TRUE(rc.IsOk()); - auto my_no_op = std::make_shared(); - std::vector> my_func_list; - my_func_list.push_back(my_no_op); - std::shared_ptr my_map_op; - MapOp::Builder builder; - builder.SetInColNames({"label"}) - .SetOutColNames({}) - .SetColOrder({"image", "label", "A", "B"}) - .SetTensorFuncs(std::move(my_func_list)) - .SetNumWorkers(100); - rc = builder.Build(&my_map_op); - EXPECT_TRUE(rc.IsOk()); - rc = my_tree_->AssociateNode(my_map_op); - EXPECT_TRUE(rc.IsOk()); - rc = my_map_op->AddChild(my_tfreader_op); - EXPECT_TRUE(rc.IsOk()); - rc = my_tree_->AssignRoot(my_map_op); - EXPECT_TRUE(rc.IsOk()); - rc = my_tree_->Prepare(); - EXPECT_TRUE(rc.IsOk()); - rc = my_tree_->Launch(); - EXPECT_TRUE(rc.IsOk()); - - - // Based on the schema file, create the golden result to compare with. - std::vector golden_types({ - DataType::Type::DE_UINT8, - DataType::Type::DE_INT64, - DataType::Type::DE_FLOAT32, - DataType::Type::DE_INT64} - ); - - std::vector golden_ranks({3, 1, 4, 1}); - - std::vector golden_shapes({ - TensorShape({3, 4, 2}), - TensorShape({7}), - TensorShape({1, 13, 14, 12}), - TensorShape({9})} - ); - - // Start the loop of reading tensors from our pipeline - DatasetIterator di(my_tree_); - TensorRow tensor_list; - rc = di.FetchNextTensorRow(&tensor_list); - EXPECT_TRUE(rc.IsOk()); - EXPECT_EQ(tensor_list.size(), 4); - for (uint32_t i = 0; i < tensor_list.size(); i++) { - EXPECT_EQ(tensor_list[i]->type(), golden_types[i]); - EXPECT_EQ(tensor_list[i]->Rank(), golden_ranks[i]); - EXPECT_EQ(tensor_list[i]->shape(), golden_shapes[i]); - EXPECT_NE(tensor_list[i]->GetBuffer(), nullptr); - } -} - // TestAsMap scenario: // TFReaderOp reads a dataset that have column ordering |image|label|A|B|. // A TensorOp that does nothing picks the "image" column and produces a column named "X". diff --git a/tests/ut/python/dataset/test_opt_pass.py b/tests/ut/python/dataset/test_opt_pass.py index bab881e283..480bfcbeab 100644 --- a/tests/ut/python/dataset/test_opt_pass.py +++ b/tests/ut/python/dataset/test_opt_pass.py @@ -16,8 +16,10 @@ import numpy as np import mindspore.dataset as ds - -def test_map_reorder_pass_0(): +# tests the construction of multiple ops from a single dataset. +# map dataset with columns order arguments should produce a ProjectOp over MapOp +# This test does not utilize the compiling passes at this time. +def test_map_reorder0(): def generator_mc(maxid=1): for _ in range(maxid): yield (np.array([0]), np.array([1])) @@ -31,8 +33,10 @@ def test_map_reorder_pass_0(): for item in data0.create_tuple_iterator(): # each data is a dictionary assert item == [np.array(1), np.array(0)] - -def test_map_reorder_pass_1(): +# tests the construction of multiple ops from a single dataset. +# map dataset with columns order arguments should produce a ProjectOp over MapOp +# This test does not utilize the compiling passes at this time. +def test_map_reorder1(): def generator_mc(maxid=1): for _ in range(maxid): yield (np.array([0]), np.array([1]), np.array([2])) @@ -48,8 +52,10 @@ def test_map_reorder_pass_1(): for item in data2.create_tuple_iterator(): assert item == [np.array(2), np.array(2), np.array(1), np.array(1), np.array(0), np.array(0)] - -def test_global_shuffle_pass(): +# tests the construction of multiple ops from a single dataset. +# TFRecordDataset with global shuffle should produce a ShuffleOp over TfReaderOp. +# This test does not utilize the compiling passes at this time. +def test_shuffle(): FILES = ["../data/dataset/testTFTestAllTypes/test.data"] SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json" @@ -85,6 +91,6 @@ def test_global_shuffle_pass(): if __name__ == "__main__": - test_map_reorder_pass_0() - test_map_reorder_pass_1() - test_global_shuffle_pass() + test_map_reorder0() + test_map_reorder1() + test_global_shuffle()