Enable cache for other leaf ops

pull/6905/head
Lixia Chen 4 years ago
parent 983827ec5c
commit 477808e622

@ -1075,7 +1075,7 @@ std::vector<std::shared_ptr<DatasetOp>> CLUEDataset::Build() {
std::shared_ptr<ClueOp> clue_op =
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_);
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(clue_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -1256,7 +1256,7 @@ std::vector<std::shared_ptr<DatasetOp>> CSVDataset::Build() {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
sorted_dataset_files, field_delim_, column_default_list, column_names_, num_workers_, rows_per_buffer_,
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_);
num_samples_, worker_connector_size_, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(csv_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {
// Inject ShuffleOp
@ -1502,7 +1502,7 @@ std::vector<std::shared_ptr<DatasetOp>> TextFileDataset::Build() {
// Create and initalize TextFileOp
std::shared_ptr<TextFileOp> text_file_op = std::make_shared<TextFileOp>(
num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, std::move(schema), sorted_dataset_files,
connector_que_size_, shuffle_files, num_shards_, shard_id_, std::move(nullptr));
connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
RETURN_EMPTY_IF_ERROR(text_file_op->Init());
if (shuffle_ == ShuffleMode::kGlobal) {

File diff suppressed because it is too large Load Diff

@ -70,13 +70,12 @@ Status AlbumOp::Builder::SanityCheck() {
AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<Sampler> sampler)
: ParallelOp(num_wkrs, queue_size),
: ParallelOp(num_wkrs, queue_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
folder_path_(file_dir),
decode_(do_decode),
extensions_(exts),
data_schema_(std::move(data_schema)),
sampler_(std::move(sampler)),
row_cnt_(0),
buf_cnt_(0),
sampler_ind_(0),

@ -284,7 +284,6 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
std::set<std::string> extensions_; // extensions allowed
std::unordered_map<std::string, int32_t> col_name_map_;
std::unique_ptr<DataSchema> data_schema_;
std::shared_ptr<Sampler> sampler_;
int64_t row_cnt_;
int64_t buf_cnt_;
int64_t sampler_ind_;

@ -25,13 +25,18 @@
#include "minddata/dataset/util/task_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
ClueOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -68,7 +73,7 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_,
builder_device_id_);
builder_device_id_, std::move(builder_sampler_));
RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op);
@ -88,8 +93,8 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
bool shuffle_files, int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
rows_per_buffer_(rows_per_buffer),
num_rows_per_shard_(0),
all_num_rows_(0),
@ -539,5 +544,21 @@ Status ClueOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
// that this clue op will produce the full set of data into the cache.
void ClueOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status ClueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -20,6 +20,7 @@
#include <map>
#include <mutex>
#include <string>
#include <utility>
#include <vector>
#include <nlohmann/json.hpp>
@ -122,6 +123,14 @@ class ClueOp : public ParallelOp {
// @return - the a string vector
std::vector<std::string> split(const std::string &s, char delim);
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -133,12 +142,13 @@ class ClueOp : public ParallelOp {
std::vector<std::string> builder_clue_files_list_;
bool builder_shuffle_files_;
std::map<std::string, std::string> builder_cols_to_keyword_;
std::shared_ptr<Sampler> builder_sampler_;
};
// Constructor of ClueOp
ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size,
bool shuffle_files, int32_t num_devices, int32_t device_id);
bool shuffle_files, int32_t num_devices, int32_t device_id, std::shared_ptr<Sampler> sampler);
// Default destructor
~ClueOp() = default;
@ -173,6 +183,17 @@ class ClueOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return clue_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this clue op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the clue op need to be reset to its defaults so
/// that this clue op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

@ -124,7 +124,7 @@ Status CocoOp::Builder::SanityCheck() {
CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path, const std::string &annotation_path,
int32_t num_workers, int32_t rows_per_buffer, int32_t queue_size, bool decode,
std::unique_ptr<DataSchema> data_schema, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, queue_size),
: ParallelOp(num_workers, queue_size, std::move(sampler)),
decode_(decode),
row_cnt_(0),
buf_cnt_(0),
@ -132,7 +132,6 @@ CocoOp::CocoOp(const TaskType &task_type, const std::string &image_folder_path,
image_folder_path_(image_folder_path),
annotation_path_(annotation_path),
rows_per_buffer_(rows_per_buffer),
sampler_(std::move(sampler)),
data_schema_(std::move(data_schema)) {
io_block_queues_.Init(num_workers_, queue_size);
}

@ -206,6 +206,10 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
/// \return Status of the node visit
Status Accept(NodePass *p, bool *modified) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return
@ -324,7 +328,6 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
std::string annotation_path_;
TaskType task_type_;
int32_t rows_per_buffer_;
std::shared_ptr<Sampler> sampler_;
std::unique_ptr<DataSchema> data_schema_;
WaitPost wp_;

@ -22,12 +22,17 @@
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/jagged_connector.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
: builder_device_id_(0),
builder_num_devices_(1),
builder_num_samples_(0),
builder_shuffle_files_(false),
builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size();
@ -59,7 +64,8 @@ Status CsvOp::Builder::Build(std::shared_ptr<CsvOp> *op) {
std::shared_ptr<CsvOp> csv_op = std::make_shared<CsvOp>(
builder_csv_files_list_, builder_field_delim_, builder_column_default_list_, builder_column_name_list_,
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_,
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_,
std::move(builder_sampler_));
RETURN_IF_NOT_OK(csv_op->Init());
*op = std::move(csv_op);
@ -70,8 +76,8 @@ CsvOp::CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default,
const std::vector<std::string> &column_name, int32_t num_workers, int64_t rows_per_buffer,
int64_t num_samples, int32_t worker_connector_size, int32_t op_connector_size, bool shuffle_files,
int32_t num_device, int32_t device_id)
: ParallelOp(num_workers, op_connector_size),
int32_t num_device, int32_t device_id, std::shared_ptr<Sampler> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
csv_files_list_(std::move(csv_files_list)),
field_delim_(field_delim),
column_default_list_(column_default),
@ -889,5 +895,21 @@ Status CsvOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
// that this csv op will produce the full set of data into the cache.
void CsvOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
num_samples_ = 0;
}
// Visitor accept method for NodePass
Status CsvOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<CsvOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -240,6 +240,14 @@ class CsvOp : public ParallelOp {
return *this;
}
// Setter method
// @param std::shared_ptr<Sampler> sampler
// @return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
private:
int32_t builder_device_id_;
int32_t builder_num_devices_;
@ -253,6 +261,7 @@ class CsvOp : public ParallelOp {
char builder_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> builder_column_default_list_;
std::vector<std::string> builder_column_name_list_;
std::shared_ptr<Sampler> builder_sampler_;
};
// Constructor of CsvOp
@ -261,7 +270,8 @@ class CsvOp : public ParallelOp {
CsvOp(const std::vector<std::string> &csv_files_list, char field_delim,
const std::vector<std::shared_ptr<BaseRecord>> &column_default, const std::vector<std::string> &column_name,
int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id,
std::shared_ptr<Sampler> sampler);
// Default destructor
~CsvOp() = default;
@ -297,6 +307,17 @@ class CsvOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return csv_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this csv op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the csv op need to be reset to its defaults so
/// that this csv op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

@ -29,6 +29,7 @@
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/execution_tree.h"
#include "minddata/dataset/engine/opt/pass.h"
namespace mindspore {
namespace dataset {
@ -499,5 +500,21 @@ Status TextFileOp::ComputeColMap() {
}
return Status::OK();
}
// Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
// that this text file op will produce the full set of data into the cache.
void TextFileOp::MakeSimpleProducer() {
device_id_ = 0;
num_devices_ = 1;
shuffle_files_ = false;
total_rows_ = 0;
}
// Visitor accept method for NodePass
Status TextFileOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<TextFileOp>(), modified);
}
} // namespace dataset
} // namespace mindspore

