return status from sampler IR functions

pull/12363/head
mohammad 4 years ago
parent ced5575387
commit 9e6fcc7f23

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -51,10 +51,12 @@ Status DatasetCacheImpl::CreateCacheLookupOp(int32_t num_workers, std::shared_pt
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.SetSampler(sampler_rt)
.Build(&lookup_op));
*ds = lookup_op;

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,10 +53,13 @@ Status PreBuiltDatasetCache::CreateCacheLookupOp(int32_t num_workers, std::share
std::shared_ptr<SamplerObj> sampler) {
CHECK_FAIL_RETURN_UNEXPECTED(cache_client_ != nullptr, "Cache client has not been created yet.");
std::shared_ptr<CacheLookupOp> lookup_op = nullptr;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler->SamplerBuild(&sampler_rt));
RETURN_IF_NOT_OK(CacheLookupOp::Builder()
.SetNumWorkers(num_workers)
.SetClient(cache_client_)
.SetSampler(sampler->SamplerBuild())
.SetSampler(sampler_rt)
.Build(&lookup_op));
*ds = lookup_op;

@ -74,10 +74,11 @@ std::shared_ptr<SamplerObj> CacheLookupNode::SamplerCopy() {
return std::static_pointer_cast<SamplerObj>(lookup_node_copy_);
}
std::shared_ptr<SamplerRT> CacheLookupNode::SamplerBuild() {
Status CacheLookupNode::SamplerBuild(std::shared_ptr<SamplerRT> *out) {
// Runtime cache lookup op should already been built, so we just return it here
auto lookup_op = std::dynamic_pointer_cast<CacheLookupOp>(lookup_op_);
return std::shared_ptr<SamplerRT>(lookup_op);
*out = std::shared_ptr<SamplerRT>(lookup_op);
return Status::OK();
}
} // namespace dataset

@ -48,8 +48,9 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
std::shared_ptr<DatasetNode> Copy() override;
/// \brief a base class override function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
std::shared_ptr<SamplerRT> SamplerBuild() override;
/// \param[out] out Shared pointer to the newly created Sampler
/// \return The Status code of the function. It returns OK status if sampler is created successfully.
Status SamplerBuild(std::shared_ptr<SamplerRT> *out) override;
/// \brief a base class override function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj

@ -52,7 +52,9 @@ Status CacheNode::Build(std::vector<std::shared_ptr<DatasetOp>> *node_ops) {
RETURN_IF_NOT_OK(cache_->Build());
std::shared_ptr<DatasetOp> cache_op = nullptr;
RETURN_IF_NOT_OK(cache_->CreateCacheOp(num_workers_, &cache_op));
cache_op->SetSampler(sampler_->SamplerBuild());
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
cache_op->SetSampler(sampler_rt);
cache_op->set_total_repeats(GetTotalRepeats());
cache_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cache_op);

@ -95,8 +95,10 @@ Status ConcatNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
// calculate the size of the shard
int64_t shard_dataset_size = 0;
std::shared_ptr<SamplerRT> sampler_rt_base = nullptr;
if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt_base));
std::shared_ptr<DistributedSamplerRT> sampler_rt =
sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_->SamplerBuild()) : nullptr;
sampler_ ? std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_rt_base) : nullptr;
if (sampler_rt != nullptr) {
sampler_rt->SetNumRowsInDataset(total_dataset_size);
sampler_rt->InitSampler();
@ -123,8 +125,10 @@ Status ConcatNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
if (children_flag_and_nums_.empty() || children_start_end_index_.empty()) {
op = std::make_shared<ConcatOp>(connector_que_size_);
} else {
op = std::make_shared<ConcatOp>(connector_que_size_, sampler_->SamplerBuild(), children_flag_and_nums_,
children_start_end_index_);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
op =
std::make_shared<ConcatOp>(connector_que_size_, sampler_rt, children_flag_and_nums_, children_start_end_index_);
}
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());

