!14633 Remove rows_per_buffer from MindData

From: @hfarahat
Reviewed-by: @robingrosman,@pandoublefeng
Signed-off-by: @pandoublefeng
pull/14633/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 98307c10db

@ -42,7 +42,6 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_numa_enable", &ConfigManager::numa_enable) .def("get_numa_enable", &ConfigManager::numa_enable)
.def("set_numa_enable", &ConfigManager::set_numa_enable) .def("set_numa_enable", &ConfigManager::set_numa_enable)
.def("get_op_connector_size", &ConfigManager::op_connector_size) .def("get_op_connector_size", &ConfigManager::op_connector_size)
.def("get_rows_per_buffer", &ConfigManager::rows_per_buffer)
.def("get_seed", &ConfigManager::seed) .def("get_seed", &ConfigManager::seed)
.def("set_rank_id", &ConfigManager::set_rank_id) .def("set_rank_id", &ConfigManager::set_rank_id)
.def("get_worker_connector_size", &ConfigManager::worker_connector_size) .def("get_worker_connector_size", &ConfigManager::worker_connector_size)
@ -54,7 +53,6 @@ PYBIND_REGISTER(ConfigManager, 0, ([](const py::module *m) {
.def("get_profiler_file_status", &ConfigManager::get_profiler_file_status) .def("get_profiler_file_status", &ConfigManager::get_profiler_file_status)
.def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers) .def("set_num_parallel_workers", &ConfigManager::set_num_parallel_workers)
.def("set_op_connector_size", &ConfigManager::set_op_connector_size) .def("set_op_connector_size", &ConfigManager::set_op_connector_size)
.def("set_rows_per_buffer", &ConfigManager::set_rows_per_buffer)
.def("set_seed", &ConfigManager::set_seed) .def("set_seed", &ConfigManager::set_seed)
.def("set_worker_connector_size", &ConfigManager::set_worker_connector_size) .def("set_worker_connector_size", &ConfigManager::set_worker_connector_size)
.def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); }); .def("load", [](ConfigManager &c, std::string s) { THROW_IF_ERROR(c.LoadFile(s)); });