@ -188,6 +188,17 @@ class TextFileOp : public ParallelOp {
// @return Vector of the input file names
std::vector<std::string> FileNames() { return text_files_list_; }
/// \Brief If a cache has been added into the ascendant tree over this text file op, then the cache will be executing
/// a sampler for fetching the data. As such, any options in the text file op need to be reset to its defaults so
/// that this text file op will produce the full set of data into the cache.
void MakeSimpleProducer();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

@ -212,6 +212,7 @@ Status VOCOp::LoadTensorRow(row_id_type row_id, const std::string &image_id, Ten
folder_path_ + std::string(kAnnotationsFolder) + image_id + std::string(kAnnotationExtension);
RETURN_IF_NOT_OK(ReadImageToTensor(kImageFile, data_schema_->column(0), &image));
RETURN_IF_NOT_OK(ReadAnnotationToTensor(kAnnotationFile, &annotation));
trow->setId(row_id);
trow->push_back(std::move(image));
trow->insert(trow->end(), annotation.begin(), annotation.end());
}

@ -45,6 +45,9 @@
#include "minddata/dataset/engine/datasetops/source/random_data_op.h"
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif
#include "minddata/dataset/engine/datasetops/source/voc_op.h"
#ifdef ENABLE_PYTHON
@ -260,6 +263,21 @@ Status NodePass::RunOnNode(std::shared_ptr<CacheLookupOp> node, bool *modified)
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
// Fallback to base class visitor by default
return RunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);
}
Status NodePass::PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified) {
// Fallback to base class visitor by default
return PreRunOnNode(std::static_pointer_cast<DatasetOp>(node), modified);

@ -81,6 +81,12 @@ class CacheMergeOp;
class CacheLookupOp;
class BuildSentencePieceVocabOp;
class ClueOp;
class CsvOp;
class TextFileOp;
#endif
#ifdef ENABLE_PYTHON
@ -211,6 +217,12 @@ class NodePass : public Pass {
virtual Status RunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified);
virtual Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheOp> node, bool *modified);
virtual Status PreRunOnNode(std::shared_ptr<CacheMergeOp> node, bool *modified);

