!8096 Add DatasetNode as a base Class for IR nodes

Merge pull request !8096 from h.farahat/datasetNode
pull/8096/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6e2241d64f

File diff suppressed because it is too large Load Diff

@ -53,7 +53,7 @@ Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
RETURN_IF_NOT_OK(runtime_context->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds));
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context->AssignConsumer(std::move(consumer));
return Status::OK();
}

@ -11,7 +11,7 @@ endif ()
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine OBJECT
set(SRC_FILES_LIST
execution_tree.cc
data_buffer.cc
data_schema.cc
@ -20,10 +20,19 @@ add_library(engine OBJECT
runtime_context.cc
consumers/tree_consumer.cc
)
if (ENABLE_PYTHON)
set(SRC_FILES_LIST
${SRC_FILES_LIST}
python_runtime_context.cc
consumers/python_tree_consumer.cc
)
endif ()
add_library(engine OBJECT ${SRC_FILES_LIST})
if (ENABLE_PYTHON)
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif()
target_include_directories(engine PRIVATE ${pybind11_INCLUDE_DIRS})
endif ()
add_dependencies(engine engine-datasetops engine-datasetops-source engine-opt engine-gnn engine-perf engine-cache-client engine-datasetops-mapop)

@ -0,0 +1,46 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/engine/consumers/python_tree_consumer.h"
namespace mindspore::dataset {
Status PythonIteratorConsumer::GetNextAsList(py::list *out) {
std::vector<TensorPtr> row;
{
py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsVector(&row));
}
for (auto el : row) {
(*out).append(el);
}
return Status::OK();
}
Status PythonIteratorConsumer::GetNextAsDict(py::dict *out) {
std::unordered_map<std::string, TensorPtr> row;
{
py::gil_scoped_release gil_release;
RETURN_IF_NOT_OK(GetNextAsMap(&row));
}
for (auto el : row) {
(*out)[common::SafeCStr(el.first)] = el.second;
}
return Status::OK();
}
} // namespace mindspore::dataset

@ -26,24 +26,21 @@
namespace mindspore::dataset {
/// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict
class PythonIterator : public IteratorConsumer {
/// Constructor
class PythonIteratorConsumer : public IteratorConsumer {
public:
/// Constructor which will call the base class default constructor.
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {}
explicit PythonIteratorConsumer(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {}
/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
/// \return Status error code
Status GetNextAsList(py::list *out);
/// Get the next row as a python dict
/// \param[out] output python dict
/// \return Status error code
Status GetNextAsMap(py::dict *output) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Get the next row as a python dict
/// \param[out] output python dict
/// \return Status error code
Status GetNextAsList(py::list *output) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// Returns the next row in as a map
/// \param[out] out std::map of string to Tensor
/// \return Status error code
Status GetNextAsDict(py::dict *out);
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_PYTHON_TREE_CONSUMER_H_

@ -34,10 +34,11 @@ namespace mindspore::dataset {
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }
// IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) {
Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
@ -73,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
}
// ToDevice
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
@ -384,7 +385,7 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
tree_adapter_ = std::make_unique<TreeAdapter>();
}
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) {
Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
if (!s.IsError()) {
init_flag_ = true;
@ -463,4 +464,15 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), 1);
}
Status BuildVocabConsumer::Start() {
// Getting one row would trigger building the vocab
TensorRow row;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
// The returned row would EOE which is an empty row
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
return Status::OK();
}
} // namespace mindspore::dataset