@ -31,8 +31,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
ConfigManager::ConfigManager() ConfigManager::ConfigManager()
: rows_per_buffer_(kCfgRowsPerBuffer), : num_parallel_workers_(kCfgParallelWorkers),
num_parallel_workers_(kCfgParallelWorkers),
worker_connector_size_(kCfgWorkerConnectorSize), worker_connector_size_(kCfgWorkerConnectorSize),
op_connector_size_(kCfgOpConnectorSize), op_connector_size_(kCfgOpConnectorSize),
rank_id_(kCfgDefaultRankId), rank_id_(kCfgDefaultRankId),
@ -70,7 +69,6 @@ void ConfigManager::Print(std::ostream &out) const {
// Don't show the test/internal ones. Only display the main ones here. // Don't show the test/internal ones. Only display the main ones here.
// fyi, boolalpha tells the output stream to write "true" and "false" for bools // fyi, boolalpha tells the output stream to write "true" and "false" for bools
out << "\nClient config settings :" out << "\nClient config settings :"
<< "\nDataCache Rows per buffer : " << rows_per_buffer_
<< "\nParallelOp workers : " << num_parallel_workers_ << "\nParallelOp workers : " << num_parallel_workers_
<< "\nParallelOp worker connector size : " << worker_connector_size_ << "\nParallelOp worker connector size : " << worker_connector_size_
<< "\nSize of each Connector : " << op_connector_size_ << std::endl; << "\nSize of each Connector : " << op_connector_size_ << std::endl;
@ -78,7 +76,6 @@ void ConfigManager::Print(std::ostream &out) const {
// Private helper function that takes a nlohmann json format and populates the settings // Private helper function that takes a nlohmann json format and populates the settings
Status ConfigManager::FromJson(const nlohmann::json &j) { Status ConfigManager::FromJson(const nlohmann::json &j) {
set_rows_per_buffer(j.value("rowsPerBuffer", rows_per_buffer_));
set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_)); set_num_parallel_workers(j.value("numParallelWorkers", num_parallel_workers_));
set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_)); set_worker_connector_size(j.value("workerConnectorSize", worker_connector_size_));
set_op_connector_size(j.value("opConnectorSize", op_connector_size_)); set_op_connector_size(j.value("opConnectorSize", op_connector_size_));
@ -115,9 +112,6 @@ Status ConfigManager::LoadFile(const std::string &settingsFile) {
return rc; return rc;
} }
// Setter function
void ConfigManager::set_rows_per_buffer(int32_t rows_per_buffer) { rows_per_buffer_ = rows_per_buffer; }
// Setter function // Setter function
void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) { void ConfigManager::set_num_parallel_workers(int32_t num_parallel_workers) {
num_parallel_workers_ = num_parallel_workers; num_parallel_workers_ = num_parallel_workers;

@ -74,10 +74,6 @@ class ConfigManager {
// @return Status error code // @return Status error code
Status LoadFile(const std::string &settingsFile); Status LoadFile(const std::string &settingsFile);
// getter function
// @return The rows per buffer setting
int32_t rows_per_buffer() const { return rows_per_buffer_; }
// getter function // getter function
// @return The number of workers setting // @return The number of workers setting
int32_t num_parallel_workers() const { return num_parallel_workers_; } int32_t num_parallel_workers() const { return num_parallel_workers_; }
@ -112,10 +108,6 @@ class ConfigManager {
/// \return auto_num_workers_ /// \return auto_num_workers_
bool auto_num_workers() const { return auto_num_workers_; } bool auto_num_workers() const { return auto_num_workers_; }
// setter function
// @param rows_per_buffer - The setting to apply to the config
void set_rows_per_buffer(int32_t rows_per_buffer);
// setter function // setter function
// @param num_parallel_workers - The setting to apply to the config // @param num_parallel_workers - The setting to apply to the config
void set_num_parallel_workers(int32_t num_parallel_workers); void set_num_parallel_workers(int32_t num_parallel_workers);
@ -230,7 +222,6 @@ class ConfigManager {
void set_auto_worker_config_(uint8_t cfg) { auto_worker_config_ = cfg; } void set_auto_worker_config_(uint8_t cfg) { auto_worker_config_ = cfg; }
private: private:
int32_t rows_per_buffer_;
int32_t num_parallel_workers_; int32_t num_parallel_workers_;
int32_t worker_connector_size_; int32_t worker_connector_size_;
int32_t op_connector_size_; int32_t op_connector_size_;

@ -35,7 +35,7 @@ TensorRow::TensorRow(row_id_type id, const std::initializer_list<value_type> &ls
TensorRow::TensorRow(const TensorRow &tr) TensorRow::TensorRow(const TensorRow &tr)
: id_(tr.id_), path_(tr.path_), row_(tr.row_), tensor_row_flag_(tr.tensor_row_flag_) {} : id_(tr.id_), path_(tr.path_), row_(tr.row_), tensor_row_flag_(tr.tensor_row_flag_) {}
TensorRow::TensorRow(TensorRow::TensorRowFlags flag) : tensor_row_flag_(flag) {} TensorRow::TensorRow(TensorRow::TensorRowFlags flag) : id_(kDefaultRowId), path_({}), tensor_row_flag_(flag) {}
TensorRow &TensorRow::operator=(const TensorRow &tr) { TensorRow &TensorRow::operator=(const TensorRow &tr) {
if (this == &tr) { if (this == &tr) {

@ -540,8 +540,7 @@ Status CachePerfRun::Run() {
int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count(); int64_t elapse_time = std::chrono::duration_cast<std::chrono::seconds>(end_tick - start_tick).count();
std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl; std::cout << "Epoch one (build phase) elapsed time " << elapse_time << " seconds" << std::endl;
std::cout << "Epoch one (build phase) per pipeline per worker summary. Buffer size = " << cfg_.rows_per_buffer() std::cout << "Epoch one (build phase) per pipeline per worker summary." << std::endl;
<< std::endl;
PrintEpochSummary(); PrintEpochSummary();
// Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline // Get some stat but we need to connect. The server will thinks it is the (n+1) pipeline

@ -228,17 +228,14 @@ Status CachePipelineRun::RunFirstEpoch() {
} }
std::vector<row_id_type> keys; std::vector<row_id_type> keys;
auto rows_per_buffer = cfg_.rows_per_buffer(); keys.reserve(1);
keys.reserve(rows_per_buffer);
int32_t worker_id = 0; int32_t worker_id = 0;
for (auto i = start_row_; i <= end_row_; ++i) { for (auto i = start_row_; i <= end_row_; ++i) {
keys.push_back(i); keys.push_back(i);
if (keys.size() == rows_per_buffer) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear(); keys.clear();
} }
}
if (!keys.empty()) { if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
@ -355,9 +352,8 @@ Status CachePipelineRun::WriterWorkerEntry(int32_t worker_id) {
Status CachePipelineRun::RunReadEpoch() { Status CachePipelineRun::RunReadEpoch() {
std::vector<row_id_type> keys; std::vector<row_id_type> keys;
auto rows_per_buffer = cc_->GetPrefetchSize(); // We will use prefetch size to read.
auto num_workers = cfg_.num_parallel_workers(); auto num_workers = cfg_.num_parallel_workers();
keys.reserve(rows_per_buffer); keys.reserve(1);
// Spawn workers // Spawn workers
auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1); auto f = std::bind(&CachePipelineRun::ReaderWorkerEntry, this, std::placeholders::_1);
std::vector<Task *> worker_threads; std::vector<Task *> worker_threads;
@ -381,12 +377,10 @@ Status CachePipelineRun::RunReadEpoch() {
int32_t worker_id = 0; int32_t worker_id = 0;
for (auto id : all_keys) { for (auto id : all_keys) {
keys.push_back(id); keys.push_back(id);
if (keys.size() == rows_per_buffer) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));
keys.clear(); keys.clear();
} }
}
if (!keys.empty()) { if (!keys.empty()) {
auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone)); auto blk = std::make_unique<IOBlock>(IOBlock(keys, IOBlock::kDeIoBlockNone));
RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk))); RETURN_IF_NOT_OK(io_block_queues_[worker_id++ % num_workers]->Add(std::move(blk)));

@ -31,7 +31,6 @@ BarrierOp::Builder::Builder() {
// using the various builder set methods. // using the various builder set methods.
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size(); builder_op_connector_size_ = cfg->op_connector_size();
} }
@ -39,17 +38,13 @@ Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); }
Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) { Status BarrierOp::Builder::Build(std::shared_ptr<BarrierOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<BarrierOp>(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, *ptr = std::make_shared<BarrierOp>(builder_op_connector_size_, builder_condition_name_, builder_condition_func_);
builder_condition_func_);
return Status::OK(); return Status::OK();
} }
// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions // Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions
BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, BarrierOp::BarrierOp(int32_t op_connector_size, const std::string &condition_name, py::function condition_func)
py::function condition_func)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
rows_per_buffer_(rows_per_buffer),
buffer_id_(0),
clean_up_(false), clean_up_(false),
eof_(false), eof_(false),
condition_name_(condition_name), condition_name_(condition_name),

@ -98,16 +98,13 @@ class BarrierOp : public PipelineOp {
}; };
// Constructor for BarrierOp // Constructor for BarrierOp
// @param rows_per_buffer - number of rows in output buffer
// @param op_connector_size - connector size // @param op_connector_size - connector size
// @param condition_name - the condition name associated with this operator // @param condition_name - the condition name associated with this operator
// @param condition_func - the blocking condition check per row // @param condition_func - the blocking condition check per row
// @note - currently rows_per_buffer should = 1 for barrier.
// The reason for this is having other values would complicate how the pipeline behaves with other operators // The reason for this is having other values would complicate how the pipeline behaves with other operators
// One example of such case is having batch after barrier. Batch would be waiting for data and having // One example of such case is having batch after barrier. Batch would be waiting for data and having
// rows per buffer in this case can result in hanging // rows per buffer in this case can result in hanging
BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, BarrierOp(int32_t op_connector_size, const std::string &condition_name, py::function condition_func);
py::function condition_func);
// Destructor // Destructor
~BarrierOp(); ~BarrierOp();
@ -156,10 +153,6 @@ class BarrierOp : public PipelineOp {
bool clean_up_; bool clean_up_;
// end of file state, we stop reading data and shut down // end of file state, we stop reading data and shut down
bool eof_; bool eof_;
// rows per buffer
int32_t rows_per_buffer_;
// buffer_id
int32_t buffer_id_;
// iterator to pull new rows, we only have one child // iterator to pull new rows, we only have one child
std::unique_ptr<ChildIterator> child_iterator_; std::unique_ptr<ChildIterator> child_iterator_;
// condition name, to support multiple barriers // condition name, to support multiple barriers

@ -248,7 +248,7 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
RETURN_IF_NOT_OK(out_connector_->SendEOF(workerId)); RETURN_IF_NOT_OK(out_connector_->SendEOF(workerId));
} else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) { } else if (table_pair.second.ctrl_ == batchCtrl::kNoCtrl) {
TensorRow new_row; TensorRow new_row;
RETURN_IF_NOT_OK(MakeBatchedBuffer(std::move(table_pair), &new_row)); RETURN_IF_NOT_OK(MakeBatchedRow(std::move(table_pair), &new_row));
RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), workerId)); RETURN_IF_NOT_OK(out_connector_->Add(std::move(new_row), workerId));
} }
RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair)); RETURN_IF_NOT_OK(worker_queues_[workerId]->PopFront(&table_pair));
@ -256,7 +256,7 @@ Status BatchOp::WorkerEntry(int32_t workerId) {
return Status::OK(); return Status::OK();
} }
Status BatchOp::MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) { Status BatchOp::MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row) {
RETURN_UNEXPECTED_IF_NULL(table_pair.first); RETURN_UNEXPECTED_IF_NULL(table_pair.first);
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc if (!in_col_names_.empty()) RETURN_IF_NOT_OK(MapColumns(&table_pair)); // pass it through pyfunc

@ -225,7 +225,7 @@ class BatchOp : public ParallelOp {
// Generate buffer with batched tensors // Generate buffer with batched tensors
// @return Status The status code returned // @return Status The status code returned
Status MakeBatchedBuffer(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row); Status MakeBatchedRow(std::pair<std::unique_ptr<TensorQTable>, CBatchInfo> table_pair, TensorRow *new_row);
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
// Function that calls pyfunc to perform map on batch // Function that calls pyfunc to perform map on batch

@ -45,14 +45,13 @@ Status CacheBase::Reset() {
MS_LOG(DEBUG) << Name() << " performing a self-reset."; MS_LOG(DEBUG) << Name() << " performing a self-reset.";
return Status::OK(); return Status::OK();
} }
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)), : ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0), row_cnt_(0),
num_cache_miss_(0), num_cache_miss_(0),
cache_client_(std::move(cache_client)), cache_client_(std::move(cache_client)),
rows_per_buffer_(rows_per_buf), prefetch_size_(1),
prefetch_size_(rows_per_buffer_),
num_prefetchers_(num_workers_) { num_prefetchers_(num_workers_) {
// Adjust the prefetch size based on the number of workers. // Adjust the prefetch size based on the number of workers.
auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_; auto prefetch_sz_per_thread = cache_client_->GetPrefetchSize() / num_prefetchers_;
@ -92,7 +91,7 @@ Status CacheBase::FetchSamplesToWorkers() {
row_cnt_ = 0; row_cnt_ = 0;
++wait_cnt; ++wait_cnt;
std::vector<row_id_type> keys; std::vector<row_id_type> keys;
keys.reserve(rows_per_buffer_); keys.reserve(1);
std::vector<row_id_type> prefetch_keys; std::vector<row_id_type> prefetch_keys;
prefetch_keys.reserve(prefetch_size_); prefetch_keys.reserve(prefetch_size_);
std::unique_ptr<DataBuffer> sampler_buffer; std::unique_ptr<DataBuffer> sampler_buffer;
@ -107,16 +106,12 @@ Status CacheBase::FetchSamplesToWorkers() {
// Batch enough rows for performance reason. // Batch enough rows for performance reason.
if (row_cnt_ % prefetch_size_ == 0) { if (row_cnt_ % prefetch_size_ == 0) {
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
// Now we tell the WorkerEntry to wait for them to come back. If prefetch_size_ is a multiple // Now we tell the WorkerEntry to wait for them to come back.
// of rows_per_buffer_, the keys vector will always be empty. But it can be partially filled.
// The only requirement we set up is rows_per_buffer_ is less than or equal to prefetch_size_.
for (auto row_id : prefetch_keys) { for (auto row_id : prefetch_keys) {
keys.push_back(row_id); keys.push_back(row_id);
if (keys.size() == rows_per_buffer_) {
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
keys.clear(); keys.clear();
} }
}
prefetch_keys.clear(); prefetch_keys.clear();
} }
} }
@ -127,12 +122,10 @@ Status CacheBase::FetchSamplesToWorkers() {
RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys)); RETURN_IF_NOT_OK(send_to_que(prefetch_queues_, prefetch_cnt++ % num_prefetchers_, prefetch_keys));
for (auto row_id : prefetch_keys) { for (auto row_id : prefetch_keys) {
keys.push_back(row_id); keys.push_back(row_id);
if (keys.size() == rows_per_buffer_) {
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
keys.clear(); keys.clear();
} }
} }
}
if (!keys.empty()) { if (!keys.empty()) {
RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys)); RETURN_IF_NOT_OK(send_to_que(io_block_queues_, buf_cnt++ % num_workers_, keys));
} }

