|
|
@ -27,18 +27,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
#include "proto/example.pb.h"
|
|
|
|
#include "proto/example.pb.h"
|
|
|
|
#include "./securec.h"
|
|
|
|
#include "./securec.h"
|
|
|
|
#include "utils/ms_utils.h"
|
|
|
|
|
|
|
|
#include "minddata/dataset/core/config_manager.h"
|
|
|
|
#include "minddata/dataset/core/config_manager.h"
|
|
|
|
#include "minddata/dataset/core/global_context.h"
|
|
|
|
#include "minddata/dataset/core/global_context.h"
|
|
|
|
#include "minddata/dataset/engine/connector.h"
|
|
|
|
|
|
|
|
#include "minddata/dataset/engine/data_schema.h"
|
|
|
|
#include "minddata/dataset/engine/data_schema.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
|
|
|
#include "minddata/dataset/engine/datasetops/source/io_block.h"
|
|
|
|
#include "minddata/dataset/engine/db_connector.h"
|
|
|
|
#include "minddata/dataset/engine/db_connector.h"
|
|
|
|
#include "minddata/dataset/engine/execution_tree.h"
|
|
|
|
#include "minddata/dataset/engine/execution_tree.h"
|
|
|
|
#include "minddata/dataset/engine/jagged_connector.h"
|
|
|
|
#include "minddata/dataset/engine/jagged_connector.h"
|
|
|
|
#include "minddata/dataset/engine/opt/pass.h"
|
|
|
|
#include "minddata/dataset/engine/opt/pass.h"
|
|
|
|
#include "minddata/dataset/util/path.h"
|
|
|
|
|
|
|
|
#include "minddata/dataset/util/queue.h"
|
|
|
|
|
|
|
|
#include "minddata/dataset/util/random.h"
|
|
|
|
#include "minddata/dataset/util/random.h"
|
|
|
|
#include "minddata/dataset/util/status.h"
|
|
|
|
#include "minddata/dataset/util/status.h"
|
|
|
|
#include "minddata/dataset/util/task_manager.h"
|
|
|
|
#include "minddata/dataset/util/task_manager.h"
|
|
|
@ -387,14 +383,14 @@ Status TFReaderOp::PostEndOfEpoch(int32_t queue_index) {
|
|
|
|
return Status::OK();
|
|
|
|
return Status::OK();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool TFReaderOp::NeedPushFileToblockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
|
|
|
bool TFReaderOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset,
|
|
|
|
const int64_t &pre_count) {
|
|
|
|
const int64_t &pre_count) {
|
|
|
|
*start_offset = 0;
|
|
|
|
*start_offset = 0;
|
|
|
|
*end_offset = 0;
|
|
|
|
*end_offset = 0;
|
|
|
|
bool push = false;
|
|
|
|
bool push = false;
|
|
|
|
int64_t start_index = device_id_ * num_rows_per_shard_;
|
|
|
|
int64_t start_index = device_id_ * num_rows_per_shard_;
|
|
|
|
if (device_id_ + 1 < 0) {
|
|
|
|
if (device_id_ + 1 < 0) {
|
|
|
|
MS_LOG(ERROR) << "Device id is invalid";
|
|
|
|
MS_LOG(ERROR) << "Device id is invalid.";
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
|
|
|
int64_t end_index = (static_cast<int64_t>(device_id_) + 1) * num_rows_per_shard_;
|
|
|
@ -448,7 +444,7 @@ Status TFReaderOp::FillIOBlockShuffle(const std::vector<int64_t> &i_keys) {
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// Do an index lookup using that key to get the filename.
|
|
|
|
// Do an index lookup using that key to get the filename.
|
|
|
|
std::string file_name = (*filename_index_)[*it];
|
|
|
|
std::string file_name = (*filename_index_)[*it];
|
|
|
|
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
|
|
|
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(*it, start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
|
|
|
MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset;
|
|
|
|
MS_LOG(DEBUG) << "File name " << *it << " start offset " << start_offset << " end_offset " << end_offset;
|
|
|
@ -496,7 +492,7 @@ Status TFReaderOp::FillIOBlockNoShuffle() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
std::string file_name = it.value();
|
|
|
|
std::string file_name = it.value();
|
|
|
|
if (NeedPushFileToblockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
|
|
|
if (NeedPushFileToBlockQueue(file_name, &start_offset, &end_offset, pre_count)) {
|
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
|
|
|
auto ioBlock = std::make_unique<FilenameBlock>(it.key(), start_offset, end_offset, IOBlock::kDeIoBlockNone);
|
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
|
|
|
RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock)));
|
|
|
|
queue_index = (queue_index + 1) % num_workers_;
|
|
|
|
queue_index = (queue_index + 1) % num_workers_;
|
|
|
@ -711,7 +707,7 @@ Status TFReaderOp::LoadFeature(const std::unique_ptr<TensorQTable> *tensor_table
|
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created.
|
|
|
|
// reinitializes itself so that it can be executed again, as if it was just created.
|
|
|
|
Status TFReaderOp::Reset() {
|
|
|
|
Status TFReaderOp::Reset() {
|
|
|
|
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
|
|
|
MS_LOG(DEBUG) << Name() << " performing a self-reset.";
|
|
|
|
// start workers first, otherwise IOBlokcs will fall through if workers see it before this is set to true
|
|
|
|
// start workers first, otherwise IOBlocks will fall through if workers see it before this is set to true
|
|
|
|
load_jagged_connector_ = true;
|
|
|
|
load_jagged_connector_ = true;
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
@ -767,6 +763,14 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor ¤t_col, const dataeng
|
|
|
|
new_pad_size *= cur_shape[i];
|
|
|
|
new_pad_size *= cur_shape[i];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
pad_size = new_pad_size;
|
|
|
|
pad_size = new_pad_size;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
if (cur_shape.known() && cur_shape.NumOfElements() != max_size) {
|
|
|
|
|
|
|
|
std::string err_msg = "Shape in schema's column '" + current_col.name() + "' is incorrect." +
|
|
|
|
|
|
|
|
"\nshape received: " + cur_shape.ToString() +
|
|
|
|
|
|
|
|
"\ntotal elements in shape received: " + std::to_string(cur_shape.NumOfElements()) +
|
|
|
|
|
|
|
|
"\nexpected total elements in shape: " + std::to_string(max_size);
|
|
|
|
|
|
|
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|