Add NonMappableLeafOp and unify TfReader and TextFile, CSV, Clue and CSV

pull/13291/head
hesham 4 years ago
parent 1edbbe56ba
commit c877ac255b

@ -15,6 +15,7 @@ set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
csv_op.cc csv_op.cc
album_op.cc album_op.cc
mappable_leaf_op.cc mappable_leaf_op.cc
nonmappable_leaf_op.cc
) )
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES

@ -26,6 +26,8 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -34,7 +36,7 @@ using ColKeyMap = std::map<std::string, std::vector<std::string>>;
class JaggedConnector; class JaggedConnector;
class ClueOp : public ParallelOp { class ClueOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@ -150,18 +152,7 @@ class ClueOp : public ParallelOp {
// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init(); Status Init() override;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files. // Get total rows in files.
// @param files - all clue files. // @param files - all clue files.
@ -178,72 +169,28 @@ class ClueOp : public ParallelOp {
std::string Name() const override { return "ClueOp"; } std::string Name() const override { return "ClueOp"; }
private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Reads a clue file and loads the data into multiple buffers. // Reads a clue file and loads the data into multiple buffers.
// @param file - the file to read. // @param file - the file to read.
// @param start_offset - the start offset of file. // @param start_offset - the start offset of file.
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard(); Status CalculateNumRowsPerShard() override;
// Count number of rows in each file. // Count number of rows in each file.
// @param filename - clue file name. // @param filename - clue file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// @return Status - the error code returned. // @return Status - the error code returned.
Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t); Status GetValue(const nlohmann::json &js, std::vector<std::string> key_chain, std::shared_ptr<Tensor> *t);
@ -251,22 +198,7 @@ class ClueOp : public ParallelOp {
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> clue_files_list_; std::vector<std::string> clue_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
ColKeyMap cols_to_keyword_; ColKeyMap cols_to_keyword_;
}; };
} // namespace dataset } // namespace dataset