@ -42,11 +42,10 @@ class CacheBase : public ParallelOp {
/// \brief Base class constructor /// \brief Base class constructor
/// \param num_workers Number of parallel workers /// \param num_workers Number of parallel workers
/// \param op_connector_size Connector size /// \param op_connector_size Connector size
/// \param rows_per_buf Number of rows per buffer
/// \param cache_client CacheClient for communication to the CacheServer /// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory /// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheBase(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor /// \brief Destructor
~CacheBase(); ~CacheBase();
@ -87,7 +86,6 @@ class CacheBase : public ParallelOp {
int64_t row_cnt_; int64_t row_cnt_;
std::atomic<int64_t> num_cache_miss_; std::atomic<int64_t> num_cache_miss_;
std::shared_ptr<CacheClient> cache_client_; std::shared_ptr<CacheClient> cache_client_;
int32_t rows_per_buffer_;
std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_; std::unique_ptr<Connector<std::vector<row_id_type>>> keys_miss_;
/// \brief Common function to register resources for interrupt /// \brief Common function to register resources for interrupt

@ -31,7 +31,6 @@ namespace dataset {
CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { CacheLookupOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers(); build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size(); build_op_connector_size_ = cfg->op_connector_size();
} }
@ -52,8 +51,8 @@ Status CacheLookupOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object and does some init on it // The builder "build" method creates the final object and does some init on it
Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) { Status CacheLookupOp::Builder::Build(std::shared_ptr<CacheLookupOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, *ptr =
build_cache_client_, build_sampler_); std::make_shared<CacheLookupOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
return Status::OK(); return Status::OK();
} }
Status CacheLookupOp::operator()() { Status CacheLookupOp::operator()() {

@ -74,7 +74,6 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
private: private:
int32_t build_num_workers_; int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_; int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_; std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<SamplerRT> build_sampler_; std::shared_ptr<SamplerRT> build_sampler_;
@ -86,9 +85,9 @@ class CacheLookupOp : public CacheBase, public SamplerRT {
/// \brief Constructor /// \brief Constructor
/// \note It takes the same argument as the base class. /// \note It takes the same argument as the base class.
/// \see CacheBase /// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheLookupOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {} : CacheBase(num_workers, op_connector_size, cache_client, sampler), SamplerRT(*(sampler.get())) {}
~CacheLookupOp() = default; ~CacheLookupOp() = default;
// As a parallel op, we override these two functions // As a parallel op, we override these two functions
Status operator()() override; Status operator()() override;

@ -33,7 +33,6 @@ namespace dataset {
CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) { CacheOp::Builder::Builder() : build_cache_client_(nullptr), build_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_num_workers_ = cfg->num_parallel_workers(); build_num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
build_op_connector_size_ = cfg->op_connector_size(); build_op_connector_size_ = cfg->op_connector_size();
} }
@ -54,17 +53,16 @@ Status CacheOp::Builder::SanityCheck() const {
// The builder "build" method creates the final object and does some init on it // The builder "build" method creates the final object and does some init on it
Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) { Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, rows_per_buffer_, build_cache_client_, *ptr = std::make_shared<CacheOp>(build_num_workers_, build_op_connector_size_, build_cache_client_, build_sampler_);
build_sampler_);
RETURN_IF_NOT_OK((*ptr)->InitCache()); RETURN_IF_NOT_OK((*ptr)->InitCache());
return Status::OK(); return Status::OK();
} }
// Constructor of CacheOp // Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)), : CacheBase(num_workers, op_connector_size, std::move(cache_client), std::move(sampler)),
num_guys_in_(0), num_guys_in_(0),
phase_(Phase::kBuildPhase) {} phase_(Phase::kBuildPhase) {}

@ -70,14 +70,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
return *this; return *this;
} }
/// \brief Setter method
/// \param rows_per_buffer
/// \return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// \brief Setter method /// \brief Setter method
/// \param sampler /// \param sampler
/// \return Builder setter method returns reference to the builder. /// \return Builder setter method returns reference to the builder.
@ -93,7 +85,6 @@ class CacheOp : public CacheBase, public RandomAccessOp {
private: private:
int32_t build_num_workers_; int32_t build_num_workers_;
int32_t rows_per_buffer_;
int32_t build_op_connector_size_; int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_; std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<SamplerRT> build_sampler_; std::shared_ptr<SamplerRT> build_sampler_;
@ -107,8 +98,8 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \note The builder class should be used to call it. /// \note The builder class should be used to call it.
/// \param num_workers The number of worker threads. /// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector. /// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf, CacheOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<CacheClient> cache_client,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler); std::shared_ptr<SamplerRT> sampler);
// Destructor // Destructor
~CacheOp(); ~CacheOp();

@ -41,7 +41,6 @@ constexpr int32_t ShuffleOp::kShuffleStateDrain;
ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) { ShuffleOp::Builder::Builder() : build_shuffle_size_(0), build_reshuffle_each_epoch_(true) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
build_op_connector_size_ = cfg->op_connector_size(); build_op_connector_size_ = cfg->op_connector_size();
build_rows_per_buffer_ = cfg->rows_per_buffer();
build_shuffle_seed_ = GetSeed(); build_shuffle_seed_ = GetSeed();
} }
@ -56,20 +55,17 @@ Status ShuffleOp::Builder::SanityCheck() const {
Status ShuffleOp::Builder::Build(std::shared_ptr<ShuffleOp> *ptr) { Status ShuffleOp::Builder::Build(std::shared_ptr<ShuffleOp> *ptr) {
RETURN_IF_NOT_OK(SanityCheck()); RETURN_IF_NOT_OK(SanityCheck());
*ptr = std::make_shared<ShuffleOp>(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_, *ptr = std::make_shared<ShuffleOp>(build_shuffle_size_, build_shuffle_seed_, build_op_connector_size_,
build_reshuffle_each_epoch_, build_rows_per_buffer_); build_reshuffle_each_epoch_);
return Status::OK(); return Status::OK();
} }
// Constructor of the ShuffleOp // Constructor of the ShuffleOp
ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch, ShuffleOp::ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch)
int32_t rows_per_buffer)
: PipelineOp(op_connector_size), : PipelineOp(op_connector_size),
shuffle_size_(shuffle_size), shuffle_size_(shuffle_size),
shuffle_seed_(shuffle_seed), shuffle_seed_(shuffle_seed),
reshuffle_each_epoch_(reset_every_epoch), reshuffle_each_epoch_(reset_every_epoch),
rng_(shuffle_seed), rng_(shuffle_seed),
buffer_counter_(0),
rows_per_buffer_(rows_per_buffer),
shuffle_buffer_(std::make_unique<TensorTable>()), shuffle_buffer_(std::make_unique<TensorTable>()),
shuffle_last_row_idx_(0), shuffle_last_row_idx_(0),
shuffle_buffer_state_(kShuffleStateInit) {} shuffle_buffer_state_(kShuffleStateInit) {}
@ -87,7 +83,6 @@ Status ShuffleOp::SelfReset() {
} }
shuffle_buffer_ = std::make_unique<TensorTable>(); shuffle_buffer_ = std::make_unique<TensorTable>();
buffer_counter_ = 0;
shuffle_last_row_idx_ = 0; shuffle_last_row_idx_ = 0;
shuffle_buffer_state_ = kShuffleStateInit; shuffle_buffer_state_ = kShuffleStateInit;
return Status::OK(); return Status::OK();
@ -104,8 +99,8 @@ void ShuffleOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
PipelineOp::Print(out, show_all); PipelineOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nShuffle size: " << shuffle_size_ << "\nRows per buffer: " << rows_per_buffer_ out << "\nShuffle size: " << shuffle_size_ << "\nShuffle buffer state: " << shuffle_buffer_state_
<< "\nShuffle buffer state: " << shuffle_buffer_state_ << "\nShuffle seed: " << shuffle_seed_ << "\n\n"; << "\nShuffle seed: " << shuffle_seed_ << "\n\n";
} }
} }

