c++ api add DeviceQueue

pull/7593/head
Xiao Tianci 4 years ago
parent c962ccbe07
commit 7b1b457ccf

@ -62,6 +62,7 @@
#include "minddata/dataset/engine/ir/datasetops/shuffle_node.h"
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
#include "minddata/dataset/engine/ir/datasetops/take_node.h"
#include "minddata/dataset/engine/ir/datasetops/transfer_node.h"
#include "minddata/dataset/engine/ir/datasetops/zip_node.h"
#ifndef ENABLE_ANDROID
@ -72,6 +73,7 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/util/path.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/services.h"
// IR leaf nodes
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
@ -125,6 +127,56 @@ std::shared_ptr<Iterator> Dataset::CreateIterator(std::vector<std::string> colum
return iter;
}
// Function to return a transferred Node that transfers data through a device.
bool Dataset::DeviceQueue(bool send_epoch_end) {
Status rc;
// Build and launch tree
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to init runtime context. Error status: " << rc;
return false;
}
// Get a uuid for queue name
std::string queue_name = Services::GetUniqueID();
// TODO(CRC):
// Get device type from ms context
std::string device_type = "CPU";
// Get device ID from children
int32_t device_id = 0;
rc = TransferNode::get_distribution(shared_from_this(), &device_id);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to get shard id. Error status: " << rc;
return false;
}
// Add TransferNode IR on top of dataset d
auto ds = std::make_shared<TransferNode>(shared_from_this(), queue_name, device_id, device_type, send_epoch_end);
// Get ToDevice consumer
auto consumer = std::make_unique<ToDevice>(device_type, send_epoch_end, -1);
ToDevice *consumer_ = consumer.get();
rc = consumer->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to init. Error status: " << rc;
return false;
}
runtime_context->AssignConsumer(std::move(consumer));
// Send data to device
rc = consumer_->Send();
if (rc.IsError()) {
MS_LOG(ERROR) << "ToDevice: Failed to send data to device. Error status: " << rc;
return false;
}
return true;
}
#ifndef ENABLE_ANDROID
// Function to create the saver, which will build and launch the execution tree and save data
bool Dataset::Save(std::string dataset_path, int32_t num_files, std::string dataset_type) {
@ -931,6 +983,7 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
return cache->ValidateParams() ? cache : nullptr;
}
#endif
} // namespace api

@ -74,13 +74,31 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
// ToDevice
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
// TODO(CRC):
// Get device ID from children look at get_distribution in python
// Add DeviceQue IR on top of dataset d
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
Status ToDevice::Send() {
std::unique_ptr<DataBuffer> db;
RETURN_IF_NOT_OK(tree_adapter_->Launch());
RETURN_IF_NOT_OK(tree_adapter_->root()->GetNextBuffer(&db));
return Status::OK();
}
Status ToDevice::Continue() {
// tree_.root() must be DeviceQueueOp
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "ContinueSend only supported by DeviceQueueOp");
op->ContinueSend();
return Status::OK();
}
Status ToDevice::Stop() {
DeviceQueueOp *op = dynamic_cast<DeviceQueueOp *>(tree_adapter_->root().get());
CHECK_FAIL_RETURN_UNEXPECTED(op != nullptr, "StopSend only supported by DeviceQueueOp");
op->StopSend();
return Status::OK();
}
#ifndef ENABLE_ANDROID
// SaveToDisk
Status SaveToDisk::ValidateParams() {

@ -126,23 +126,27 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs)
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Send() {
// TODO(CRC): launch the tree
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Stop() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
Status Continue() {
// TODO(CRC): Get root + call StopSend
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Send the data to device
/// \return Status error code
Status Send();
/// Stop to send data to device
/// \return Status error code
Status Stop();
/// Continue to send data to device
/// \return Status error code
Status Continue();
protected:
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "ToDevice"; }
private:
std::string device_type_;

@ -15,6 +15,7 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
skip_node.cc
sync_wait_node.cc
take_node.cc
transfer_node.cc
zip_node.cc
)

@ -68,6 +68,13 @@ std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
return node_ops;
}
// Get the shard id of node
Status AlbumNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -44,6 +44,10 @@ class AlbumNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string schema_path_;

@ -67,6 +67,13 @@ std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
return node_ops;
}
// Get the shard id of node
Status CelebANode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -46,6 +46,10 @@ class CelebANode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

@ -66,6 +66,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
return node_ops;
}
// Get the shard id of node
Status Cifar100Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -44,6 +44,10 @@ class Cifar100Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

@ -64,6 +64,13 @@ std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
return node_ops;
}
// Get the shard id of node
Status Cifar10Node::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -44,6 +44,10 @@ class Cifar10Node : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string usage_;

@ -213,6 +213,13 @@ std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
return node_ops;
}
// Get the shard id of node
Status CLUENode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -45,6 +45,10 @@ class CLUENode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
/// \brief Split string based on a character delimiter
/// \return A string vector

@ -117,6 +117,14 @@ std::vector<std::shared_ptr<DatasetOp>> CocoNode::Build() {
return node_ops;
}
// Get the shard id of node
Status CocoNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -43,6 +43,10 @@ class CocoNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
std::string annotation_file_;

@ -122,6 +122,14 @@ std::vector<std::shared_ptr<DatasetOp>> CSVNode::Build() {
return node_ops;
}
// Get the shard id of node
Status CSVNode::GetShardId(int32_t *shard_id) {
*shard_id = shard_id_;
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -66,6 +66,10 @@ class CSVNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::vector<std::string> dataset_files_;
char field_delim_;

@ -70,6 +70,14 @@ std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
std::move(sampler_->Build())));
return node_ops;
}
// Get the shard id of node
Status ImageFolderNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -51,6 +51,10 @@ class ImageFolderNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_dir_;
bool decode_;

@ -85,6 +85,14 @@ std::vector<std::shared_ptr<DatasetOp>> ManifestNode::Build() {
return node_ops;
}
// Get the shard id of node
Status ManifestNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -44,6 +44,10 @@ class ManifestNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
private:
std::string dataset_file_;
std::string usage_;

@ -160,6 +160,13 @@ std::vector<std::shared_ptr<DatasetOp>> MindDataNode::Build() {
return node_ops;
}
// Get the shard id of node
Status MindDataNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -48,6 +48,10 @@ class MindDataNode : public Dataset {
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
/// \brief Get the shard id of node
/// \return Status Status::OK() if get shard id successfully
Status GetShardId(int32_t *shard_id) override;
/// \brief Build sampler chain for minddata dataset
/// \return Status Status::OK() if input sampler is valid
Status BuildMindDatasetSamplerChain(const std::shared_ptr<SamplerObj> &sampler,

@ -60,6 +60,13 @@ std::vector<std::shared_ptr<DatasetOp>> MnistNode::Build() {
return node_ops;
}
// Get the shard id of node
Status MnistNode::GetShardId(int32_t *shard_id) {
*shard_id = sampler_->ShardId();
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save