From 7b1b457ccfda83a0a07479ae55a2180a0eb19fc0 Mon Sep 17 00:00:00 2001 From: Xiao Tianci Date: Wed, 21 Oct 2020 23:00:57 +0800 Subject: [PATCH] c++ api add DeviceQueue --- .../ccsrc/minddata/dataset/api/datasets.cc | 53 +++++++++++ .../dataset/engine/consumers/tree_consumer.cc | 26 +++++- .../dataset/engine/consumers/tree_consumer.h | 30 ++++--- .../engine/ir/datasetops/CMakeLists.txt | 1 + .../engine/ir/datasetops/source/album_node.cc | 7 ++ .../engine/ir/datasetops/source/album_node.h | 4 + .../ir/datasetops/source/celeba_node.cc | 7 ++ .../engine/ir/datasetops/source/celeba_node.h | 4 + .../ir/datasetops/source/cifar100_node.cc | 7 ++ .../ir/datasetops/source/cifar100_node.h | 4 + .../ir/datasetops/source/cifar10_node.cc | 7 ++ .../ir/datasetops/source/cifar10_node.h | 4 + .../engine/ir/datasetops/source/clue_node.cc | 7 ++ .../engine/ir/datasetops/source/clue_node.h | 4 + .../engine/ir/datasetops/source/coco_node.cc | 8 ++ .../engine/ir/datasetops/source/coco_node.h | 4 + .../engine/ir/datasetops/source/csv_node.cc | 8 ++ .../engine/ir/datasetops/source/csv_node.h | 4 + .../ir/datasetops/source/image_folder_node.cc | 8 ++ .../ir/datasetops/source/image_folder_node.h | 4 + .../ir/datasetops/source/manifest_node.cc | 8 ++ .../ir/datasetops/source/manifest_node.h | 4 + .../ir/datasetops/source/minddata_node.cc | 7 ++ .../ir/datasetops/source/minddata_node.h | 4 + .../engine/ir/datasetops/source/mnist_node.cc | 7 ++ .../engine/ir/datasetops/source/mnist_node.h | 4 + .../ir/datasetops/source/random_node.cc | 7 ++ .../engine/ir/datasetops/source/random_node.h | 4 + .../ir/datasetops/source/text_file_node.cc | 7 ++ .../ir/datasetops/source/text_file_node.h | 4 + .../ir/datasetops/source/tf_record_node.cc | 7 ++ .../ir/datasetops/source/tf_record_node.h | 4 + .../engine/ir/datasetops/source/voc_node.cc | 7 ++ .../engine/ir/datasetops/source/voc_node.h | 4 + .../engine/ir/datasetops/transfer_node.cc | 90 +++++++++++++++++++ .../engine/ir/datasetops/transfer_node.h | 62 +++++++++++++ .../minddata/dataset/engine/tree_adapter.h | 4 + .../ccsrc/minddata/dataset/include/datasets.h | 15 ++++ .../ccsrc/minddata/dataset/include/samplers.h | 11 ++- 39 files changed, 443 insertions(+), 18 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index 0672015029..c47f904a92 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -62,6 +62,7 @@ #include "minddata/dataset/engine/ir/datasetops/shuffle_node.h" #include "minddata/dataset/engine/ir/datasetops/skip_node.h" #include "minddata/dataset/engine/ir/datasetops/take_node.h" +#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" #include "minddata/dataset/engine/ir/datasetops/zip_node.h" #ifndef ENABLE_ANDROID @@ -72,6 +73,7 @@ #include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/util/path.h" #include "minddata/dataset/util/random.h" +#include "minddata/dataset/util/services.h" // IR leaf nodes #include "minddata/dataset/engine/ir/datasetops/source/album_node.h" @@ -125,6 +127,56 @@ std::shared_ptr Dataset::CreateIterator(std::vector colum return iter; } +// Function to return a transferred Node that transfers data through a device. +bool Dataset::DeviceQueue(bool send_epoch_end) { + Status rc; + + // Build and launch tree + std::unique_ptr runtime_context = std::make_unique(); + rc = runtime_context->Init(); + if (rc.IsError()) { + MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc; + return false; + } + + // Get a uuid for queue name + std::string queue_name = Services::GetUniqueID(); + + // TODO(CRC): + // Get device type from ms context + std::string device_type = "CPU"; + + // Get device ID from children + int32_t device_id = 0; + rc = TransferNode::get_distribution(shared_from_this(), &device_id); + if (rc.IsError()) { + MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc; + return false; + } + + // Add TransferNode IR on top of dataset d + auto ds = std::make_shared(shared_from_this(), queue_name, device_id, device_type, send_epoch_end); + + // Get ToDevice consumer + auto consumer = std::make_unique(device_type, send_epoch_end, -1); + ToDevice *consumer_ = consumer.get(); + rc = consumer->Init(ds); + if (rc.IsError()) { + MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc; + return false; + } + runtime_context->AssignConsumer(std::move(consumer)); + + // Send data to device + rc = consumer_->Send(); + if (rc.IsError()) { + MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc; + return false; + } + + return true; +} + #ifndef ENABLE_ANDROID // Function to create the saver, which will build and launch the execution tree and save data bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) { @@ -931,6 +983,7 @@ std::shared_ptr CreateDatasetCache(session_id_type id, uint64_t me auto cache = std::make_shared(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz); return cache->ValidateParams() ? cache : nullptr; } + #endif } // namespace api diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc index bcedbb7f91..3aa4bf3bc3 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.cc @@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map // ToDevice Status ToDevice::Init(std::shared_ptr d) { - // TODO(CRC): - // Get device ID from children look at get_distribution in python - // Add DeviceQue IR on top of dataset d - return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_); } +Status ToDevice::Send() { + std::unique_ptr db; + RETURN_IF_NOT_OK(tree_adapter_->Launch()); + RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db)); + return Status::OK(); +} + +Status ToDevice::Continue() { + // tree_.root() must be DeviceQueueOp + DeviceQueueOp *op = dynamic_cast(tree_adapter_->root().get()); + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp"); + op->ContinueSend(); + return Status::OK(); +} + +Status ToDevice::Stop() { + DeviceQueueOp *op = dynamic_cast(tree_adapter_->root().get()); + CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp"); + op->StopSend(); + return Status::OK(); +} + #ifndef ENABLE_ANDROID // SaveToDisk Status SaveToDisk::ValidateParams() { diff --git a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h index 30d289388c..e4376d14b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h +++ b/mindspore/ccsrc/minddata/dataset/engine/consumers/tree_consumer.h @@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer { /// Consumer that iterates over the dataset and send it to a device class ToDevice : public TreeConsumer { public: - ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs) + ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1) : TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {} Status Init(std::shared_ptr d) override; - Status Send() { - // TODO(CRC): launch the tree - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status Stop() { - // TODO(CRC): Get root + call StopSend - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } - Status Continue() { - // TODO(CRC): Get root + call StopSend - return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); - } + /// Send the data to device + /// \return Status error code + Status Send(); + + /// Stop to send data to device + /// \return Status error code + Status Stop(); + + /// Continue to send data to device + /// \return Status error code + Status Continue(); + + protected: + /// Method to return the name of the consumer + /// \return string + std::string Name() override { return "ToDevice"; } private: std::string device_type_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt index 3f50c8d5d6..b771b4cc84 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/CMakeLists.txt @@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES skip_node.cc sync_wait_node.cc take_node.cc + transfer_node.cc zip_node.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc index 51588eae5c..1baa78c2dc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.cc @@ -68,6 +68,13 @@ std::vector> AlbumNode::Build() { return node_ops; } +// Get the shard id of node +Status AlbumNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h index 8cfeb9b9ad..498535296e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h @@ -44,6 +44,10 @@ class AlbumNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string schema_path_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index e46fb580ea..e6f6182008 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -67,6 +67,13 @@ std::vector> CelebANode::Build() { return node_ops; } +// Get the shard id of node +Status CelebANode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index 30a539cbec..3829302cff 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -46,6 +46,10 @@ class CelebANode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 465cdcdc02..835967005d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -66,6 +66,13 @@ std::vector> Cifar100Node::Build() { return node_ops; } +// Get the shard id of node +Status Cifar100Node::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index ab336b1cd3..bbde01ba20 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -44,6 +44,10 @@ class Cifar100Node : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index 208a101e4e..578e2d9bfb 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -64,6 +64,13 @@ std::vector> Cifar10Node::Build() { return node_ops; } +// Get the shard id of node +Status Cifar10Node::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 5832c1446f..ff851c420f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -44,6 +44,10 @@ class Cifar10Node : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index c6cbb59544..af6a5d75b2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -213,6 +213,13 @@ std::vector> CLUENode::Build() { return node_ops; } +// Get the shard id of node +Status CLUENode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index 42cbc1aa6e..8e2eb8aff8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -45,6 +45,10 @@ class CLUENode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: /// \brief Split string based on a character delimiter /// \return A string vector diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index adb933bfa4..0b447d9960 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -117,6 +117,14 @@ std::vector> CocoNode::Build() { return node_ops; } + +// Get the shard id of node +Status CocoNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index 3398d33707..50b621ef01 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -43,6 +43,10 @@ class CocoNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string annotation_file_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index f44688a17a..222b9d6d74 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -122,6 +122,14 @@ std::vector> CSVNode::Build() { return node_ops; } + +// Get the shard id of node +Status CSVNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index 6e53985b49..ca673c0ee5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -66,6 +66,10 @@ class CSVNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::vector dataset_files_; char field_delim_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc index 714d6f9799..338308a0c8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.cc @@ -70,6 +70,14 @@ std::vector> ImageFolderNode::Build() { std::move(sampler_->Build()))); return node_ops; } + +// Get the shard id of node +Status ImageFolderNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h index 6f5345472d..fe4e4cd13d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/image_folder_node.h @@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; bool decode_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index d884f3fec6..41141b0ce5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -85,6 +85,14 @@ std::vector> ManifestNode::Build() { return node_ops; } + +// Get the shard id of node +Status ManifestNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index 449144fa24..99d7ef435b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -44,6 +44,10 @@ class ManifestNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_file_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc index 4093b2e032..6b2004a156 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.cc @@ -160,6 +160,13 @@ std::vector> MindDataNode::Build() { return node_ops; } +// Get the shard id of node +Status MindDataNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h index 16fed6f5f0..1663f14762 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/minddata_node.h @@ -48,6 +48,10 @@ class MindDataNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + /// \brief Build sampler chain for minddata dataset /// \return Status Status::OK() if input sampler is valid Status BuildMindDatasetSamplerChain(const std::shared_ptr &sampler, diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc index d5ee2b49a0..6d6e1fdee8 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.cc @@ -60,6 +60,13 @@ std::vector> MnistNode::Build() { return node_ops; } +// Get the shard id of node +Status MnistNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h index 97fb371ce1..713ba94fdf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/mnist_node.h @@ -44,6 +44,10 @@ class MnistNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc index f02b4c5873..71839e43dc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.cc @@ -99,6 +99,13 @@ std::vector> RandomNode::Build() { return node_ops; } +// Get the shard id of node +Status RandomNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index aa1c07fcc2..09e980a14c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -65,6 +65,10 @@ class RandomNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: /// \brief A quick inline for producing a random number between (and including) min/max /// \param[in] min minimum number that can be generated. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index d445ebcd67..519d2b6c30 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -95,6 +95,13 @@ std::vector> TextFileNode::Build() { return node_ops; } +// Get the shard id of node +Status TextFileNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index 2bfd384187..e5762f8b37 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -45,6 +45,10 @@ class TextFileNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::vector dataset_files_; int32_t num_samples_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc index 5e8d5ef6fe..1dac924d60 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.cc @@ -80,6 +80,13 @@ std::vector> TFRecordNode::Build() { return node_ops; } +// Get the shard id of node +Status TFRecordNode::GetShardId(int32_t *shard_id) { + *shard_id = shard_id_; + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h index 5c97deebcb..ebc493d09a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/tf_record_node.h @@ -72,6 +72,10 @@ class TFRecordNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: std::vector dataset_files_; std::string schema_path_; // schema_path_ path to schema file. It is set when type of schema parameter is string diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 0e556bde86..f263ceb096 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -112,6 +112,13 @@ std::vector> VOCNode::Build() { return node_ops; } +// Get the shard id of node +Status VOCNode::GetShardId(int32_t *shard_id) { + *shard_id = sampler_->ShardId(); + + return Status::OK(); +} + } // namespace api } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index 7bde316663..a61b758fd6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -45,6 +45,10 @@ class VOCNode : public Dataset { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Get the shard id of node + /// \return Status Status::OK() if get shard id successfully + Status GetShardId(int32_t *shard_id) override; + private: const std::string kColumnImage = "image"; const std::string kColumnTarget = "target"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc new file mode 100644 index 0000000000..787b48df99 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -0,0 +1,90 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/ir/datasetops/transfer_node.h" + +#include +#include +#include + +#include "minddata/dataset/engine/datasetops/device_queue_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace api { + +// Constructor for TransferNode +TransferNode::TransferNode(std::shared_ptr child, const std::string &queue_name, int32_t device_id, + const std::string &device_type, bool send_epoch_end) + : queue_name_(queue_name), + device_id_(device_id), + device_type_(device_type), + prefetch_size_(16), + send_epoch_end_(send_epoch_end), + total_batch_(0) { + this->children.push_back(child); +} + +// Validator for TransferNode +Status TransferNode::ValidateParams() { + // Check if device_type_ is in {"CPU", "GPU", "Ascend"} + RETURN_IF_NOT_OK(ValidateStringValue("TransferNode", device_type_, {"CPU", "GPU", "Ascend"})); + return Status::OK(); +} + +// Function to build TransferNode +std::vector> TransferNode::Build() { + // A vector containing shared pointer to the Dataset Ops that this object will create + std::vector> node_ops; + + // Convert device_type_ from string to DeviceType + DeviceQueueOp::DeviceType type; + if (device_type_ == "CPU") { + type = DeviceQueueOp::DeviceType::CPU; + } else if (device_type_ == "GPU") { + type = DeviceQueueOp::DeviceType::GPU; + } else if (device_type_ == "Ascend") { + type = DeviceQueueOp::DeviceType::Ascend; + } + + node_ops.push_back( + std::make_shared(queue_name_, type, device_id_, prefetch_size_, send_epoch_end_, total_batch_)); + return node_ops; +} + +// Function to get the device_id +Status TransferNode::get_distribution(std::shared_ptr ds, int32_t *device_id) { + // Get device id according to the type of dataset + Status rc = ds->GetShardId(device_id); + if (rc != Status::OK()) { + // Get device id from the child node + if (ds->children.size()) { + ds = ds->children[0]; + return TransferNode::get_distribution(ds, device_id); + } else { + std::string err_msg = "Unknown dataset type."; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + + return Status::OK(); +} + +} // namespace api +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h new file mode 100644 index 0000000000..000287155b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ + +#include +#include +#include + +#include "minddata/dataset/include/datasets.h" + +namespace mindspore { +namespace dataset { + +namespace api { + +class TransferNode : public Dataset { + public: + /// \brief Constructor + TransferNode(std::shared_ptr child, const std::string &queue_name, int32_t device_id, + const std::string &device_type, bool send_epoch_end); + + /// \brief Destructor + ~TransferNode() = default; + + /// \brief a base class override function to create the required runtime dataset op objects for this class + /// \return shared pointer to the list of newly created DatasetOps + std::vector> Build() override; + + /// \brief Parameters validation + /// \return Status Status::OK() if all the parameters are valid + Status ValidateParams() override; + + static Status get_distribution(std::shared_ptr ds, int32_t *device_id); + + private: + std::string queue_name_; + int32_t device_id_; + std::string device_type_; + int32_t prefetch_size_; + bool send_epoch_end_; + int32_t total_batch_; +}; + +} // namespace api +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_TRANSFER_NODE_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h index c95b72b44e..34b434e129 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h +++ b/mindspore/ccsrc/minddata/dataset/engine/tree_adapter.h @@ -57,6 +57,10 @@ class TreeAdapter { // to be able to launch a thread. BuildAndPrepare needs to be called before this function TaskGroup *AllTasks() const { return tree_ != nullptr ? tree_->AllTasks() : nullptr; } + std::shared_ptr root() { return tree_->root(); } + + Status Launch() const { return tree_->Launch(); } + private: // This RECURSIVE function converts IR nodes into DatasetOp in ExecutionTree. IR could build a vector of ops. In // such case, the first node is returned. Op is added as child when the current function returns. diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index 71c900ddfe..f3704b54c8 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -96,6 +96,7 @@ class RepeatNode; class ShuffleNode; class SkipNode; class TakeNode; +class TransferNode; class ZipNode; #define RETURN_EMPTY_IF_ERROR(_s) \ @@ -559,6 +560,7 @@ class Dataset : public std::enable_shared_from_this { public: // need friend class so they can access the children_ field friend class Iterator; + friend class TransferNode; friend class mindspore::dataset::TreeAdapter; /// \brief Constructor @@ -579,6 +581,12 @@ class Dataset : public std::enable_shared_from_this { /// \return Status Status::OK() if all the parameters are valid virtual Status ValidateParams() = 0; + /// \brief Pure virtual function for derived class to get the shard id of specific node + /// \return Status Status::OK() if get shard id successfully + virtual Status GetShardId(int32_t *shard_id) { + return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); + } + /// \brief Gets the dataset size /// \return status code int64_t GetDatasetSize(); @@ -617,6 +625,13 @@ class Dataset : public std::enable_shared_from_this { /// \return Shared pointer to the Iterator std::shared_ptr CreateIterator(std::vector columns = {}); + /// \brief Function to transfer data through a device. + /// \notes If device is Ascend, features of data will be transferred one by one. The limitation + /// of data transmission per time is 256M. + /// \param[in] send_epoch_end Whether to send end of sequence to device or not (default=True). + /// \return Returns true if no error encountered else false. + bool DeviceQueue(bool send_epoch_end = true); + #ifndef ENABLE_ANDROID /// \brief Function to create a Saver to save the dynamic data processed by the dataset pipeline /// \note Usage restrictions: diff --git a/mindspore/ccsrc/minddata/dataset/include/samplers.h b/mindspore/ccsrc/minddata/dataset/include/samplers.h index e29ac61fc4..197729258e 100644 --- a/mindspore/ccsrc/minddata/dataset/include/samplers.h +++ b/mindspore/ccsrc/minddata/dataset/include/samplers.h @@ -17,8 +17,9 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_SAMPLERS_H_ -#include #include +#include +#include #ifndef ENABLE_ANDROID #include "minddata/dataset/engine/datasetops/source/mindrecord_op.h" @@ -48,6 +49,10 @@ class SamplerObj : public std::enable_shared_from_this { /// \return Shared pointers to the newly created Sampler virtual std::shared_ptr Build() = 0; + /// \brief Function for derived class to get the shard id of sampler + /// \return The shard id of the derived sampler + virtual int64_t ShardId() { return 0; } + #ifndef ENABLE_ANDROID /// \brief Virtual function to convert a SamplerObj class into a runtime mindrecord sampler object, /// only override by SubsetRandomSampler, PkSampler, RandomSampler, SequentialSampler, DistributedSampler @@ -134,6 +139,10 @@ class DistributedSamplerObj : public SamplerObj { bool ValidateParams() override; + /// \brief Function to get the shard id of sampler + /// \return The shard id of sampler + int64_t ShardId() override { return shard_id_; } + private: int64_t num_shards_; int64_t shard_id_;