@ -121,9 +121,7 @@ class ShuffleOp : public PipelineOp {
// @param shuffle_size - The size for the shuffle buffer // @param shuffle_size - The size for the shuffle buffer
// @param shuffle_seed - The seed to use for random number generation // @param shuffle_seed - The seed to use for random number generation
// @param op_connector_size - The output connector queue size // @param op_connector_size - The output connector queue size
// @param rows_per_buffer - The requested number of rows per buffer ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch);
ShuffleOp(int32_t shuffle_size, uint32_t shuffle_seed, int32_t op_connector_size, bool reset_every_epoch,
int32_t rows_per_buffer);
// Destructor // Destructor
~ShuffleOp() = default; ~ShuffleOp() = default;
@ -183,8 +181,6 @@ class ShuffleOp : public PipelineOp {
// (ie uniform_int_distribution) because we will need to create up to |dataset| instances // (ie uniform_int_distribution) because we will need to create up to |dataset| instances
// of the distribution object in the common case of a perfect shuffle // of the distribution object in the common case of a perfect shuffle
std::mt19937_64 rng_; std::mt19937_64 rng_;
int32_t buffer_counter_; // For creating new buffer id's
int32_t rows_per_buffer_; // Number of rows to pack into output buffer
// A single (potentially large) buffer of tensor rows for performing shuffling. // A single (potentially large) buffer of tensor rows for performing shuffling.
std::unique_ptr<TensorTable> shuffle_buffer_; std::unique_ptr<TensorTable> shuffle_buffer_;
int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer int32_t shuffle_last_row_idx_; // Internal tracking of the last slot of our shuffle buffer

@ -32,7 +32,6 @@ namespace dataset {
AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { AlbumOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers(); builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size(); builder_op_connector_size_ = cfg->op_connector_size();
} }
@ -52,9 +51,8 @@ Status AlbumOp::Builder::Build(std::shared_ptr<AlbumOp> *ptr) {
MS_LOG(INFO) << "Schema file provided: " << builder_schema_file_ << "."; MS_LOG(INFO) << "Schema file provided: " << builder_schema_file_ << ".";
builder_schema_->LoadSchemaFile(builder_schema_file_, builder_columns_to_load_); builder_schema_->LoadSchemaFile(builder_schema_file_, builder_columns_to_load_);
} }
*ptr = std::make_shared<AlbumOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, *ptr = std::make_shared<AlbumOp>(builder_num_workers_, builder_dir_, builder_op_connector_size_, builder_decode_,
builder_op_connector_size_, builder_decode_, builder_extensions_, builder_extensions_, std::move(builder_schema_), std::move(builder_sampler_));
std::move(builder_schema_), std::move(builder_sampler_));
return Status::OK(); return Status::OK();
} }
@ -69,10 +67,10 @@ Status AlbumOp::Builder::SanityCheck() {
return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg); return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
} }
AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, AlbumOp::AlbumOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool do_decode,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema,
std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_wkrs, queue_size, std::move(sampler), rows_per_buffer), : MappableLeafOp(num_wkrs, queue_size, std::move(sampler)),
folder_path_(file_dir), folder_path_(file_dir),
decode_(do_decode), decode_(do_decode),
extensions_(exts), extensions_(exts),

@ -58,14 +58,6 @@ class AlbumOp : public MappableLeafOp {
/// \brief Destructor. /// \brief Destructor.
~Builder() = default; ~Builder() = default;
/// \brief Setter method
/// \param[in] rows_per_buffer
/// \return Builder setter method returns reference to the builder
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
/// \brief Setter method /// \brief Setter method
/// \param[in] size /// \param[in] size
/// \return Builder setter method returns reference to the builder /// \return Builder setter method returns reference to the builder
@ -154,16 +146,14 @@ class AlbumOp : public MappableLeafOp {
/// \brief Constructor /// \brief Constructor
/// \param[in] num_wkrs - Num of workers reading images in parallel /// \param[in] num_wkrs - Num of workers reading images in parallel
/// \param[in] rows_per_buffer Number of images (rows) in each buffer
/// \param[in] file_dir - directory of Album /// \param[in] file_dir - directory of Album
/// \param[in] queue_size - connector size /// \param[in] queue_size - connector size
/// \param[in] do_decode - decode image files /// \param[in] do_decode - decode image files
/// \param[in] exts - set of file extensions to read, if empty, read everything under the dir /// \param[in] exts - set of file extensions to read, if empty, read everything under the dir
/// \param[in] data_schema - schema of dataset /// \param[in] data_schema - schema of dataset
/// \param[in] sampler - sampler tells AlbumOp what to read /// \param[in] sampler - sampler tells AlbumOp what to read
AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir, int32_t queue_size, bool do_decode, AlbumOp(int32_t num_wkrs, std::string file_dir, int32_t queue_size, bool do_decode, const std::set<std::string> &exts,
const std::set<std::string> &exts, std::unique_ptr<DataSchema> data_schema, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor. /// \brief Destructor.
~AlbumOp() = default; ~AlbumOp() = default;
@ -273,7 +263,6 @@ class AlbumOp : public MappableLeafOp {
/// \return Status The status code returned /// \return Status The status code returned
Status ComputeColMap() override; Status ComputeColMap() override;
int32_t rows_per_buffer_;
std::string folder_path_; // directory of image folder std::string folder_path_; // directory of image folder
bool decode_; bool decode_;
std::set<std::string> extensions_; // extensions allowed std::set<std::string> extensions_; // extensions allowed

@ -34,7 +34,6 @@ namespace dataset {
CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) { CelebAOp::Builder::Builder() : builder_decode_(false), builder_sampler_(nullptr) {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
builder_num_workers_ = cfg->num_parallel_workers(); builder_num_workers_ = cfg->num_parallel_workers();
builder_rows_per_buffer_ = cfg->rows_per_buffer();
builder_op_connector_size_ = cfg->op_connector_size(); builder_op_connector_size_ = cfg->op_connector_size();
} }
@ -54,9 +53,9 @@ Status CelebAOp::Builder::Build(std::shared_ptr<CelebAOp> *op) {
// label is like this:0 1 0 0 1...... // label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK( RETURN_IF_NOT_OK(
builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1))); builder_schema_->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
*op = std::make_shared<CelebAOp>(builder_num_workers_, builder_rows_per_buffer_, builder_dir_, *op = std::make_shared<CelebAOp>(builder_num_workers_, builder_dir_, builder_op_connector_size_, builder_decode_,
builder_op_connector_size_, builder_decode_, builder_usage_, builder_extensions_, builder_usage_, builder_extensions_, std::move(builder_schema_),
std::move(builder_schema_), std::move(builder_sampler_)); std::move(builder_sampler_));
if (*op == nullptr) { if (*op == nullptr) {
return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "CelebAOp init failed."); return Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, "CelebAOp init failed.");
} }
@ -76,10 +75,10 @@ Status CelebAOp::Builder::SanityCheck() {
return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg); return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
} }
CelebAOp::CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, CelebAOp::CelebAOp(int32_t num_workers, const std::string &dir, int32_t queue_size, bool decode,
bool decode, const std::string &usage, const std::set<std::string> &exts, const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema,
std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler) std::shared_ptr<SamplerRT> sampler)
: MappableLeafOp(num_workers, queue_size, std::move(sampler), rows_per_buffer), : MappableLeafOp(num_workers, queue_size, std::move(sampler)),
folder_path_(dir), folder_path_(dir),
decode_(decode), decode_(decode),
extensions_(exts), extensions_(exts),