@ -26,6 +26,8 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/io_block.h" #include "minddata/dataset/engine/datasetops/source/io_block.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/engine/jagged_connector.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -34,7 +36,7 @@ const size_t CSV_BUFFER_SIZE = 4096;
using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;
class JaggedConnector; class JaggedConnector;
class CsvOp : public ParallelOp { class CsvOp : public NonMappableLeafOp {
public: public:
enum RecordType : uint8_t { INT = 0, FLOAT, STRING }; enum RecordType : uint8_t { INT = 0, FLOAT, STRING };
@ -63,7 +65,7 @@ class CsvOp : public ParallelOp {
public: public:
CsvParser() = delete; CsvParser() = delete;
CsvParser(int32_t worker_id, std::shared_ptr<JaggedConnector> connector, int64_t rows_per_buffer, char field_delim, CsvParser(int32_t worker_id, JaggedConnector *connector, int64_t rows_per_buffer, char field_delim,
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path); std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default, std::string file_path);
~CsvParser() = default; ~CsvParser() = default;
@ -125,7 +127,7 @@ class CsvOp : public ParallelOp {
int CatchException(int c); int CatchException(int c);
int32_t worker_id_; int32_t worker_id_;
std::shared_ptr<JaggedConnector> buffer_connector_; JaggedConnector *buffer_connector_;
int64_t csv_rows_per_buffer_; int64_t csv_rows_per_buffer_;
const char csv_field_delim_; const char csv_field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_;
@ -274,18 +276,7 @@ class CsvOp : public ParallelOp {
// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init(); Status Init() override;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files. // Get total rows in files.
// @param files - all csv files. // @param files - all csv files.
@ -303,11 +294,6 @@ class CsvOp : public ParallelOp {
std::string Name() const override { return "CsvOp"; } std::string Name() const override { return "CsvOp"; }
private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param line - the content of the row. // @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in. // @param tensor_table - the tensor table to put the parsed data in.
@ -321,61 +307,22 @@ class CsvOp : public ParallelOp {
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
const int32_t worker_id);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_offset - If file contains the first sample of data.
// @param end_offset - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard(); Status CalculateNumRowsPerShard() override;
// Count number of rows in each file. // Count number of rows in each file.
// @param filename - csv file name. // @param filename - csv file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
@ -394,22 +341,7 @@ class CsvOp : public ParallelOp {
// @return bool - whether column name identical in all CSV files // @return bool - whether column name identical in all CSV files
bool ColumnNameValidate(); bool ColumnNameValidate();
int32_t device_id_;
bool shuffle_files_;
bool finished_reading_dataset_;
int32_t num_devices_;
int64_t rows_per_buffer_;
bool load_io_block_queue_;
int64_t num_rows_per_shard_;
int64_t all_num_rows_;
int64_t num_samples_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
std::vector<std::string> csv_files_list_; std::vector<std::string> csv_files_list_;
WaitPost io_block_queue_wait_post_;
std::shared_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
bool load_jagged_connector_;
char field_delim_; char field_delim_;
std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_; std::vector<std::shared_ptr<CsvOp::BaseRecord>> column_default_list_;
std::vector<std::string> column_name_list_; std::vector<std::string> column_name_list_;

@ -0,0 +1,177 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_
#include <algorithm>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include <utility>
#include <map>
#include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h"
namespace dataengine {
class Example;
class Feature;
class BytesList;
} // namespace dataengine
namespace mindspore {
namespace dataset {
template <typename T>
class Queue;
template <class T>
class Connector;
class JaggedConnector;
class FilenameBlock;
using StringIndex = AutoIndexObj<std::string>;
class NonMappableLeafOp : public ParallelOp {
public:
// Constructor of TFReaderOp (2)
// @note The builder class should be used to call this constructor.
// @param num_workers - number of worker threads reading data from tf_file files.
// @param worker_connector_size - size of each internal queue.
// @param rows_per_buffer - number of rows that a full buffer will contain.
// @param total_num_rows - Number of rows to read
// @param dataset_files_list - list of filepaths for the dataset files.
// @param op_connector_size - size of each queue in the connector that the child operator pulls from.
// @param columns_to_load - the names of the columns to load data from.
// @param shuffle_files - whether or not to shuffle the files before reading data.
// @param equal_rows_per_shard - whether or not to get equal rows for each process.
NonMappableLeafOp(int32_t num_workers, int32_t worker_connector_size, int64_t rows_per_buffer, int64_t total_num_rows,
int32_t op_connector_size, bool shuffle_files, int32_t num_devices, int32_t device_id);
// Default destructor
~NonMappableLeafOp() = default;
// Instantiates the internal queues and connectors.
// @return Status - the error code returned.
virtual Status Init() = 0;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }
// Op name getter
// @return Name of the current Op
std::string Name() const override { return "NonMappableLeafOp"; }
protected:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read.
// @param start_offset - the start offset of file.
// @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
virtual Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) = 0;
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Calculate number of rows in each shard.
// @return Status - the error code returned.
virtual Status CalculateNumRowsPerShard() = 0;
static void ShuffleKeys(std::vector<int64_t> *i_keys, uint32_t seed);
// Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned.
virtual Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) = 0;
int32_t device_id_;
int32_t num_devices_;
bool load_jagged_connector_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
std::map<std::string, int64_t> filename_numrows_;
bool finished_reading_dataset_;
int64_t total_rows_;
int64_t rows_per_buffer_;
WaitPost io_block_queue_wait_post_;
bool load_io_block_queue_;
std::mutex load_io_block_queue_mutex_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
bool shuffle_files_;
int64_t num_rows_per_shard_;
int64_t num_rows_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_DATASETOPS_SOURCE_NONMAPPABLE_LEAF_OP_H_

@ -27,6 +27,7 @@
#include "minddata/dataset/util/auto_index.h" #include "minddata/dataset/util/auto_index.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
#include "minddata/dataset/util/queue.h" #include "minddata/dataset/util/queue.h"
#include "minddata/dataset/util/wait_post.h" #include "minddata/dataset/util/wait_post.h"
#include "minddata/dataset/engine/jagged_connector.h" #include "minddata/dataset/engine/jagged_connector.h"
@ -35,7 +36,7 @@ namespace mindspore {
namespace dataset { namespace dataset {
using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;
class TextFileOp : public ParallelOp { class TextFileOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@ -150,18 +151,7 @@ class TextFileOp : public ParallelOp {
// Instantiates the internal queues and connectors // Instantiates the internal queues and connectors
// @return Status - the error code returned // @return Status - the error code returned
Status Init(); Status Init() override;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Get total rows in files. // Get total rows in files.
// @param files - all text files. // @param files - all text files.
@ -178,11 +168,6 @@ class TextFileOp : public ParallelOp {
std::vector<std::string> FileNames() { return text_files_list_; } std::vector<std::string> FileNames() { return text_files_list_; }
private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param line - the content of the row. // @param line - the content of the row.
// @param tensor_table - the tensor table to put the parsed data in. // @param tensor_table - the tensor table to put the parsed data in.
@ -196,82 +181,28 @@ class TextFileOp : public ParallelOp {
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, Status LoadFile(const std::string &file, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
const int32_t worker_id);
// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard(); Status CalculateNumRowsPerShard() override;
// Count number of rows in each file. // Count number of rows in each file.
// @param filename - text file name. // @param filename - text file name.
// @return int64_t - the total number of rows in file. // @return int64_t - the total number of rows in file.
int64_t CountTotalRows(const std::string &file); int64_t CountTotalRows(const std::string &file);
// Notifies the thread which called FillIoBlockQueue to resume execution
void NotifyToFillIOBlockQueue();
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Fill the IOBlockQueue. // Fill the IOBlockQueue.
// @para i_keys - keys of file to fill to the IOBlockQueue // @para i_keys - keys of file to fill to the IOBlockQueue
// @return Status - the error code returned. // @return Status - the error code returned.
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys); Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> text_files_list_; std::vector<std::string> text_files_list_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
int64_t all_num_rows_;
int64_t num_rows_per_shard_;
std::map<std::string, int64_t> filename_numrows_;
std::unique_ptr<StringIndex> filename_index_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
bool finished_reading_dataset_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
}; };
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

@ -31,6 +31,7 @@
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/data_schema.h" #include "minddata/dataset/engine/data_schema.h"
#include "minddata/dataset/engine/datasetops/parallel_op.h" #include "minddata/dataset/engine/datasetops/parallel_op.h"
#include "minddata/dataset/engine/datasetops/source/nonmappable_leaf_op.h"
namespace dataengine { namespace dataengine {
class Example; class Example;
@ -51,7 +52,7 @@ class FilenameBlock;
using StringIndex = AutoIndexObj<std::string>; using StringIndex = AutoIndexObj<std::string>;
class TFReaderOp : public ParallelOp { class TFReaderOp : public NonMappableLeafOp {
public: public:
class Builder { class Builder {
public: public:
@ -195,21 +196,7 @@ class TFReaderOp : public ParallelOp {
// Instantiates the internal queues and connectors. // Instantiates the internal queues and connectors.
// @return Status - the error code returned. // @return Status - the error code returned.
Status Init(); Status Init() override;
// Class functor operator () override.
// All dataset operators operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work
// @return Status - the error code returned.
Status operator()() override;
// Overrides base class reset method. Cleans up any state info from it's previous execution and
// reinitializes itself so that it can be executed again, as if it was just created.
// @return Status - the error code returned.
Status Reset() override;
// Getter method
int64_t rows_per_buffer() const { return rows_per_buffer_; }
// Reads all the provided tf_file files and counts the total number of rows. filenames will // Reads all the provided tf_file files and counts the total number of rows. filenames will
// first be sectioned into equal parts, then sections are read in parallel. If threads is // first be sectioned into equal parts, then sections are read in parallel. If threads is
@ -233,48 +220,13 @@ class TFReaderOp : public ParallelOp {
static bool ValidateFirstRowCrc(const std::string &filename); static bool ValidateFirstRowCrc(const std::string &filename);
private: private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned.
Status WorkerEntry(int32_t worker_id) override;
// Pushes a control indicator onto the IOBlockQueue for each worker to consume.
// When the worker pops this control indicator, it will shut itself down gracefully.
// @return Status - the error code returned.
Status PostEndOfData();
// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker
// pops this control indicator, it will wait until the next epoch starts and then resume execution.
// @return Status - the error code returned.
Status PostEndOfEpoch(int32_t queue_index);
// Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue.
// @return Status - the error code returned.
Status WaitToFillIOBlockQueue();
// Notifies the thread which called WaitToFillIOBlockQueue to resume execution.
void NotifyToFillIOBlockQueue();
// Pops an element from a queue in IOBlockQueue.
// @param index - the index of the queue to pop from.
// @param out_block - the popped element.
// @return Status - the error code returned.
Status PopIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> *out_block);
// Pushes an element to a queue in IOBlockQueue.
// @param index - the index of the queue to push to.
// @param io_block - the element to push onto the queue.
// @return Status - the error code returned.
Status PushIoBlockQueue(int32_t index, std::unique_ptr<FilenameBlock> &&io_block);
// Reads a tf_file file and loads the data into multiple buffers. // Reads a tf_file file and loads the data into multiple buffers.
// @param filename - the tf_file file to read. // @param filename - the tf_file file to read.
// @param start_offset - the start offset of file. // @param start_offset - the start offset of file.
// @param end_offset - the end offset of file. // @param end_offset - the end offset of file.
// @param worker_id - the id of the worker that is executing this function. // @param worker_id - the id of the worker that is executing this function.
// @return Status - the error code returned. // @return Status - the error code returned.
Status LoadFile(const std::string &filename, const int64_t start_offset, const int64_t end_offset, Status LoadFile(const std::string &filename, int64_t start_offset, int64_t end_offset, int32_t worker_id) override;
const int32_t &worker_id);
// Parses a single row and puts the data into a tensor table. // Parses a single row and puts the data into a tensor table.
// @param tf_file - the row to be parsed. // @param tf_file - the row to be parsed.
@ -339,6 +291,11 @@ class TFReaderOp : public ParallelOp {
// @return int63_t - the total number of rows of files read. // @return int63_t - the total number of rows of files read.
static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin, static int64_t CountTotalRowsSectioned(const std::vector<std::string> &filenames, const int64_t begin,
const int64_t end); const int64_t end);
protected:
Status FillIOBlockQueue(const std::vector<int64_t> &i_keys) override;
private:
// Fill IO block queue if shuffle is true // Fill IO block queue if shuffle is true
// @param i_keys - shuffle keys. // @param i_keys - shuffle keys.
// @return Status - the error code returned. // @return Status - the error code returned.
@ -351,43 +308,18 @@ class TFReaderOp : public ParallelOp {
*/ */
Status FillIOBlockNoShuffle(); Status FillIOBlockNoShuffle();
// Select file and push it to the block queue.
// @param file_name - File name.
// @param start_file - If file contains the first sample of data.
// @param end_file - If file contains the end sample of data.
// @param pre_count - Total rows of previous files.
// @return Status - the error code returned.
bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
const int64_t &pre_count);
// Calculate number of rows in each shard. // Calculate number of rows in each shard.
// @return Status - the error code returned. // @return Status - the error code returned.
Status CalculateNumRowsPerShard(); Status CalculateNumRowsPerShard() override;
// Private function for computing the assignment of the column name map. // Private function for computing the assignment of the column name map.
// @return - Status // @return - Status
Status ComputeColMap() override; Status ComputeColMap() override;
int32_t device_id_;
int32_t num_devices_;
int64_t rows_per_buffer_;
int64_t total_rows_;
std::vector<std::string> dataset_files_list_; std::vector<std::string> dataset_files_list_;
std::vector<std::string> columns_to_load_; std::vector<std::string> columns_to_load_;
bool finished_reading_dataset_;
bool shuffle_files_;
std::unique_ptr<DataSchema> data_schema_; std::unique_ptr<DataSchema> data_schema_;
std::unique_ptr<StringIndex> filename_index_;
bool load_io_block_queue_;
bool load_jagged_connector_;
std::unique_ptr<JaggedConnector> jagged_buffer_connector_;
QueueList<std::unique_ptr<FilenameBlock>> io_block_queues_;
WaitPost io_block_queue_wait_post_;
std::mutex load_io_block_queue_mutex_;
std::map<std::string, int64_t> filename_numrows_;
int64_t num_rows_;
int64_t num_rows_per_shard_;
bool equal_rows_per_shard_; bool equal_rows_per_shard_;
}; };
} // namespace dataset } // namespace dataset

Loading…
Cancel
Save