|
|
@ -105,6 +105,7 @@ TFReaderOp::TFReaderOp(int32_t num_workers, int32_t worker_connector_size, int64
|
|
|
|
data_schema_(std::move(data_schema)),
|
|
|
|
data_schema_(std::move(data_schema)),
|
|
|
|
filename_index_(make_unique<StringIndex>()),
|
|
|
|
filename_index_(make_unique<StringIndex>()),
|
|
|
|
load_io_block_queue_(true),
|
|
|
|
load_io_block_queue_(true),
|
|
|
|
|
|
|
|
load_jagged_connector_(true),
|
|
|
|
num_rows_(0),
|
|
|
|
num_rows_(0),
|
|
|
|
num_rows_per_shard_(0),
|
|
|
|
num_rows_per_shard_(0),
|
|
|
|
equal_rows_per_shard_(equal_rows_per_shard) {
|
|
|
|
equal_rows_per_shard_(equal_rows_per_shard) {
|
|
|
@ -203,6 +204,25 @@ Status TFReaderOp::operator()() {
|
|
|
|
buffer_id++;
|
|
|
|
buffer_id++;
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
|
|
|
RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(fetched_buffer)));
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// user specified number of rows they want, and we read enough rows
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// IOBlockQueue thread needs to:
|
|
|
|
|
|
|
|
// -stop pushing stuff to IOBlockQueue
|
|
|
|
|
|
|
|
// -call PostEndOfEpoch (will send EOE)
|
|
|
|
|
|
|
|
// -wait for reset
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Worker threads need to:
|
|
|
|
|
|
|
|
// -stop reading the file they are currently reading and throw it away
|
|
|
|
|
|
|
|
// -keep pulling, but dont read other files (eventually skips all IOBlocks and will get EOE)
|
|
|
|
|
|
|
|
//
|
|
|
|
|
|
|
|
// Master thread needs to:
|
|
|
|
|
|
|
|
// -tell IOBlockQueue thread to stop pushing
|
|
|
|
|
|
|
|
// -tell worker threads to stop reading the file tey are currently reading
|
|
|
|
|
|
|
|
// -keep pulling until EOE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// don't think we need a lock for now
|
|
|
|
|
|
|
|
load_jagged_connector_ = false;
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
|
|
|
load_io_block_queue_ = false;
|
|
|
|
load_io_block_queue_ = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -245,12 +265,14 @@ Status TFReaderOp::WorkerEntry(int32_t worker_id) {
|
|
|
|
|
|
|
|
|
|
|
|
while (!io_block->eof()) {
|
|
|
|
while (!io_block->eof()) {
|
|
|
|
if (!io_block->eoe()) {
|
|
|
|
if (!io_block->eoe()) {
|
|
|
|
std::string filename;
|
|
|
|
if (load_jagged_connector_) {
|
|
|
|
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
|
|
|
std::string filename;
|
|
|
|
int64_t start_offset = io_block->GetStartOffset();
|
|
|
|
RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_));
|
|
|
|
int64_t end_offset = io_block->GetEndOffset();
|
|
|
|
int64_t start_offset = io_block->GetStartOffset();
|
|
|
|
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
|
|
|
int64_t end_offset = io_block->GetEndOffset();
|
|
|
|
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << common::SafeCStr(filename) << ".";
|
|
|
|
RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id));
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "TFReader operator worker " << worker_id << " loaded file " << filename << ".";
|
|
|
|
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
|
|
|
std::unique_ptr<DataBuffer> eoe_buffer = mindspore::make_unique<DataBuffer>(1, DataBuffer::kDeBFlagEOE);
|
|
|
|
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
|
|
|
RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer)));
|
|
|
@ -478,6 +500,10 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off
|
|
|
|
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
|
|
|
std::unique_ptr<TensorQTable> new_tensor_table = make_unique<TensorQTable>();
|
|
|
|
|
|
|
|
|
|
|
|
while (reader.peek() != EOF) {
|
|
|
|
while (reader.peek() != EOF) {
|
|
|
|
|
|
|
|
if (!load_jagged_connector_) {
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// read length
|
|
|
|
// read length
|
|
|
|
int64_t record_length = 0;
|
|
|
|
int64_t record_length = 0;
|
|
|
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
|
|
|
(void)reader.read(reinterpret_cast<char *>(&record_length), static_cast<std::streamsize>(sizeof(int64_t)));
|
|
|
@ -599,6 +625,9 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|
|
|
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
|
|
|
// Overrides base class reset method. Cleans up any state info from it's previous execution and
|
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created.
|
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created.
|
|
|
|
Status TFReaderOp::Reset() {
|
|
|
|
Status TFReaderOp::Reset() {
|
|
|
|
|
|
|
|
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
|
|
|
|
|
|
|
load_jagged_connector_ = true;
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
|
|
|
std::unique_lock<std::mutex> lock(load_io_block_queue_mutex_);
|
|
|
|
load_io_block_queue_ = true;
|
|
|
|
load_io_block_queue_ = true;
|
|
|
|