@ -53,14 +53,6 @@ class CelebAOp : public MappableLeafOp {
// Destructor. // Destructor.
~Builder() = default; ~Builder() = default;
// Setter method
// @param int32_t rows_per_buffer
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
builder_rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method // Setter method
// @param int32_t size // @param int32_t size
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
@ -139,13 +131,11 @@ class CelebAOp : public MappableLeafOp {
// Constructor // Constructor
// @param int32_t - num_workers - Num of workers reading images in parallel // @param int32_t - num_workers - Num of workers reading images in parallel
// @param int32_t - rows_per_buffer Number of images (rows) in each buffer
// @param std::string - dir directory of celeba dataset // @param std::string - dir directory of celeba dataset
// @param int32_t queueSize - connector queue size // @param int32_t queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read // @param std::unique_ptr<Sampler> sampler - sampler tells CelebAOp what to read
CelebAOp(int32_t num_workers, int32_t rows_per_buffer, const std::string &dir, int32_t queue_size, bool decode, CelebAOp(int32_t num_workers, const std::string &dir, int32_t queue_size, bool decode, const std::string &usage,
const std::string &usage, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, const std::set<std::string> &exts, std::unique_ptr<DataSchema> schema, std::shared_ptr<SamplerRT> sampler);
std::shared_ptr<SamplerRT> sampler);
~CelebAOp() override = default; ~CelebAOp() override = default;

@ -39,7 +39,6 @@ constexpr uint32_t kCifarImageSize = kCifarImageHeight * kCifarImageWidth * kCif
CifarOp::Builder::Builder() : sampler_(nullptr), usage_("") { CifarOp::Builder::Builder() : sampler_(nullptr), usage_("") {
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager();
num_workers_ = cfg->num_parallel_workers(); num_workers_ = cfg->num_parallel_workers();
rows_per_buffer_ = cfg->rows_per_buffer();
op_connect_size_ = cfg->op_connector_size(); op_connect_size_ = cfg->op_connector_size();
cifar_type_ = kCifar10; cifar_type_ = kCifar10;
} }
@ -65,8 +64,8 @@ Status CifarOp::Builder::Build(std::shared_ptr<CifarOp> *ptr) {
ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar))); ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &another_scalar)));
} }
*ptr = std::make_shared<CifarOp>(cifar_type_, usage_, num_workers_, rows_per_buffer_, dir_, op_connect_size_, *ptr = std::make_shared<CifarOp>(cifar_type_, usage_, num_workers_, dir_, op_connect_size_, std::move(schema_),
std::move(schema_), std::move(sampler_)); std::move(sampler_));
return Status::OK(); return Status::OK();
} }
@ -85,10 +84,9 @@ Status CifarOp::Builder::SanityCheck() {
return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg); return err_msg.empty() ? Status::OK() : Status(StatusCode::kMDUnexpectedError, __LINE__, __FILE__, err_msg);
} }
CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, CifarOp::CifarOp(CifarType type, const std::string &usage, int32_t num_works, const std::string &file_dir,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler)
std::shared_ptr<SamplerRT> sampler) : MappableLeafOp(num_works, queue_size, std::move(sampler)),
: MappableLeafOp(num_works, queue_size, std::move(sampler), rows_per_buf),
cifar_type_(type), cifar_type_(type),
usage_(usage), usage_(usage),
folder_path_(file_dir), folder_path_(file_dir),

