parent
2f565f4c20
commit
bd5a777f81
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "dataset/include/iterator.h"
|
||||
#include "dataset/core/client.h"
|
||||
#include "dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Get the next row from the data pipeline.
|
||||
void Iterator::GetNextRow(TensorMap *row) {
|
||||
Status rc = iterator_->GetNextAsMap(row);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "GetNextRow: Failed to get next row.";
|
||||
row->clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Shut down the data pipeline.
|
||||
void Iterator::Stop() {
|
||||
// Releasing the iterator_ unique_ptre. This should trigger the destructor of iterator_.
|
||||
iterator_.reset();
|
||||
|
||||
// Release ownership of tree_ shared pointer. This will decrement the ref count.
|
||||
tree_.reset();
|
||||
}
|
||||
|
||||
// Function to build and launch the execution tree.
|
||||
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
|
||||
// One time init
|
||||
Status rc;
|
||||
rc = GlobalInit();
|
||||
RETURN_IF_NOT_OK(rc);
|
||||
|
||||
// Instantiate the execution tree
|
||||
tree_ = std::make_shared<ExecutionTree>();
|
||||
|
||||
// Iterative BFS converting Dataset tree into runtime Execution tree.
|
||||
std::queue<std::pair<std::shared_ptr<Dataset>, std::shared_ptr<DatasetOp>>> q;
|
||||
|
||||
if (ds != nullptr) {
|
||||
// Convert the current root node.
|
||||
auto root_op = ds->Build()->front();
|
||||
RETURN_UNEXPECTED_IF_NULL(root_op);
|
||||
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(root_op));
|
||||
|
||||
q.push(std::make_pair(ds, root_op));
|
||||
|
||||
// Traverse down to the children and convert them to the corresponding DatasetOps (i.e. execution tree nodes)
|
||||
while (!q.empty()) {
|
||||
auto node_pair = q.front();
|
||||
q.pop();
|
||||
// Iterate through all the direct children of the first element in our BFS queue
|
||||
for (auto child : node_pair.first->children) {
|
||||
auto child_ops = child->Build();
|
||||
RETURN_UNEXPECTED_IF_NULL(child_ops);
|
||||
auto node_op = node_pair.second;
|
||||
// Iterate through all the DatasetOps returned by calling Build on the last Dataset object, associate them
|
||||
// with the execution tree and add the child and parent relationship between the nodes
|
||||
// Note that some Dataset objects might return more than one DatasetOps
|
||||
// e.g. MapDataset will return MapOp and ProjectOp if project_columns is set for MapDataset
|
||||
for (auto child_op : *child_ops) {
|
||||
RETURN_IF_NOT_OK(tree_->AssociateNode(child_op));
|
||||
RETURN_IF_NOT_OK(node_op->AddChild(child_op));
|
||||
node_op = child_op;
|
||||
}
|
||||
// Add the child and the last element of the returned DatasetOps (which is now the leaf node in our current
|
||||
// execution tree) to the BFS queue
|
||||
q.push(std::make_pair(child, child_ops->back()));
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(tree_->AssignRoot(root_op));
|
||||
}
|
||||
|
||||
// Launch the execution tree.
|
||||
RETURN_IF_NOT_OK(tree_->Prepare());
|
||||
RETURN_IF_NOT_OK(tree_->Launch());
|
||||
iterator_ = std::make_unique<DatasetIterator>(tree_);
|
||||
RETURN_UNEXPECTED_IF_NULL(iterator_);
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,224 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/include/samplers.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/distributed_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/sequential_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h"
|
||||
#include "dataset/engine/datasetops/source/sampler/pk_sampler.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
SamplerObj::SamplerObj() {}
|
||||
|
||||
/// Function to create a Distributed Sampler.
|
||||
std::shared_ptr<DistributedSamplerObj> DistributedSampler(int64_t num_shards, int64_t shard_id, bool shuffle,
|
||||
int64_t num_samples, uint32_t seed) {
|
||||
auto sampler = std::make_shared<DistributedSamplerObj>(num_shards, shard_id, shuffle, num_samples, seed);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// Function to create a PK Sampler.
|
||||
std::shared_ptr<PKSamplerObj> PKSampler(int64_t num_val, bool shuffle, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<PKSamplerObj>(num_val, shuffle, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// Function to create a Random Sampler.
|
||||
std::shared_ptr<RandomSamplerObj> RandomSampler(bool replacement, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<RandomSamplerObj>(replacement, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// Function to create a Sequential Sampler.
|
||||
std::shared_ptr<SequentialSamplerObj> SequentialSampler(int64_t start_index, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<SequentialSamplerObj>(start_index, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// Function to create a Subset Random Sampler.
|
||||
std::shared_ptr<SubsetRandomSamplerObj> SubsetRandomSampler(const std::vector<int64_t> &indices, int64_t num_samples) {
|
||||
auto sampler = std::make_shared<SubsetRandomSamplerObj>(indices, num_samples);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/// Function to create a Weighted Random Sampler.
|
||||
std::shared_ptr<WeightedRandomSamplerObj> WeightedRandomSampler(const std::vector<double> &weights, int64_t num_samples,
|
||||
bool replacement) {
|
||||
auto sampler = std::make_shared<WeightedRandomSamplerObj>(weights, num_samples, replacement);
|
||||
// Input validation
|
||||
if (!sampler->ValidateParams()) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampler;
|
||||
}
|
||||
|
||||
/* ####################################### Derived Sampler classes ################################# */
|
||||
|
||||
// DistributedSampler
|
||||
DistributedSamplerObj::DistributedSamplerObj(int64_t num_shards, int64_t shard_id, bool shuffle, int64_t num_samples,
|
||||
uint32_t seed)
|
||||
: num_shards_(num_shards), shard_id_(shard_id), shuffle_(shuffle), num_samples_(num_samples), seed_(seed) {}
|
||||
|
||||
bool DistributedSamplerObj::ValidateParams() {
|
||||
if (num_shards_ <= 0) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shard_id_ < 0 || shard_id_ >= num_shards_) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid input, shard_id: " << shard_id_ << ", num_shards: " << num_shards_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "DistributedSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
|
||||
return std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_);
|
||||
}
|
||||
|
||||
// PKSampler
|
||||
PKSamplerObj::PKSamplerObj(int64_t num_val, bool shuffle, int64_t num_samples)
|
||||
: num_val_(num_val), shuffle_(shuffle), num_samples_(num_samples) {}
|
||||
|
||||
bool PKSamplerObj::ValidateParams() {
|
||||
if (num_val_ <= 0) {
|
||||
MS_LOG(ERROR) << "PKSampler: invalid num_val: " << num_val_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "PKSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> PKSamplerObj::Build() {
|
||||
return std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
|
||||
}
|
||||
|
||||
// RandomSampler
|
||||
RandomSamplerObj::RandomSamplerObj(bool replacement, int64_t num_samples)
|
||||
: replacement_(replacement), num_samples_(num_samples) {}
|
||||
|
||||
bool RandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "RandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> RandomSamplerObj::Build() {
|
||||
bool reshuffle_each_epoch = true;
|
||||
auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
// SequentialSampler
|
||||
SequentialSamplerObj::SequentialSamplerObj(int64_t start_index, int64_t num_samples)
|
||||
: start_index_(start_index), num_samples_(num_samples) {}
|
||||
|
||||
bool SequentialSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "SequentialSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (start_index_ < 0) {
|
||||
MS_LOG(ERROR) << "SequentialSampler: invalid start_index: " << start_index_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
|
||||
auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
// SubsetRandomSampler
|
||||
SubsetRandomSamplerObj::SubsetRandomSamplerObj(const std::vector<int64_t> &indices, int64_t num_samples)
|
||||
: indices_(indices), num_samples_(num_samples) {}
|
||||
|
||||
bool SubsetRandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "SubsetRandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
|
||||
auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
// WeightedRandomSampler
|
||||
WeightedRandomSamplerObj::WeightedRandomSamplerObj(const std::vector<double> &weights, int64_t num_samples,
|
||||
bool replacement)
|
||||
: weights_(weights), num_samples_(num_samples), replacement_(replacement) {}
|
||||
|
||||
bool WeightedRandomSamplerObj::ValidateParams() {
|
||||
if (num_samples_ < 0) {
|
||||
MS_LOG(ERROR) << "WeightedRandomSampler: invalid num_samples: " << num_samples_;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() {
|
||||
auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_);
|
||||
return sampler;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
@ -1,19 +1,32 @@
|
||||
add_subdirectory(sampler)
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-datasetops-source OBJECT
|
||||
generator_op.cc
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
io_block.cc
|
||||
mindrecord_op.cc
|
||||
tf_reader_op.cc
|
||||
image_folder_op.cc
|
||||
mnist_op.cc
|
||||
voc_op.cc
|
||||
coco_op.cc
|
||||
manifest_op.cc
|
||||
cifar_op.cc
|
||||
random_data_op.cc
|
||||
celeba_op.cc
|
||||
text_file_op.cc
|
||||
clue_op.cc
|
||||
)
|
||||
)
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}
|
||||
mindrecord_op.cc
|
||||
tf_reader_op.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES
|
||||
${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES}
|
||||
generator_op.cc
|
||||
voc_op.cc
|
||||
manifest_op.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(engine-datasetops-source OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SRC_FILES})
|
@ -1,12 +1,21 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-datasetops-source-sampler OBJECT
|
||||
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
||||
distributed_sampler.cc
|
||||
pk_sampler.cc
|
||||
python_sampler.cc
|
||||
random_sampler.cc
|
||||
sampler.cc
|
||||
sequential_sampler.cc
|
||||
subset_random_sampler.cc
|
||||
weighted_random_sampler.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
set(DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES
|
||||
${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES}
|
||||
python_sampler.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(engine-datasetops-source-sampler OBJECT ${DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_SRC_FILES})
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue