Added GetDatasetSize

pull/7662/head
Mahdi 5 years ago
parent 7276198580
commit 0e03f5b0dd

@ -179,6 +179,26 @@ Dataset::Dataset() {
rows_per_buffer_ = cfg->rows_per_buffer();
connector_que_size_ = cfg->op_connector_size();
worker_connector_size_ = cfg->worker_connector_size();
tree_getters_ = std::make_shared<TreeGetters>();
}
int64_t Dataset::GetDatasetSize() {
int64_t dataset_size;
auto ds = shared_from_this();
Status rc;
std::unique_ptr<RuntimeContext> runtime_context = std::make_unique<RuntimeContext>();
rc = runtime_context->Init();
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
return -1;
}
rc = tree_getters_->Init(ds);
if (rc.IsError()) {
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
return -1;
}
rc = tree_getters_->GetDatasetSize(&dataset_size);
return rc.IsError() ? -1 : dataset_size;
}
// Constructor to initialize the cache

@ -351,4 +351,32 @@ Status SaveToDisk::TransfromTensor(const unsigned char *src, const TensorShape &
}
#endif
TreeGetters::TreeGetters() {
tree_adapter_ = std::make_unique<TreeAdapter>();
dataset_size_ = -1;
}
Status TreeGetters::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d), 1); }
Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ == -1) {
std::shared_ptr<DatasetOp> root = std::shared_ptr<DatasetOp>(tree_adapter_->GetRoot());
CHECK_FAIL_RETURN_UNEXPECTED(root != nullptr, "Root is a nullptr.");
RETURN_IF_NOT_OK(root->GetDatasetSize(dataset_size));
dataset_size_ = *dataset_size;
TensorRow row;
if (*dataset_size == -1) {
int64_t num_rows = 0;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
while (row.size() != 0) {
num_rows++;
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&row));
}
dataset_size_ = num_rows;
}
}
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace mindspore::dataset

@ -152,9 +152,10 @@ class ToDevice : public TreeConsumer {
/// Consumer that is used to get some pipeline information
class TreeGetters : public TreeConsumer {
Status GetDatasetSize(int32_t *size) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
public:
TreeGetters();
Status Init(std::shared_ptr<api::Dataset> d) override;
Status GetDatasetSize(int64_t *size);
Status GetBatchSize(int32_t *batch_size) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
@ -173,6 +174,11 @@ class TreeGetters : public TreeConsumer {
Status GetOutputNames(std::vector<std::string> *names) {
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
}
std::string Name() override { return "TreeGetters"; }
private:
int64_t dataset_size_;
};
} // namespace mindspore::dataset

@ -531,5 +531,30 @@ Status BatchOp::ComputeColMap() {
return Status::OK();
}
Status BatchOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
#ifdef ENABLE_PYTHON
if (batch_size_func_) {
*dataset_size = -1;
return Status::OK();
}
#endif
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && start_batch_size_ > 0) {
if (drop_) {
num_rows = floor(num_rows / start_batch_size_);
} else {
num_rows = ceil(num_rows / start_batch_size_);
}
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -219,6 +219,11 @@ class BatchOp : public ParallelOp {
static Status PadColumns(std::unique_ptr<TensorQTable> *table, const PadInfo &pad_info,
const std::unordered_map<std::string, int32_t> &column_name_id_map);
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
protected:
Status ComputeColMap() override;

@ -231,5 +231,13 @@ Status BucketBatchByLengthOp::ComputeColMap() {
}
return Status::OK();
}
// Get Dataset size
Status BucketBatchByLengthOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -112,6 +112,11 @@ class BucketBatchByLengthOp : public PipelineOp {
std::string Name() const override { return kBucketBatchByLengthOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// << Stream output operator overload
// @notes This allows you to write the debug print info using stream operators
// @param out - reference to the output stream being overloaded

@ -195,5 +195,13 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}
// Get Dataset size
Status ConcatOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);

@ -50,7 +50,8 @@ DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler
op_num_repeats_per_epoch_(kInfiniteRepeat),
op_current_repeats_(0),
op_current_epochs_(0),
out_connector_(nullptr) {
out_connector_(nullptr),
dataset_size_(-1) {
// The operator starts out with an invalid operator id. The only way to
// get it out of invalid state is to assign the operator to an execution tree.
}
@ -290,6 +291,17 @@ Status DatasetOp::GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t wo
return Status::OK();
}
// Gets the dataset size
Status DatasetOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
CHECK_FAIL_RETURN_UNEXPECTED(child_.size() == 1, "Can't get the dataset size for the current tree.");
return child_[0]->GetDatasetSize(dataset_size);
}
// Performs handling for when an eoe message is received.
// The base class implementation simply flows the eoe message to output. Derived classes
// may override if they need to perform special eoe handling.

