Enable cache for other leaf ops

pull/6905/head
Lixia Chen 4 years ago
parent 983827ec5c
commit 477808e622

@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
std::shared_ptr<ClueOp> clue_op = std::shared_ptr<ClueOp> clue_op =
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map, std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_); sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(clue_op->Init()); RETURN_EMPTY_IF_ERROR(clue_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) { if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_, sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_); num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(csv_op->Init()); RETURN_EMPTY_IF_ERROR(csv_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) { if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp // Inject ShuffleOp
@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// Create and initalize TextFileOp // Create and initalize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>( std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files, num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr)); connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(text_file_op->Init()); RETURN_EMPTY_IF_ERROR(text_file_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) { if (shuffle_ == ShuffleMode::kGlobal) {

File diff suppressed because it is too large Load Diff

@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() {
AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler) std::shared_ptr<Sampler> sampler)
: ParallelOp(num_wkrs, queue_size), : ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir), folder_path_(file_dir),
decode_(do_decode), decode_(do_decode),
extensions_(exts), extensions_(exts),
data_schema_(std::move(data_schema)), data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
row_cnt_(0), row_cnt_(0),
buf_cnt_(0), buf_cnt_(0),
sampler_ind_(0), sampler_ind_(0),

@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
std::set<std::string> extensions_; // extensions allowed std::set<std::string> extensions_; // extensions allowed
std::unordered_map<std::string, int32_t> col_name_map_; std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t row_cnt_; int64_t row_cnt_;
int64_t buf_cnt_; int64_t buf_cnt_;
int64_t sampler_ind_; int64_t sampler_ind_;

@ -25,13 +25,18 @@
#include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
ClueOp::Builder::Builder() ClueOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { : builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers(); builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size(); builder_op_connector_size_ = config_manager->op_connector_size();
@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_); builder_device_id_, std::move(builder_sampler_));
RETURN_IF_NOT_OK(clue_op->Init()); RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op); *op = std::move(clue_op);
@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id) bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0), num_rows_per_shard_(0),
all_num_rows_(0), all_num_rows_(0),
@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() {
} }
return Status::OK(); return Status::OK();
} }
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
// that this clue op will produce the full set of data into the cache.
void ClueOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status ClueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -20,6 +20,7 @@
#include <map> #include <map>
#include <mutex> #include <mutex>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include <nlohmann/json.hpp> #include <nlohmann/json.hpp>
@ -122,6 +123,14 @@ class ClueOp : public ParallelOp {
// @return - the a string vector // @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim); std::vector<std::string> split(const std::string &s, char delim);
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private: private:
int32_t builder_device_id_; int32_t builder_device_id_;
int32_t builder_num_devices_; int32_t builder_num_devices_;
@ -133,12 +142,13 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_; std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_; bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_; std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<Sampler> builder_sampler_;
}; };
// Constructor of ClueOp // Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id); bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
// Default destructor // Default destructor
~ClueOp() = default; ~ClueOp() = default;
@ -173,6 +183,17 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names // @return Vector of the input file names
std::vector<std::string> FileNames() { return clue_files_list_; } std::vector<std::string> FileNames() { return clue_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
/// that this clue op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private: private:
// The entry point for when workers are launched. // The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.

@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() {
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path, CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode, int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler) std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size), : ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode), decode_(decode),
row_cnt_(0), row_cnt_(0),
buf_cnt_(0), buf_cnt_(0),
@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path,
image_folder_path_(image_folder_path), image_folder_path_(image_folder_path),
annotation_path_(annotation_path), annotation_path_(annotation_path),
rows_per_buffer_(rows_per_buffer), rows_per_buffer_(rows_per_buffer),
sampler_(std::move(sampler)),
data_schema_(std::move(data_schema)) { data_schema_(std::move(data_schema)) {
io_block_queues_.Init(num_workers_, queue_size); io_block_queues_.Init(num_workers_, queue_size);
} }

@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the node visit /// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override; Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }
private: private:
// Initialize Sampler, calls sampler->Init() within // Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return // @return Status - The error code return
@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::string annotation_path_; std::string annotation_path_;
TaskType task_type_; TaskType task_type_;
int32_t rows_per_buffer_; int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
WaitPost wp_; WaitPost wp_;

@ -22,12 +22,17 @@
#include "minddata/dataset/core/config_manager.h" #include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
CsvOp::Builder::Builder() CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { : builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers(); builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size(); builder_op_connector_size_ = config_manager->op_connector_size();
@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>( std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_, builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_); builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
std::move(builder_sampler_));
RETURN_IF_NOT_OK(csv_op->Init()); RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op); *op = std::move(csv_op);
@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer, const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files, int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id) int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
csv_files_list_(std::move(csv_files_list)), csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim), field_delim_(field_delim),
column_default_list_(column_default), column_default_list_(column_default),
@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() {
} }
return Status::OK(); return Status::OK();
} }
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
// that this csv op will produce the full set of data into the cache.
void CsvOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status CsvOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -240,6 +240,14 @@ class CsvOp : public ParallelOp {
return *this; return *this;
} }
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private: private:
int32_t builder_device_id_; int32_t builder_device_id_;
int32_t builder_num_devices_; int32_t builder_num_devices_;
@ -253,6 +261,7 @@ class CsvOp : public ParallelOp {
char builder_field_delim_; char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_; std::vector<std::string> builder_column_name_list_;
std::shared_ptr<Sampler> builder_sampler_;
}; };
// Constructor of CsvOp // Constructor of CsvOp
@ -261,7 +270,8 @@ class CsvOp : public ParallelOp {
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim, CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name, const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id); int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<Sampler> sampler);
// Default destructor // Default destructor
~CsvOp() = default; ~CsvOp() = default;
@ -297,6 +307,17 @@ class CsvOp : public ParallelOp {
// @return Vector of the input file names // @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; } std::vector<std::string> FileNames() { return csv_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
/// that this csv op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private: private:
// The entry point for when workers are launched. // The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.

@ -29,6 +29,7 @@
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() {
} }
return Status::OK(); return Status::OK();
} }
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
// that this text file op will produce the full set of data into the cache.
void TextFileOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
total_rows_ = 0;
}
// Visitor accept method for NodePass
Status TextFileOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names // @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; } std::vector<std::string> FileNames() { return text_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
/// that this text file op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private: private:
// The entry point for when workers are launched. // The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.