@ -36,6 +36,9 @@
#ifndef ENABLE_ANDROID
#include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include "minddata/dataset/engine/datasetops/source/csv_op.h"
#include "minddata/dataset/engine/datasetops/source/text_file_op.h"
#endif
#ifdef ENABLE_PYTHON
@ -141,6 +144,36 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TFReaderOp> node
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) {
if (is_caching_) {
// If we are a ClueOp in a caching tree, then change our config so that it becomes a basic
// ClueOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) {
if (is_caching_) {
// If we are a CsvOp in a caching tree, then change our config so that it becomes a basic
// CsvOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) {
if (is_caching_) {
// If we are a TextFileOp in a caching tree, then change our config so that it becomes a basic
// TextFileOp that parses all files. Selection of data will come from the sampler on the cache instead.
node->MakeSimpleProducer();
}
return NonMappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif
// Perform leaf node cache transform identification
@ -163,34 +196,22 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<AlbumOp> node, b
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<MnistOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for MnistOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CifarOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CifarOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CocoOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CocoOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<CelebAOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for CelebAOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#ifndef ENABLE_ANDROID
@ -214,18 +235,12 @@ Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<GeneratorOp> nod
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<ManifestOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for ManifestOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
// Perform leaf node cache transform identification
Status CacheTransformPass::CachePass::RunOnNode(std::shared_ptr<VOCOp> node, bool *modified) {
if (is_caching_) {
RETURN_STATUS_UNEXPECTED("There is currently no support for VOCOp under cache.");
}
return Status::OK();
return MappableCacheLeafSetup(std::static_pointer_cast<DatasetOp>(node));
}
#endif

@ -65,6 +65,24 @@ class CacheTransformPass : public TreePass {
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TFReaderOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<ClueOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<CsvOp> node, bool *modified) override;
/// \brief Perform leaf node cache tranform identifications
/// \param[in] node The node being visited
/// \param[inout] modified Indicator if the node was changed at all
/// \return Status The error code return
Status RunOnNode(std::shared_ptr<TextFileOp> node, bool *modified) override;
#endif
/// \brief Perform leaf node cache tranform identifications

File diff suppressed because it is too large Load Diff

@ -83,6 +83,9 @@ def check_mnist_cifar_dataset(method):
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -110,6 +113,9 @@ def check_manifestdataset(method):
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -180,6 +186,9 @@ def check_vocdataset(method):
validate_dataset_param_value(nreq_param_dict, param_dict, dict)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -216,6 +225,9 @@ def check_cocodataset(method):
raise ValueError("CocoDataset doesn't support PKSampler")
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -252,6 +264,9 @@ def check_celebadataset(method):
if sampler is not None and isinstance(sampler, samplers.PKSampler):
raise ValueError("CelebADataset does not support PKSampler.")
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -842,6 +857,9 @@ def check_cluedataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -885,6 +903,9 @@ def check_csvdataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method
@ -904,6 +925,9 @@ def check_textfiledataset(method):
validate_dataset_param_value(nreq_param_int, param_dict, int)
check_sampler_shuffle_shard_options(param_dict)
cache = param_dict.get('cache')
check_cache_option(cache)
return method(self, *args, **kwargs)
return new_method

@ -103,6 +103,24 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_epoch_ctrl" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_coco" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_mnist" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_celeba" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_manifest" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_cifar" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_map.py" "test_cache_map_voc" 1
HandleRcExit $? 0 0
# Run two parallel pipelines (sharing cache)
for i in $(seq 1 2)
do
@ -282,6 +300,15 @@ HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_epoch_ctrl" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_clue" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_csv" 1
HandleRcExit $? 0 0
PytestCmd "test_cache_nomap.py" "test_cache_nomap_textfile" 1
HandleRcExit $? 0 0
for i in $(seq 1 3)
do
test_name="test_cache_nomap_multiple_cache${i}"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -40,8 +40,9 @@ def test_textline_dataset_all_file():
assert count == 5
def test_textline_dataset_num_samples_zero():
data = ds.TextFileDataset(DATA_FILE, num_samples=0)
def test_textline_dataset_num_samples_none():
# Do not provide a num_samples argument, so it would be None by default
data = ds.TextFileDataset(DATA_FILE)
count = 0
for i in data.create_dict_iterator(num_epochs=1, output_numpy=True):
logger.info("{}".format(i["text"]))
@ -208,7 +209,7 @@ def test_textline_dataset_exceptions():
if __name__ == "__main__":
test_textline_dataset_one_file()
test_textline_dataset_all_file()
test_textline_dataset_num_samples_zero()
test_textline_dataset_num_samples_none()
test_textline_dataset_shuffle_false4()
test_textline_dataset_shuffle_false1()
test_textline_dataset_shuffle_files4()

Loading…
Cancel
Save