@ -179,6 +179,10 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// \return Status - The error code return
Status GetNextInput(std::unique_ptr<DataBuffer> *p_buffer, int32_t worker_id = 0, int32_t child_index = 0);
/// \brief Gets the dataset size
/// \return Status - The status code return
virtual Status GetDatasetSize(int64_t *dataset_size);
/// \brief Performs handling for when an eoe message is received.
/// The base class implementation simply flows the eoe message to output. Derived classes
/// may override if they need to perform special eoe handling.
@ -406,6 +410,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
std::unordered_map<std::string, int32_t> column_name_id_map_; // Mapping between col index and col name
std::mutex column_name_map_mutex_; // For protecting shared access to the column map
CallbackManager callback_manager_; // Manages callbacks associated with a DatasetOp
int64_t dataset_size_; // Size of the dataset
private:
/// Sets the operator id.

@ -278,5 +278,14 @@ Status FilterOp::PreAccept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->PreRunOnNode(shared_from_base<FilterOp>(), modified);
}
// Get Dataset size
Status FilterOp::GetDatasetSize(int64_t *dataset_size) {
// We are returning -1 because we can't easily calculate GetDatasetSize. Returning -1 will make TreeGetters to
// iterate over the dataset and count the size
*dataset_size = dataset_size_;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -137,6 +137,11 @@ class FilterOp : public ParallelOp {
// @return Name of the current Op
std::string Name() const override { return kFilterOp; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// predicate_func python callable which returns a boolean value.
py::function predicate_func_;

@ -191,5 +191,21 @@ Status RepeatOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<RepeatOp>(), modified);
}
// Get Dataset size
Status RepeatOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0 || num_repeats_ == -1) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
if (num_rows > 0 && num_repeats_ > 0) {
num_rows = num_rows * num_repeats_;
}
*dataset_size = num_rows;
dataset_size_ = num_rows;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -133,6 +133,11 @@ class RepeatOp : public PipelineOp {
/// \@return Status - The error code return
Status Reset() override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// \brief Adds an operator to the repeat ops list of tracked leaf/eoe nodes
// \param[in] eoe_op The input leaf/eoe operator to add to the list
void AddToEoeList(std::shared_ptr<DatasetOp> eoe_op) { eoe_ops_.push_back(std::move(eoe_op)); }

@ -134,5 +134,21 @@ Status SkipOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<SkipOp>(), modified);
}
// Get Dataset size
Status SkipOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows;
RETURN_IF_NOT_OK(child_[0]->GetDatasetSize(&num_rows));
*dataset_size = 0;
if (max_skips_ >= 0 && max_skips_ < num_rows) {
*dataset_size = num_rows - max_skips_;
}
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -80,6 +80,11 @@ class SkipOp : public PipelineOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
// Op name getter
// @return Name of the current Op
std::string Name() const override { return kSkipOp; }