@ -71,9 +71,11 @@ Status AlbumNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
// Argument that is not exposed to user in the API.
std::set<std::string> extensions = {};
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto album_op = std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_,
extensions, std::move(schema), std::move(sampler_->SamplerBuild()));
extensions, std::move(schema), std::move(sampler_rt));
album_op->set_total_repeats(GetTotalRepeats());
album_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(album_op);

@ -66,10 +66,11 @@ Status CelebANode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
// label is like this:0 1 0 0 1......
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto celeba_op =
std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, decode_, usage_,
extensions_, std::move(schema), std::move(sampler_->SamplerBuild()));
auto celeba_op = std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
decode_, usage_, extensions_, std::move(schema), std::move(sampler_rt));
celeba_op->set_total_repeats(GetTotalRepeats());
celeba_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(celeba_op);
@ -140,7 +141,9 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size
num_rows = std::min(num_rows, partition_num);
}
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
return Status::OK();
}

@ -63,10 +63,12 @@ Status Cifar100Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto cifar_op =
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
@ -90,7 +92,10 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, false, &num_rows));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -61,10 +61,12 @@ Status Cifar10Node::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_op
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto cifar_op =
std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_, dataset_dir_,
connector_que_size_, std::move(schema), std::move(sampler_->SamplerBuild()));
connector_que_size_, std::move(schema), std::move(sampler_rt));
cifar_op->set_total_repeats(GetTotalRepeats());
cifar_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(cifar_op);
@ -88,7 +90,10 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &siz
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(CifarOp::CountTotalRows(dataset_dir_, usage_, true, &num_rows));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -119,9 +119,12 @@ Status CocoNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_UNEXPECTED(err_msg);
}
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
std::shared_ptr<CocoOp> op =
std::make_shared<CocoOp>(task_type, dataset_dir_, annotation_file_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
connector_que_size_, decode_, std::move(schema), std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
@ -145,7 +148,9 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_g
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(CocoOp::CountTotalRows(dataset_dir_, annotation_file_, task_, &num_rows));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -78,7 +78,8 @@ Status GeneratorNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_
column_types_.push_back((col.type()));
}
}
std::shared_ptr<SamplerRT> sampler_rt = sampler_ ? sampler_->SamplerBuild() : nullptr;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
// GeneratorOp's constructor takes in a prefetch_size, which isn't being set by user nor is it being used by
// GeneratorOp internally. Here it is given a zero which is the default in generator builder
@ -140,7 +141,9 @@ Status GeneratorNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &s
int64_t sample_size;
int64_t num_rows;
num_rows = source_len_;
sample_size = sampler_ ? sampler_->SamplerBuild()->CalculateNumSamples(num_rows) : num_rows;
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
if (sampler_) RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_ ? sampler_rt->CalculateNumSamples(num_rows) : num_rows;
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -69,10 +69,12 @@ Status ImageFolderNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const nod
RETURN_IF_NOT_OK(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
recursive_, decode_, exts_, class_indexing_, std::move(schema),
std::move(sampler_->SamplerBuild()));
auto op =
std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_, recursive_,
decode_, exts_, class_indexing_, std::move(schema), std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
@ -95,7 +97,9 @@ Status ImageFolderNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter>
}
int64_t sample_size, num_rows;
RETURN_IF_NOT_OK(ImageFolderOp::CountRowsAndClasses(dataset_dir_, exts_, &num_rows, nullptr, {}));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -91,9 +91,11 @@ Status ManifestNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_o
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<ManifestOp> manifest_op;
manifest_op =
std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_, decode_,
class_index_, std::move(schema), std::move(sampler_->SamplerBuild()), usage_);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
manifest_op = std::make_shared<ManifestOp>(num_workers_, rows_per_buffer_, dataset_file_, connector_que_size_,
decode_, class_index_, std::move(schema), std::move(sampler_rt), usage_);
manifest_op->set_total_repeats(GetTotalRepeats());
manifest_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(manifest_op);
@ -118,7 +120,9 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &si
int64_t num_rows, sample_size;
int64_t num_classes; // dummy variable
RETURN_IF_NOT_OK(ManifestOp::CountTotalRows(dataset_file_, class_index_, usage_, &num_rows, &num_classes));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -57,9 +57,11 @@ Status MnistNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops)
TensorShape scalar = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
auto op = std::make_shared<MnistOp>(usage_, num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
std::move(schema), std::move(sampler_->SamplerBuild()));
std::move(schema), std::move(sampler_rt));
op->set_total_repeats(GetTotalRepeats());
op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(op);
@ -83,7 +85,9 @@ Status MnistNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
}
int64_t num_rows, sample_size;
RETURN_IF_NOT_OK(MnistOp::CountTotalRows(dataset_dir_, usage_, &num_rows));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -42,11 +42,13 @@ namespace dataset {
// Constructor
SamplerObj::SamplerObj() {}
void SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> sampler) {
Status SamplerObj::BuildChildren(std::shared_ptr<SamplerRT> *sampler) {
for (auto child : children_) {
auto sampler_rt = child->SamplerBuild();
sampler->AddChild(sampler_rt);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(child->SamplerBuild(&sampler_rt));
RETURN_IF_NOT_OK((*sampler)->AddChild(sampler_rt));
}
return Status::OK();
}
Status SamplerObj::AddChildSampler(std::shared_ptr<SamplerObj> child) {
@ -113,12 +115,13 @@ Status DistributedSamplerObj::ValidateParams() {
return Status::OK();
}
std::shared_ptr<SamplerRT> DistributedSamplerObj::SamplerBuild() {
Status DistributedSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -186,11 +189,12 @@ Status PKSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
std::shared_ptr<SamplerRT> PKSamplerObj::SamplerBuild() {
Status PKSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -218,9 +222,14 @@ PreBuiltSamplerObj::PreBuiltSamplerObj(std::shared_ptr<mindrecord::ShardOperator
Status PreBuiltSamplerObj::ValidateParams() { return Status::OK(); }
std::shared_ptr<SamplerRT> PreBuiltSamplerObj::SamplerBuild() {
BuildChildren(sp_);
return sp_;
Status PreBuiltSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
Status s = BuildChildren(&sp_);
if (s.IsOk())
*sampler = sp_;
else
*sampler = nullptr;
// FIXME: what to do with sp_ if status is not OK?
return s;
}
#ifndef ENABLE_ANDROID
@ -280,11 +289,12 @@ Status RandomSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
std::shared_ptr<SamplerRT> RandomSamplerObj::SamplerBuild() {
Status RandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -333,11 +343,12 @@ Status SequentialSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
std::shared_ptr<SamplerRT> SequentialSamplerObj::SamplerBuild() {
Status SequentialSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -362,11 +373,12 @@ Status SubsetSamplerObj::ValidateParams() {
return Status::OK();
}
std::shared_ptr<SamplerRT> SubsetSamplerObj::SamplerBuild() {
Status SubsetSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::SubsetSamplerRT>(num_samples_, indices_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -399,11 +411,12 @@ Status SubsetSamplerObj::to_json(nlohmann::json *out_json) {
SubsetRandomSamplerObj::SubsetRandomSamplerObj(std::vector<int64_t> indices, int64_t num_samples)
: SubsetSamplerObj(std::move(indices), num_samples) {}
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::SamplerBuild() {
Status SubsetRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
BuildChildren(sampler);
return sampler;
*sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
#ifndef ENABLE_ANDROID
@ -480,10 +493,11 @@ Status WeightedRandomSamplerObj::to_json(nlohmann::json *out_json) {
return Status::OK();
}
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::SamplerBuild() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
BuildChildren(sampler);
return sampler;
Status WeightedRandomSamplerObj::SamplerBuild(std::shared_ptr<SamplerRT> *sampler) {
*sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
Status s = BuildChildren(sampler);
sampler = s.IsOk() ? sampler : nullptr;
return s;
}
} // namespace dataset

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,8 +48,9 @@ class SamplerObj {
virtual Status ValidateParams() = 0;
/// \brief Pure virtual function to convert a SamplerObj class into a runtime sampler object
/// \return Shared pointers to the newly created Sampler
virtual std::shared_ptr<SamplerRT> SamplerBuild() = 0;
/// \param[out] sampler Shared pointers to the newly created Sampler
/// \return The Status code of the function. It returns OK status if sampler is created successfully.
virtual Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) = 0;
/// \brief Pure virtual function to copy a SamplerObj class
/// \return Shared pointers to the newly copied SamplerObj
@ -78,7 +79,8 @@ class SamplerObj {
protected:
/// \brief A function that calls build on the children of this sampler
/// \param[in] sampler The samplerRT object built from this sampler
void BuildChildren(std::shared_ptr<SamplerRT> sampler);
/// \return the Status code returned
Status BuildChildren(std::shared_ptr<SamplerRT> *sampler);
std::vector<std::shared_ptr<SamplerObj>> children_;
};
@ -91,7 +93,7 @@ class DistributedSamplerObj : public SamplerObj {
~DistributedSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards_, shard_id_, shuffle_, num_samples_, seed_,
@ -133,7 +135,7 @@ class PKSamplerObj : public SamplerObj {
~PKSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<PKSamplerObj>(num_val_, shuffle_, num_samples_);
@ -169,7 +171,7 @@ class PreBuiltSamplerObj : public SamplerObj {
~PreBuiltSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
#ifndef ENABLE_ANDROID
std::shared_ptr<mindrecord::ShardOperator> BuildForMindDataset() override;
@ -194,7 +196,7 @@ class RandomSamplerObj : public SamplerObj {
~RandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<RandomSamplerObj>(replacement_, num_samples_, reshuffle_each_epoch_);
@ -227,7 +229,7 @@ class SequentialSamplerObj : public SamplerObj {
~SequentialSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SequentialSamplerObj>(start_index_, num_samples_);
@ -259,7 +261,7 @@ class SubsetSamplerObj : public SamplerObj {
~SubsetSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetSamplerObj>(indices_, num_samples_);
@ -293,7 +295,7 @@ class SubsetRandomSamplerObj : public SubsetSamplerObj {
Status to_json(nlohmann::json *out_json) override;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices_, num_samples_);
@ -316,7 +318,7 @@ class WeightedRandomSamplerObj : public SamplerObj {
~WeightedRandomSamplerObj() = default;
std::shared_ptr<SamplerRT> SamplerBuild() override;
Status SamplerBuild(std::shared_ptr<SamplerRT> *sampler) override;
std::shared_ptr<SamplerObj> SamplerCopy() override {
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights_, num_samples_, replacement_);

@ -108,11 +108,12 @@ Status VOCNode::Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) {
RETURN_IF_NOT_OK(schema->AddColumn(
ColDescriptor(std::string(kColumnTruncate), DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
}
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
std::shared_ptr<VOCOp> voc_op;
voc_op =
std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_->SamplerBuild()));
voc_op = std::make_shared<VOCOp>(task_type_, usage_, dataset_dir_, class_index_, num_workers_, rows_per_buffer_,
connector_que_size_, decode_, std::move(schema), std::move(sampler_rt));
voc_op->set_total_repeats(GetTotalRepeats());
voc_op->set_num_repeats_per_epoch(GetNumRepeatsPerEpoch());
node_ops->push_back(voc_op);
@ -135,7 +136,9 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_ge
}
int64_t num_rows = 0, sample_size;
RETURN_IF_NOT_OK(VOCOp::CountTotalRows(dataset_dir_, task_, usage_, class_index_, &num_rows));
sample_size = sampler_->SamplerBuild()->CalculateNumSamples(num_rows);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
sample_size = sampler_rt->CalculateNumSamples(num_rows);
*dataset_size = sample_size;
dataset_size_ = *dataset_size;
return Status::OK();

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -90,67 +90,74 @@ TEST_F(MindDataTestPipeline, TestCalculateNumSamples) {
int64_t num_rows = 30; // dummy variable for number of rows in the dataset
std::shared_ptr<SamplerObj> sampl = DistributedSampler(2, 1, false, 6);
EXPECT_NE(sampl, nullptr);
std::shared_ptr<SamplerRT> sampler_rt = sampl->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt;
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 6);
sampl = PKSampler(3, false);
EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->SamplerBuild();
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 30);
sampl = RandomSampler(false, 12);
EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->SamplerBuild();
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
sampl = SequentialSampler(0, 10);
EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->SamplerBuild();
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 10);
std::vector<double> weights = {0.9, 0.8, 0.68, 0.7, 0.71, 0.6, 0.5, 0.4, 0.3, 0.5, 0.2, 0.1};
sampl = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->SamplerBuild();
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 12);
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21};
sampl = SubsetRandomSampler(indices, 11);
EXPECT_NE(sampl, nullptr);
sampler_rt = sampl->SamplerBuild();
sampl->SamplerBuild(&sampler_rt);
EXPECT_EQ(sampler_rt->CalculateNumSamples(num_rows), 11);
// Testing chains
// Parent and child have num_samples
std::shared_ptr<SamplerObj> sampl1 = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl1, nullptr);
std::shared_ptr<SamplerRT> sampler_rt1 = sampl1->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt1;
sampl1->SamplerBuild(&sampler_rt1);
std::shared_ptr<SamplerObj> sampl2 = SequentialSampler(0, 10);
EXPECT_NE(sampl2, nullptr);
std::shared_ptr<SamplerRT> sampler_rt2 = sampl2->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt2;
sampl2->SamplerBuild(&sampler_rt2);
sampler_rt2->AddChild(sampler_rt1);
EXPECT_EQ(sampler_rt2->CalculateNumSamples(num_rows), 10);
// Parent doesn't have num_samples
std::shared_ptr<SamplerObj> sampl3 = WeightedRandomSampler(weights, 12);
EXPECT_NE(sampl3, nullptr);
std::shared_ptr<SamplerRT> sampler_rt3 = sampl3->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt3;
sampl3->SamplerBuild(&sampler_rt3);
std::shared_ptr<SamplerObj> sampl4 = SubsetRandomSampler(indices);
EXPECT_NE(sampl4, nullptr);
std::shared_ptr<SamplerRT> sampler_rt4 = sampl4->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt4;
sampl4->SamplerBuild(&sampler_rt4);
sampler_rt4->AddChild(sampler_rt3);
EXPECT_EQ(sampler_rt4->CalculateNumSamples(num_rows), 11);
// Child doesn't have num_samples
std::shared_ptr<SamplerObj> sampl5 = RandomSampler(false);
EXPECT_NE(sampl5, nullptr);
std::shared_ptr<SamplerRT> sampler_rt5 = sampl5->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt5;
sampl5->SamplerBuild(&sampler_rt5);
std::shared_ptr<SamplerObj> sampl6 = PKSampler(3, false, 7);
EXPECT_NE(sampl6, nullptr);
std::shared_ptr<SamplerRT> sampler_rt6 = sampl6->SamplerBuild();
std::shared_ptr<SamplerRT> sampler_rt6;
sampl6->SamplerBuild(&sampler_rt6);
sampler_rt6->AddChild(sampler_rt5);
EXPECT_EQ(sampler_rt6->CalculateNumSamples(num_rows), 7);
}
@ -159,10 +166,14 @@ TEST_F(MindDataTestPipeline, TestSamplersMoveParameters) {
std::vector<int64_t> indices = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23};
std::shared_ptr<SamplerObj> sampl1 = SubsetRandomSampler(indices);
EXPECT_FALSE(indices.empty());
EXPECT_NE(sampl1->SamplerBuild(), nullptr);
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
sampl1->SamplerBuild(&sampler_rt);
EXPECT_NE(sampler_rt, nullptr);
std::shared_ptr<SamplerObj> sampl2 = SubsetRandomSampler(std::move(indices));
EXPECT_TRUE(indices.empty());
EXPECT_NE(sampl2->SamplerBuild(), nullptr);
std::shared_ptr<SamplerRT> sampler_rt2 = nullptr;
sampl2->SamplerBuild(&sampler_rt2);
EXPECT_NE(sampler_rt, nullptr);
}
TEST_F(MindDataTestPipeline, TestWeightedRandomSamplerFail) {

Loading…
Cancel
Save