@ -22,14 +22,16 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/tree_adapter.h"
#include "minddata/dataset/text/vocab.h"
namespace mindspore::dataset {
// Forward declare
class TreeAdapter;
namespace api {
class Dataset;
class DatasetNode;
}
/// A base class for tree consumers which would fetch rows from the tree pipeline
@ -40,7 +42,9 @@ class TreeConsumer {
/// Initializes the consumer, this involves constructing and preparing the tree.
/// \param d The dataset node that represent the root of the IR tree.
/// \return Status error code.
virtual Status Init(std::shared_ptr<api::Dataset> d);
virtual Status Init(std::shared_ptr<api::DatasetNode> d);
Status Terminate();
protected:
/// The class owns the tree_adapter that handles execution tree operations.
@ -57,7 +61,7 @@ class IteratorConsumer : public TreeConsumer {
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
@ -126,10 +130,10 @@ class SaveToDisk : public TreeConsumer {
/// Consumer that iterates over the dataset and send it to a device
class ToDevice : public TreeConsumer {
public:
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
/// Send the data to device
/// \return Status error code
@ -158,7 +162,7 @@ class ToDevice : public TreeConsumer {
class TreeGetters : public TreeConsumer {
public:
TreeGetters();
Status Init(std::shared_ptr<api::Dataset> d) override;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
@ -176,5 +180,23 @@ class TreeGetters : public TreeConsumer {
bool row_flag_; // indicate whether the first row has been stored in row_
};
class BuildVocabConsumer : public TreeConsumer {
public:
/// BuildVocabConsumer Constructor which will call the base class default constructor.
BuildVocabConsumer() = default;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// \return Status error code
Status Start();
protected:
/// Method to return the name of the consumer
/// \return string
std::string Name() override { return "BuildVocab"; }
};
} // namespace mindspore::dataset
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMERS_TREE_CONSUMER_H_

@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_subdirectory(source)
set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
dataset_node.cc
batch_node.cc
bucket_batch_by_length_node.cc
build_sentence_piece_vocab_node.cc

@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {
BatchNode::BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
BatchNode::BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
: batch_size_(batch_size),

@ -23,16 +23,16 @@
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class BatchNode : public Dataset {
class BatchNode : public DatasetNode {
public:
/// \brief Constructor
BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder, bool pad,
std::vector<std::string> cols_to_map,
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);

@ -29,7 +29,7 @@ namespace mindspore {
namespace dataset {
namespace api {
BucketBatchByLengthNode::BucketBatchByLengthNode(
std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info, bool pad_to_bucket_boundary,

@ -23,15 +23,15 @@
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class BucketBatchByLengthNode : public Dataset {
class BucketBatchByLengthNode : public DatasetNode {
public:
/// \brief Constructor
BucketBatchByLengthNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &column_names,
BucketBatchByLengthNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &column_names,
const std::vector<int32_t> &bucket_boundaries, const std::vector<int32_t> &bucket_batch_sizes,
std::function<TensorRow(TensorRow)> element_length_function = nullptr,
const std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> &pad_info = {},

@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {
BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<Dataset> child,
BuildSentenceVocabNode::BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child,
std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size,
float character_coverage, SentencePieceModel model_type,

@ -29,10 +29,10 @@ namespace mindspore {
namespace dataset {
namespace api {
class BuildSentenceVocabNode : public Dataset {
class BuildSentenceVocabNode : public DatasetNode {
public:
/// \brief Constructor
BuildSentenceVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<SentencePieceVocab> vocab,
BuildSentenceVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SentencePieceVocab> vocab,
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> &params);

@ -28,7 +28,7 @@ namespace mindspore {
namespace dataset {
namespace api {
BuildVocabNode::BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab,
BuildVocabNode::BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range,
int64_t top_k, const std::vector<std::string> &special_tokens, bool special_first)
: vocab_(vocab),

@ -22,17 +22,17 @@
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class BuildVocabNode : public Dataset {
class BuildVocabNode : public DatasetNode {
public:
/// \brief Constructor
BuildVocabNode(std::shared_ptr<Dataset> child, std::shared_ptr<Vocab> vocab, const std::vector<std::string> &columns,
const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
BuildVocabNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<Vocab> vocab,
const std::vector<std::string> &columns, const std::pair<int64_t, int64_t> &freq_range, int64_t top_k,
const std::vector<std::string> &special_tokens, bool special_first);
/// \brief Destructor

@ -27,18 +27,16 @@ namespace mindspore {
namespace dataset {
namespace api {
// Function to build ConcatOp
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets) : datasets_(datasets) {
this->children = datasets_;
}
ConcatNode::ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets) { this->children = datasets; }
Status ConcatNode::ValidateParams() {
if (datasets_.empty()) {
if (children.size() < 2) {
std::string err_msg = "ConcatNode: concatenated datasets are not specified.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (find(datasets_.begin(), datasets_.end(), nullptr) != datasets_.end()) {
if (find(children.begin(), children.end(), nullptr) != children.end()) {
std::string err_msg = "ConcatNode: concatenated datasets should not be null.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);

@ -21,16 +21,16 @@
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class ConcatNode : public Dataset {
class ConcatNode : public DatasetNode {
public:
/// \brief Constructor
explicit ConcatNode(const std::vector<std::shared_ptr<Dataset>> &datasets);
explicit ConcatNode(const std::vector<std::shared_ptr<DatasetNode>> &datasets);
/// \brief Destructor
~ConcatNode() = default;
@ -42,9 +42,6 @@ class ConcatNode : public Dataset {
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::vector<std::shared_ptr<Dataset>> datasets_;
};
} // namespace api

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
#include <memory>
namespace mindspore {
namespace dataset {
namespace api {
Status DatasetNode::AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
if (cache_ != nullptr) {
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
node_ops->push_back(cache_op);
}
return Status::OK();
}
// Constructor to initialize the cache
DatasetNode::DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache) : DatasetNode() { cache_ = dataset_cache; }
std::shared_ptr<DatasetNode> DatasetNode::SetNumWorkers(int32_t num_workers) {
#if !defined(_WIN32) && !defined(_WIN64)
#ifndef ENABLE_ANDROID
int32_t cpu_count = sysconf(_SC_NPROCESSORS_CONF);
if (cpu_count < 0 || cpu_count > INT32_MAX) {
MS_LOG(ERROR) << "Error determining current CPU: " << cpu_count;
return nullptr;
}
if (num_workers < 1 || num_workers > cpu_count) {
MS_LOG(ERROR) << "num_workers exceeds the boundary between 1 and " << cpu_count;
return nullptr;
}
#endif
#endif
num_workers_ = num_workers;
return shared_from_this();
}
DatasetNode::DatasetNode() {
// Fetch some default value from config manager
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,126 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "minddata/dataset/include/datasets.h"
namespace mindspore {
namespace dataset {
namespace api {
class Dataset;
class SamplerObj;
#define RETURN_EMPTY_IF_ERROR(_s) \
do { \
Status __rc = (_s); \
if (__rc.IsError()) { \
MS_LOG(ERROR) << __rc; \
return {}; \
} \
} while (false)
Status AddShuffleOp(int64_t num_files, int64_t num_devices, int64_t num_rows, int64_t total_rows,
int32_t connector_que_size, int32_t rows_per_buffer, std::shared_ptr<DatasetOp> *shuffle_op);
// Helper function to validate dataset files parameter
Status ValidateDatasetFilesParam(const std::string &dataset_name, const std::vector<std::string> &dataset_files);
// Helper function to validate dataset num_shards and shard_id parameters
Status ValidateDatasetShardParams(const std::string &dataset_name, int32_t num_shards, int32_t shard_id);
// Helper function to validate dataset sampler parameter
Status ValidateDatasetSampler(const std::string &dataset_name, const std::shared_ptr<SamplerObj> &sampler);
Status ValidateStringValue(const std::string &dataset_name, const std::string &str,
const std::unordered_set<std::string> &valid_strings);
// Helper function to validate dataset input/output column parameterCD -
Status ValidateDatasetColumnParam(const std::string &dataset_name, const std::string &column_param,
const std::vector<std::string> &columns);
// Helper function to validate dataset directory parameter
Status ValidateDatasetDirParam(const std::string &dataset_name, std::string dataset_dir);
/// \brief Function to create a sampler for non-mappable dataset (to be used by cache op later).
/// \notes Non-mappable dataset does not directly support a sampler. It has provided sampling arguments (shuffle,
/// num_samples, num_shards, shard_id) and it DOES support sampling if somewhere above it in the pipeline contains
/// a cache. If there is no cache above it, then the sampler is not used.
/// \param[in] num_samples The number of samples to be included in the dataset.
/// \param[in] shuffle If true, the indices are shuffled.
/// \param[in] num_shards Number of shards to divide the dataset into.
/// \param[in] shard_id Shard ID of the current shard within num_shards.
/// \return Shared pointer to the current Sampler.
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id);
class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
public:
/// \brief Constructor
DatasetNode();
/// \brief Constructor that initializes the cache
/// \param dataset_cache DatasetCache
explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache);
/// \brief Destructor
~DatasetNode() = default;
/// \brief Pure virtual function to convert a DatasetNode class into a runtime dataset object
/// \return The list of shared pointers to the newly created DatasetOps
virtual std::vector<std::shared_ptr<DatasetOp>> Build() = 0;
/// \brief Pure virtual function for derived class to implement parameters validation
/// \return Status Status::OK() if all the parameters are valid
virtual Status ValidateParams() = 0;
const std::vector<std::shared_ptr<DatasetNode>> Children() const { return children; }
/// \brief Pure virtual function for derived class to get the shard id of specific node
/// \return Status Status::OK() if get shard id successfully
virtual Status GetShardId(int32_t *shard_id) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
/// \brief Setter function for runtime number of workers
/// \param[in] num_workers The number of threads in this operator
/// \return Shared pointer to the original object
std::shared_ptr<DatasetNode> SetNumWorkers(int32_t num_workers);
protected:
std::vector<std::shared_ptr<DatasetNode>> children;
std::shared_ptr<DatasetNode> parent;
std::shared_ptr<DatasetCache> cache_;
Status AddCacheOp(std::vector<std::shared_ptr<DatasetOp>> *node_ops);
int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t connector_que_size_;
int32_t worker_connector_size_;
};
} // namespace api
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_DATASET_NODE_H_

@ -28,14 +28,14 @@ namespace mindspore {
namespace dataset {
namespace api {
MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
MapNode::MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
: operations_(operations),
input_columns_(input_columns),
output_columns_(output_columns),
project_columns_(project_columns),
Dataset(std::move(cache)) {
DatasetNode(std::move(cache)) {
this->children.push_back(child);
}

@ -21,15 +21,15 @@
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class MapNode : public Dataset {
class MapNode : public DatasetNode {
public:
/// \brief Constructor
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
MapNode(std::shared_ptr<DatasetNode> child, std::vector<std::shared_ptr<TensorOperation>> operations,
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);

@ -28,7 +28,8 @@ namespace dataset {
namespace api {
// Function to build ProjectOp
ProjectNode::ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns) : columns_(columns) {
ProjectNode::ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns)
: columns_(columns) {
this->children.push_back(child);
}

@ -21,17 +21,17 @@
#include <string>
#include <vector>
#include "minddata/dataset/include/datasets.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
namespace api {
class ProjectNode : public Dataset {
class ProjectNode : public DatasetNode {
public:
/// \brief Constructor
explicit ProjectNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &columns);
explicit ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns);
/// \brief Destructor
~ProjectNode() = default;

@ -27,7 +27,7 @@ namespace mindspore {
namespace dataset {
namespace api {
// Function to build RenameOp
RenameNode::RenameNode(std::shared_ptr<Dataset> child, const std::vector<std::string> &input_columns,
RenameNode::RenameNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &input_columns,
const std::vector<std::string> &output_columns)
: input_columns_(input_columns), output_columns_(output_columns) {
this->children.push_back(child);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save