@ -15,6 +15,7 @@
*/
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include "minddata/dataset/core/config_manager.h"
@ -445,5 +446,64 @@ Status CelebAOp::ComputeColMap() {
}
return Status::OK();
}
// Get Dataset size
Status CelebAOp::GetDatasetSize(int64_t *dataset_size) {
int64_t num_rows, sample_size;
std::string line;
Path folder_path(folder_path_);
std::ifstream attr_file((folder_path / "list_attr_celeba.txt").toString());
if (!attr_file.is_open()) {
std::string attr_file_name = (folder_path / "list_attr_celeba.txt").toString();
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba attr file: " + attr_file_name);
}
std::string rows_num;
(void)getline(attr_file, rows_num);
try {
num_rows = static_cast<int64_t>(std::stoul(rows_num)); // First line is rows number in attr file
} catch (std::invalid_argument &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, invalid argument: " + rows_num);
} catch (std::out_of_range &e) {
RETURN_STATUS_UNEXPECTED(
"Invalid data, failed to convert rows_num from attr_file to unsigned long, out of range: " + rows_num);
}
if (usage_ != "all") {
int64_t partition_num = 0;
char usage_type;
if (usage_ == "train") {
usage_type = '0';
} else {
if (usage_ == "valid") {
usage_type = '1';
} else {
if (usage_ == "test")
usage_type = '2';
else
RETURN_STATUS_UNEXPECTED("Invalid usage.");
}
}
if (!partition_file_.is_open()) {
partition_file_.open((folder_path / "list_eval_partition.txt").toString());
}
if (partition_file_.is_open()) {
while (getline(partition_file_, line)) {
int start = line.find(' ');
if (line.at(start + 1) == usage_type) {
partition_num++;
}
}
} else {
std::string partition_file_name = "list_eval_partition.txt";
RETURN_STATUS_UNEXPECTED("Invalid file, failed to open Celeba partition file: " + partition_file_name);
}
num_rows = std::min(num_rows, partition_num);
}
sample_size = sampler_->GetNumSamples();
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -179,6 +179,11 @@ class CelebAOp : public ParallelOp, RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CelebAOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Called first when function is called
// @return

@ -507,5 +507,21 @@ Status CifarOp::ComputeColMap() {
}
return Status::OK();
}
// Get Dataset size
Status CifarOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
num_rows = num_rows_;
if (num_rows_ <= 0)
RETURN_IF_NOT_OK(CountTotalRows(folder_path_, usage_, cifar_type_ == CifarType::kCifar10, &num_rows));
sample_size = sampler_->GetNumSamples();
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -175,6 +175,11 @@ class CifarOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CifarOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

@ -15,6 +15,7 @@
*/
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include <fstream>
@ -563,5 +564,20 @@ Status ClueOp::Accept(NodePass *p, bool *modified) {
// Downcast shared pointer then call visitor
return p->RunOnNode(shared_from_base<ClueOp>(), modified);
}
// Get Dataset size
Status ClueOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows, sample_size;
if (num_rows_per_shard_ <= 0) RETURN_IF_NOT_OK(CalculateNumRowsPerShard());
sample_size = num_samples_;
num_rows = num_rows_per_shard_;
*dataset_size = sample_size > 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -193,6 +193,11 @@ class ClueOp : public ParallelOp {
// @return - Status of the node visit.
Status Accept(NodePass *p, bool *modified) override;
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// The entry point for when workers are launched.
// @param worker_id - the id of the worker that is executing this function.

@ -679,5 +679,36 @@ Status CocoOp::ComputeColMap() {
}
return Status::OK();
}
// Get Dataset size
Status CocoOp::GetDatasetSize(int64_t *dataset_size) {
if (dataset_size_ > 0) {
*dataset_size = dataset_size_;
return Status::OK();
}
int64_t num_rows = 0, sample_size;
std::string task_type;
switch (task_type_) {
case TaskType::Detection:
task_type = "Detection";
break;
case TaskType::Keypoint:
task_type = "Keypoint";
break;
case TaskType::Panoptic:
task_type = "Panoptic";
break;
case TaskType::Stuff:
task_type = "Stuff";
break;
}
if (image_ids_.size() == 0) {
RETURN_IF_NOT_OK(CountTotalRows(image_folder_path_, annotation_path_, task_type, &num_rows));
}
sample_size = sampler_->GetNumSamples();
*dataset_size = sample_size != 0 ? std::min(num_rows, sample_size) : num_rows;
dataset_size_ = *dataset_size;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -209,6 +209,11 @@ class CocoOp : public ParallelOp, public RandomAccessOp {
// @return Name of the current Op
std::string Name() const override { return "CocoOp"; }
/// \brief Base-class override for GetDatasetSize
/// \param[out] dataset_size the size of the dataset
/// \return Status of the function
Status GetDatasetSize(int64_t *dataset_size) override;
private:
// Initialize Sampler, calls sampler->Init() within
// @return Status - The error code return

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save