!11393 Migrating cache transform pass from execution tree to IR tree

From: @lixiachen
Reviewed-by: 
Signed-off-by:
pull/11393/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4cd6588af0

@ -237,8 +237,11 @@ Status CacheOp::Accept(NodePass *p, bool *const modified) {
return p->RunOnNode(shared_from_base<CacheOp>(), modified);
}
// A public wrapper for creating the cache through the client
Status CacheOp::CreateCache(uint32_t cache_crc) {
Status CacheOp::PrepareNodePostAction() {
// Run any common code from super class first before adding our own
RETURN_IF_NOT_OK(ParallelOp::PrepareNodePostAction());
// Get the computed check sum from all ops in our cache path below us and ask the cache op to create it's cache
uint32_t cache_crc = DatasetOp::GenerateCRC(shared_from_this());
// This is a non-mappable cache op so the id's need to be generated.
// Construct the cache
const bool generate_ids = true;

@ -141,11 +141,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
bool AllowCacheMiss() override { return false; }
/// \brief Base-class override for the name of this operator
std::string Name() const override { return kCacheOp; }
/// \brief A public wrapper for creating the cache through the client
/// \param[in] cache_crc The crc that identifies the cache
/// \see cache_pass.cc
/// \return Status return code
Status CreateCache(uint32_t cache_crc);
Status PrepareNodePostAction() override;
private:
WaitPost rows_cache_done_;

@ -33,11 +33,7 @@
namespace mindspore {
namespace dataset {
ClueOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -74,7 +70,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_, std::move(builder_sampler_));
builder_device_id_);
RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op);
@ -94,8 +90,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
@ -552,16 +548,6 @@ Status ClueOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
// that this clue op will produce the full set of data into the cache.
void ClueOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status ClueOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

@ -122,14 +122,6 @@ class ClueOp : public ParallelOp {
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -141,13 +133,12 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~ClueOp() = default;
@ -182,11 +173,6 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return clue_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
/// that this clue op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "ClueOp"; }

@ -29,11 +29,7 @@
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -65,8 +61,7 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
std::move(builder_sampler_));
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
@ -77,8 +72,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
@ -920,16 +915,6 @@ Status CsvOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
// that this csv op will produce the full set of data into the cache.
void CsvOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status CsvOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

@ -241,14 +241,6 @@ class CsvOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -262,7 +254,6 @@ class CsvOp : public ParallelOp {
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of CsvOp
@ -271,8 +262,7 @@ class CsvOp : public ParallelOp {
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<SamplerRT> sampler);
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~CsvOp() = default;
@ -308,11 +298,6 @@ class CsvOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
/// that this csv op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

@ -34,8 +34,7 @@ RandomDataOp::Builder::Builder()
builder_num_workers_(0),
builder_op_connector_size_(0),
builder_rows_per_buffer_(0),
builder_total_rows_(0),
builder_sampler_(nullptr) {
builder_total_rows_(0) {
// Some arguments to the RandomDataOp have a default argument that is taken from the config.
// The user may override these defaults by using the builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
@ -48,9 +47,8 @@ RandomDataOp::Builder::Builder()
Status RandomDataOp::Builder::Build(std::shared_ptr<RandomDataOp> *out_op) {
RETURN_IF_NOT_OK(SanityCheck());
*out_op =
std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
builder_total_rows_, std::move(builder_data_schema_), std::move(builder_sampler_));
*out_op = std::make_shared<RandomDataOp>(builder_num_workers_, builder_op_connector_size_, builder_rows_per_buffer_,
builder_total_rows_, std::move(builder_data_schema_));
return Status::OK();
}
@ -65,8 +63,8 @@ Status RandomDataOp::Builder::SanityCheck() const {
// Constructor for RandomDataOp
RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
std::unique_ptr<DataSchema> data_schema)
: ParallelOp(num_workers, op_connector_size),
buffer_id_(0),
rows_per_buffer_(rows_per_buffer),
total_rows_(total_rows),
@ -80,8 +78,7 @@ RandomDataOp::RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64
if (total_rows_ == 0) {
total_rows_ = GenRandomInt(1, kMaxTotalRows);
}
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random
// schema.
// If the user did not provide a schema, then we will ask the op to generate a pseudo-random schema.
// See details of generateSchema function to learn what type of schema it will create.
if (data_schema_ == nullptr) {
GenerateSchema();

@ -117,14 +117,6 @@ class RandomDataOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
/**
* Check if the required parameters are set by the builder.
@ -133,7 +125,6 @@ class RandomDataOp : public ParallelOp {
Status SanityCheck() const;
std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_num_workers_;
int32_t builder_op_connector_size_;
int64_t builder_rows_per_buffer_;
@ -148,11 +139,10 @@ class RandomDataOp : public ParallelOp {
* @param rows_per_buffer - The number of rows in each DataBuffer
* @param data_schema - A user-provided schema
* @param total_rows - The total number of rows in the dataset
* @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
* @return Builder - The modified builder by reference
*/
RandomDataOp(int32_t num_workers, int32_t op_connector_size, int64_t rows_per_buffer, int64_t total_rows,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
std::unique_ptr<DataSchema> data_schema);
/**
* Destructor

@ -34,11 +34,7 @@
namespace mindspore {
namespace dataset {
TextFileOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_total_rows_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -74,7 +70,7 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_total_rows_, builder_worker_connector_size_,
std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_,
builder_num_devices_, builder_device_id_, std::move(builder_sampler_));
builder_num_devices_, builder_device_id_);
RETURN_IF_NOT_OK(text_file_op->Init());
*op = std::move(text_file_op);
@ -83,9 +79,8 @@ Status TextFileOp::Builder::Build(std::shared_ptr<TextFileOp> *op) {
TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema> schema, std::vector<std::string> text_files_list,
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id,
std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
@ -504,16 +499,6 @@ Status TextFileOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
// that this text file op will produce the full set of data into the cache.
void TextFileOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
total_rows_ = 0;
}
// Visitor accept method for NodePass
Status TextFileOp::Accept(NodePass *p, bool *const modified) {
// Downcast shared pointer then call visitor

@ -112,14 +112,6 @@ class TextFileOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -131,7 +123,6 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> builder_text_files_list_;
bool builder_shuffle_files_;
std::unique_ptr<DataSchema> builder_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
};
// Constructor of TextFileOp
@ -145,10 +136,9 @@ class TextFileOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t total_rows, int32_t worker_connector_size,
std::unique_ptr<DataSchema>, std::vector<std::string> text_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<SamplerRT> sampler);
bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~TextFileOp() = default;
@ -187,11 +177,6 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
/// that this text file op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.

@ -44,11 +44,7 @@
namespace mindspore {
namespace dataset {
TFReaderOp::Builder::Builder()
: builder_device_id_(0),
builder_num_devices_(1),
builder_total_rows_(0),
builder_equal_rows_per_shard_(false),
builder_sampler_(nullptr) {
: builder_device_id_(0), builder_num_devices_(1), builder_total_rows_(0), builder_equal_rows_per_shard_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_worker_connector_size_ = config_manager->worker_connector_size();
@ -122,8 +118,7 @@ Status TFReaderOp::Builder::Build(std::shared_ptr<TFReaderOp> *out_tf_reader_op)
std::shared_ptr<TFReaderOp> new_tf_reader_op = std::make_shared<TFReaderOp>(
builder_num_workers_, builder_worker_connector_size_, builder_rows_per_buffer_, builder_total_rows_,
builder_dataset_files_list_, std::move(builder_data_schema_), builder_op_connector_size_, builder_columns_to_load_,
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_,
std::move(builder_sampler_));
builder_shuffle_files_, builder_num_devices_, builder_device_id_, builder_equal_rows_per_shard_);
RETURN_IF_NOT_OK(new_tf_reader_op->Init());
*out_tf_reader_op = std::move(new_tf_reader_op);
@ -134,8 +129,8 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
int64_t total_num_rows, std::vector<std::string> dataset_files_list,
std::unique_ptr<DataSchema> data_schema, int32_t op_connector_size,
std::vector<std::string> columns_to_load, bool shuffle_files, int32_t num_device,
int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
int32_t device_id, bool equal_rows_per_shard)
: ParallelOp(num_workers, op_connector_size),
device_id_(device_id),
num_devices_(num_device),
rows_per_buffer_(rows_per_buffer),
@ -1043,17 +1038,6 @@ Status TFReaderOp::ComputeColMap() {
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
// that this tf reader will produce the full set of data into the cache.
void TFReaderOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
total_rows_ = 0;
shuffle_files_ = false;
equal_rows_per_shard_ = false;
}
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
Status TFReaderOp::PrepareNodePostAction() {

@ -153,17 +153,8 @@ class TFReaderOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
std::unique_ptr<DataSchema> builder_data_schema_;
std::shared_ptr<SamplerRT> builder_sampler_;
int32_t builder_device_id_;
int32_t builder_num_devices_;
int32_t builder_num_workers_;
@ -189,11 +180,10 @@ class TFReaderOp : public ParallelOp {
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
// @param sampler - allow a sampler. Only valid if a cache exists in ascendent tree nodes
TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
std::vector<std::string> dataset_files_list, std::unique_ptr<DataSchema> data_schema,
int32_t op_connector_size, std::vector<std::string> columns_to_load, bool shuffle_files,
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard, std::shared_ptr<SamplerRT> sampler);
int32_t num_devices, int32_t device_id, bool equal_rows_per_shard);
// Default destructor
~TFReaderOp() = default;
@ -246,11 +236,6 @@ class TFReaderOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return dataset_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this tf reader, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the tf reader need to be reset to its defaults so
/// that this tf reader will produce the full set of data into the cache.
void MakeSimpleProducer();
// During tree prepare phase, operators may have specific post-operations to perform depending on
// their role.
// @notes Derived versions of this function should always call it's superclass version first
@ -387,7 +372,7 @@ class TFReaderOp : public ParallelOp {
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Caculate number of rows in each shard.
// Calculate number of rows in each shard.
// @return Status - the error code returned.
Status CalculateNumRowsPerShard();

@ -320,7 +320,6 @@ Status ExecutionTree::PostAction() {
// The IR version cannot detect an invalid case of a cache on Map with random tensor operation from Python API.
// This is because Python API binding to TensorOperation is still in progress.
post_actions.push_back(std::make_unique<CacheErrorPass>());
post_actions.push_back(std::make_unique<CacheTransformPass>());
post_actions.push_back(std::make_unique<RepeatPass>());
#endif

@ -19,6 +19,7 @@
#include <memory>
#include "minddata/dataset/engine/datasetops/dataset_op.h"
#include "minddata/dataset/include/samplers.h"
#include "minddata/dataset/util/status.h"
namespace mindspore::dataset {
@ -29,6 +30,9 @@ class DatasetCache {
virtual Status ValidateParams() = 0;
virtual Status CreateCacheOp(int num_workers, std::shared_ptr<DatasetOp> *ds_op) = 0;
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
virtual Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) = 0;
virtual Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) = 0;
};
} // namespace mindspore::dataset

@ -16,6 +16,8 @@
#include <memory>
#include "minddata/dataset/engine/ir/cache/dataset_cache_impl.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
@ -44,5 +46,28 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr<Data
return Status::OK();
}
Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.Build(&lookup_op));
*ds = lookup_op;
return Status::OK();
}
Status DatasetCacheImpl::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
*ds = merge_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -56,6 +56,11 @@ class DatasetCacheImpl : public DatasetCache {
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
~DatasetCacheImpl() = default;

@ -16,6 +16,8 @@
#include <memory>
#include "minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h"
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
namespace mindspore {
@ -46,5 +48,29 @@ Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) {
*out_json = args;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.Build(&lookup_op));
*ds = lookup_op;
return Status::OK();
}
Status PreBuiltDatasetCache::CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheMergeOp> merge_op = nullptr;
RETURN_IF_NOT_OK(CacheMergeOp::Builder().SetNumWorkers(num_workers).SetClient(cache_client_).Build(&merge_op));
*ds = merge_op;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -40,6 +40,11 @@ class PreBuiltDatasetCache : public DatasetCache {
Status CreateCacheOp(int32_t num_workers, std::shared_ptr<DatasetOp> *const ds) override;
Status CreateCacheLookupOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds,
std::shared_ptr<SamplerObj> sampler) override;
Status CreateCacheMergeOp(int32_t num_workers, std::shared_ptr<DatasetOp> *ds) override;
Status ValidateParams() override { return Status::OK(); }
Status to_json(nlohmann::json *out_json) override;

@ -8,6 +8,9 @@ set(DATASET_ENGINE_IR_DATASETOPS_SRC_FILES
bucket_batch_by_length_node.cc
build_sentence_piece_vocab_node.cc
build_vocab_node.cc
cache_lookup_node.cc
cache_merge_node.cc
cache_node.cc
concat_node.cc
epoch_ctrl_node.cc
filter_node.cc

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_lookup_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheLookupNode::CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), sampler_(sampler), lookup_op_(nullptr), lookup_node_copy_(nullptr) {
this->AddChild(child);
}
void CacheLookupNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheLookupNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CacheLookupNode>(nullptr, sampler, cache_);
lookup_node_copy_ = node;
return node;
}
Status CacheLookupNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
return Status::OK();
}
Status CacheLookupNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache lookup node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
RETURN_IF_NOT_OK(cache_->CreateCacheLookupOp(num_workers_, &lookup_op_, sampler_));
node_ops->push_back(lookup_op_);
return Status::OK();
}
std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
// CacheLookupNode should already been copied, so we just return it here
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
}
std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() {
// Runtime cache lookup op should already been built, so we just return it here
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
return std::shared_ptr<SamplerRT>(lookup_op);
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,75 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/datasetops/cache_lookup_op.h"
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheLookupNode : public DatasetNode, public SamplerObj {
public:
/// \brief Constructor
CacheLookupNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheLookupNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheLookupNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
std::shared_ptr<SamplerRT> SamplerBuild() override;
/// \brief a base class override function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
std::shared_ptr<SamplerObj> SamplerCopy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
std::shared_ptr<DatasetOp> lookup_op_;
std::shared_ptr<CacheLookupNode> lookup_node_copy_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_merge_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/cache_merge_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheMergeNode::CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)) {
nary_op_ = true;
this->AddChild(child);
}
void CacheMergeNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheMergeNode::Copy() {
auto node = std::make_shared<CacheMergeNode>(nullptr, cache_);
return node;
}
Status CacheMergeNode::ValidateParams() { return Status::OK(); }
Status CacheMergeNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache merge node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> merge_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheMergeOp(num_workers_, &merge_op));
node_ops->push_back(merge_op);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheMergeNode : public DatasetNode {
public:
/// \brief Constructor
CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheMergeNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheMergeNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/ir/datasetops/cache_node.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/cache_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
CacheNode::CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache)
: DatasetNode(std::move(cache)), sampler_(sampler) {
this->AddChild(child);
}
void CacheNode::Print(std::ostream &out) const { out << Name(); }
std::shared_ptr<DatasetNode> CacheNode::Copy() {
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
auto node = std::make_shared<CacheNode>(nullptr, sampler, cache_);
return node;
}
Status CacheNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetSampler("CacheNode", sampler_));
return Status::OK();
}
Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_ != nullptr,
"Internal error. Attempt to create a cache node without cache client.");
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
cache_op->SetSampler(sampler_->SamplerBuild());
node_ops->push_back(cache_op);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
namespace mindspore {
namespace dataset {
class CacheNode : public DatasetNode {
public:
/// \brief Constructor
CacheNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<SamplerObj> sampler,
std::shared_ptr<DatasetCache> cache);
/// \brief Destructor
~CacheNode() = default;
/// \brief Node name getter
/// \return Name of the current node
std::string Name() const override { return kCacheNode; }
/// \brief Print the description
/// \param out - The output stream to write output to
void Print(std::ostream &out) const override;
/// \brief Copy the node to a new object
/// \return A shared pointer to the new copy
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to create the required runtime dataset op objects for this class
/// \param node_ops - A vector containing shared pointer to the Dataset Ops that this object will create
/// \return Status Status::OK() if build successfully
Status Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) override;
/// \brief Parameters validation
/// \return Status Status::OK() if all the parameters are valid
Status ValidateParams() override;
private:
std::shared_ptr<SamplerObj> sampler_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save