@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten
folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension); folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension);
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image)); RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation)); RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation));
trow->setId(row_id);
trow->push_back(std::move(image)); trow->push_back(std::move(image));
trow->insert(trow->end(), annotation.begin(), annotation.end()); trow->insert(trow->end(), annotation.begin(), annotation.end());
} }

@ -45,6 +45,9 @@
#include "minddata/dataset/engine/datasetops/source/random_data_op.h" #include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif #endif
#include "minddata/dataset/engine/datasetops/source/voc_op.h" #include "minddata/dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
} }
Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) { Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default // Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified); return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);

@ -81,6 +81,12 @@ class CacheMergeOp;
class CacheLookupOp; class CacheLookupOp;
class BuildSentencePieceVocabOp; class BuildSentencePieceVocabOp;
class ClueOp;
class CsvOp;
class TextFileOp;
#endif #endif
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -211,6 +217,12 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified); virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified); virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);

@ -36,6 +36,9 @@
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif #endif
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node
} }
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node)); return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
} }
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
if (is_caching_) {
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
if (is_caching_) {
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
if (is_caching_) {
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif #endif
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
}
return Status::OK();
} }
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
}
return Status::OK();
} }
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
}
return Status::OK();
} }
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
}
return Status::OK();
} }
#ifndef ENABLE_ANDROID #ifndef ENABLE_ANDROID
@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
}
return Status::OK();
} }
// Perform leaf node cache transform identification // Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) { Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
if (is_caching_) { return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
}
return Status::OK();
} }
#endif #endif

@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass {
/// \param[inout] modified Indicator if the node was changed at all /// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return /// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override; Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif #endif
/// \brief Perform leaf node cache tranform identifications /// \brief Perform leaf node cache tranform identifications

File diff suppressed because it is too large Load Diff

@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method):
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -110,6 +113,9 @@ def check_manifestdataset(method):
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -180,6 +186,9 @@ def check_vocdataset(method):
validate_dataset_param_value(nreq_param_dict, param_dict, dict) validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -216,6 +225,9 @@ def check_cocodataset(method):
raise ValueError("CocoDataset doesn't support PKSampler") raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -252,6 +264,9 @@ def check_celebadataset(method):
if sampler is not None and isinstance(sampler, samplers.PKSampler): if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CelebADataset does not support PKSampler.") raise ValueError("CelebADataset does not support PKSampler.")
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -842,6 +857,9 @@ def check_cluedataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -885,6 +903,9 @@ def check_csvdataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
@ -904,6 +925,9 @@ def check_textfiledataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int) validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict) check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

@ -103,6 +103,24 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1 PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_coco" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_voc" 1
HandleRcExit $? 0 0
# Run two parallel pipelines (sharing cache) # Run two parallel pipelines (sharing cache)
for i in $(seq 1 2) for i in $(seq 1 2)
do do
@ -282,6 +300,15 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1 PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1
HandleRcExit $? 0 0 HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1
HandleRcExit $? 0 0
for i in $(seq 1 3) for i in $(seq 1 3)
do do
test_name="test_cache_nomap_multiple_cache${i}" test_name="test_cache_nomap_multiple_cache${i}"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -40,8 +40,9 @@ def test_textline_dataset_all_file():
assert count == 5 assert count == 5
def test_textline_dataset_num_samples_zero(): def test_textline_dataset_num_samples_none():
data = ds.TextFileDataset(DATA_FILE, num_samples=0) # Do not provide a num_samples argument, so it would be None by default
data = ds.TextFileDataset(DATA_FILE)
count = 0 count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True): for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"])) logger.info("{}".format(i["text"]))
@ -208,7 +209,7 @@ def test_textline_dataset_exceptions():
if __name__ == "__main__": if __name__ == "__main__":
test_textline_dataset_one_file() test_textline_dataset_one_file()
test_textline_dataset_all_file() test_textline_dataset_all_file()
test_textline_dataset_num_samples_zero() test_textline_dataset_num_samples_none()
test_textline_dataset_shuffle_false4() test_textline_dataset_shuffle_false4()
test_textline_dataset_shuffle_false1() test_textline_dataset_shuffle_false1()
test_textline_dataset_shuffle_files4() test_textline_dataset_shuffle_files4()

Loading…
Cancel
Save