@ -49,14 +49,6 @@ class CifarOp : public MappableLeafOp {
// Destructor. // Destructor.
~Builder() = default; ~Builder() = default;
// Setter method
// @param uint32_t rows_per_buffer
// @return Builder setter method returns reference to the builder.
Builder &SetRowsPerBuffer(int32_t rows_per_buffer) {
rows_per_buffer_ = rows_per_buffer;
return *this;
}
// Setter method // Setter method
// @param uint32_t size // @param uint32_t size
// @return Builder setter method returns reference to the builder. // @return Builder setter method returns reference to the builder.
@ -122,7 +114,6 @@ class CifarOp : public MappableLeafOp {
std::string dir_; std::string dir_;
std::string usage_; std::string usage_;
int32_t num_workers_; int32_t num_workers_;
int32_t rows_per_buffer_;
int32_t op_connect_size_; int32_t op_connect_size_;
std::shared_ptr<SamplerRT> sampler_; std::shared_ptr<SamplerRT> sampler_;
std::unique_ptr<DataSchema> schema_; std::unique_ptr<DataSchema> schema_;
@ -133,13 +124,11 @@ class CifarOp : public MappableLeafOp {
// @param CifarType type - Cifar10 or Cifar100 // @param CifarType type - Cifar10 or Cifar100
// @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all' // @param const std::string &usage - Usage of this dataset, can be 'train', 'test' or 'all'
// @param uint32_t numWorks - Num of workers reading images in parallel // @param uint32_t numWorks - Num of workers reading images in parallel
// @param uint32_t - rowsPerBuffer Number of images (rows) in each buffer
// @param std::string - dir directory of cifar dataset // @param std::string - dir directory of cifar dataset
// @param uint32_t - queueSize - connector queue size // @param uint32_t - queueSize - connector queue size
// @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read // @param std::unique_ptr<Sampler> sampler - sampler tells ImageFolderOp what to read
CifarOp(CifarType type, const std::string &usage, int32_t num_works, int32_t rows_per_buf, CifarOp(CifarType type, const std::string &usage, int32_t num_works, const std::string &file_dir, int32_t queue_size,
const std::string &file_dir, int32_t queue_size, std::unique_ptr<DataSchema> data_schema, std::unique_ptr<DataSchema> data_schema, std::shared_ptr<SamplerRT> sampler);
std::shared_ptr<SamplerRT> sampler);
// Destructor. // Destructor.
~CifarOp() = default; ~CifarOp() = default;

@ -36,7 +36,6 @@ ClueOp::Builder::Builder()
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers(); builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size(); builder_op_connector_size_ = config_manager->op_connector_size();
builder_rows_per_buffer_ = config_manager->rows_per_buffer();
builder_worker_connector_size_ = config_manager->worker_connector_size(); builder_worker_connector_size_ = config_manager->worker_connector_size();
} }
@ -67,9 +66,8 @@ Status ClueOp::Builder::Build(std::shared_ptr<ClueOp> *op) {
} }
std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>( std::shared_ptr<ClueOp> clue_op = std::make_shared<ClueOp>(
builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, ck_map, builder_num_workers_, builder_num_samples_, builder_worker_connector_size_, ck_map, builder_clue_files_list_,
builder_clue_files_list_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_op_connector_size_, builder_shuffle_files_, builder_num_devices_, builder_device_id_);
builder_device_id_);
RETURN_IF_NOT_OK(clue_op->Init()); RETURN_IF_NOT_OK(clue_op->Init());
*op = std::move(clue_op); *op = std::move(clue_op);
@ -87,11 +85,11 @@ std::vector<std::string> ClueOp::Builder::split(const std::string &s, char delim
return res; return res;
} }
ClueOp::ClueOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, ClueOp::ClueOp(int32_t num_workers, int64_t num_samples, int32_t worker_connector_size, ColKeyMap cols_to_keyword,
ColKeyMap cols_to_keyword, std::vector<std::string> clue_files_list, int32_t op_connector_size, std::vector<std::string> clue_files_list, int32_t op_connector_size, bool shuffle_files,
bool shuffle_files, int32_t num_devices, int32_t device_id) int32_t num_devices, int32_t device_id)
: NonMappableLeafOp(num_workers, worker_connector_size, rows_per_buffer, num_samples, op_connector_size, : NonMappableLeafOp(num_workers, worker_connector_size, num_samples, op_connector_size, shuffle_files, num_devices,
shuffle_files, num_devices, device_id), device_id),
clue_files_list_(std::move(clue_files_list)), clue_files_list_(std::move(clue_files_list)),
cols_to_keyword_(cols_to_keyword) {} cols_to_keyword_(cols_to_keyword) {}
@ -200,8 +198,7 @@ void ClueOp::Print(std::ostream &out, bool show_all) const {
// Call the super class for displaying any common detailed info // Call the super class for displaying any common detailed info
ParallelOp::Print(out, show_all); ParallelOp::Print(out, show_all);
// Then show any custom derived-internal stuff // Then show any custom derived-internal stuff
out << "\nRows per buffer: " << rows_per_buffer_ << "\nSample count: " << total_rows_ out << "\nSample count: " << total_rows_ << "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nDevice id: " << device_id_ << "\nNumber of devices: " << num_devices_
<< "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n"; << "\nShuffle files: " << ((shuffle_files_) ? "yes" : "no") << "\nClue files list:\n";
for (int i = 0; i < clue_files_list_.size(); ++i) { for (int i = 0; i < clue_files_list_.size(); ++i) {
out << " " << clue_files_list_[i]; out << " " << clue_files_list_[i];

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save