Remove api namespace

pull/8139/head
hesham 4 years ago
parent 11028fb31c
commit 5169fb4c42

@ -21,7 +21,6 @@
namespace mindspore {
namespace dataset {
namespace api {
// Config operations for setting and getting the configuration.
namespace config {
@ -104,6 +103,5 @@ bool load(std::string file) {
}
} // namespace config
} // namespace api
} // namespace dataset
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -26,7 +26,6 @@
namespace mindspore {
namespace dataset {
namespace api {
Execute::Execute(std::shared_ptr<TensorOperation> op) : op_(std::move(op)) {}
@ -54,6 +53,5 @@ std::shared_ptr<tensor::MSTensor> Execute::operator()(std::shared_ptr<tensor::MS
return std::make_shared<tensor::DETensor>(std::move(de_output));
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -20,7 +20,6 @@
namespace mindspore {
namespace dataset {
namespace api {
// Get the next row from the data pipeline.
bool Iterator::GetNextRow(TensorMap *row) {
@ -45,19 +44,18 @@ bool Iterator::GetNextRow(TensorVec *row) {
}
// Shut down the data pipeline.
void Iterator::Stop() { runtime_context->Terminate(); }
void Iterator::Stop() { runtime_context_->Terminate(); }
// Function to build and launch the execution tree.
Status Iterator::BuildAndLaunchTree(std::shared_ptr<Dataset> ds) {
runtime_context = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context->Init());
runtime_context_ = std::make_unique<RuntimeContext>();
RETURN_IF_NOT_OK(runtime_context_->Init());
auto consumer = std::make_unique<IteratorConsumer>();
consumer_ = consumer.get();
RETURN_IF_NOT_OK(consumer->Init(ds->IRNode()));
runtime_context->AssignConsumer(std::move(consumer));
runtime_context_->AssignConsumer(std::move(consumer));
return Status::OK();
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -27,59 +27,59 @@
namespace mindspore {
namespace dataset {
PYBIND_REGISTER(Sampler, 0, ([](const py::module *m) {
(void)py::class_<Sampler, std::shared_ptr<Sampler>>(*m, "Sampler")
PYBIND_REGISTER(SamplerRT, 0, ([](const py::module *m) {
(void)py::class_<SamplerRT, std::shared_ptr<SamplerRT>>(*m, "Sampler")
.def("set_num_rows",
[](Sampler &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
[](SamplerRT &self, int64_t rows) { THROW_IF_ERROR(self.SetNumRowsInDataset(rows)); })
.def("set_num_samples",
[](Sampler &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](Sampler &self) { THROW_IF_ERROR(self.InitSampler()); })
[](SamplerRT &self, int64_t samples) { THROW_IF_ERROR(self.SetNumSamples(samples)); })
.def("initialize", [](SamplerRT &self) { THROW_IF_ERROR(self.InitSampler()); })
.def("get_indices",
[](Sampler &self) {
[](SamplerRT &self) {
py::array ret;
THROW_IF_ERROR(self.GetAllIdsThenReset(&ret));
return ret;
})
.def("add_child", [](std::shared_ptr<Sampler> self, std::shared_ptr<Sampler> child) {
.def("add_child", [](std::shared_ptr<SamplerRT> self, std::shared_ptr<SamplerRT> child) {
THROW_IF_ERROR(self->AddChild(child));
});
}));
PYBIND_REGISTER(DistributedSampler, 1, ([](const py::module *m) {
(void)py::class_<DistributedSampler, Sampler, std::shared_ptr<DistributedSampler>>(
PYBIND_REGISTER(DistributedSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<DistributedSamplerRT, SamplerRT, std::shared_ptr<DistributedSamplerRT>>(
*m, "DistributedSampler")
.def(py::init<int64_t, int64_t, int64_t, bool, uint32_t, int64_t>());
}));
PYBIND_REGISTER(PKSampler, 1, ([](const py::module *m) {
(void)py::class_<PKSampler, Sampler, std::shared_ptr<PKSampler>>(*m, "PKSampler")
PYBIND_REGISTER(PKSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PKSamplerRT, SamplerRT, std::shared_ptr<PKSamplerRT>>(*m, "PKSampler")
.def(py::init<int64_t, int64_t, bool>());
}));
PYBIND_REGISTER(PythonSampler, 1, ([](const py::module *m) {
(void)py::class_<PythonSampler, Sampler, std::shared_ptr<PythonSampler>>(*m, "PythonSampler")
PYBIND_REGISTER(PythonSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<PythonSamplerRT, SamplerRT, std::shared_ptr<PythonSamplerRT>>(*m, "PythonSampler")
.def(py::init<int64_t, py::object>());
}));
PYBIND_REGISTER(RandomSampler, 1, ([](const py::module *m) {
(void)py::class_<RandomSampler, Sampler, std::shared_ptr<RandomSampler>>(*m, "RandomSampler")
PYBIND_REGISTER(RandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<RandomSamplerRT, SamplerRT, std::shared_ptr<RandomSamplerRT>>(*m, "RandomSampler")
.def(py::init<int64_t, bool, bool>());
}));
PYBIND_REGISTER(SequentialSampler, 1, ([](const py::module *m) {
(void)py::class_<SequentialSampler, Sampler, std::shared_ptr<SequentialSampler>>(*m,
"SequentialSampler")
PYBIND_REGISTER(SequentialSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SequentialSamplerRT, SamplerRT, std::shared_ptr<SequentialSamplerRT>>(
*m, "SequentialSampler")
.def(py::init<int64_t, int64_t>());
}));
PYBIND_REGISTER(SubsetRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSampler, Sampler, std::shared_ptr<SubsetRandomSampler>>(
PYBIND_REGISTER(SubsetRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<SubsetRandomSamplerRT, SamplerRT, std::shared_ptr<SubsetRandomSamplerRT>>(
*m, "SubsetRandomSampler")
.def(py::init<int64_t, std::vector<int64_t>>());
}));
PYBIND_REGISTER(WeightedRandomSampler, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSampler, Sampler, std::shared_ptr<WeightedRandomSampler>>(
PYBIND_REGISTER(WeightedRandomSamplerRT, 1, ([](const py::module *m) {
(void)py::class_<WeightedRandomSamplerRT, SamplerRT, std::shared_ptr<WeightedRandomSamplerRT>>(
*m, "WeightedRandomSampler")
.def(py::init<int64_t, std::vector<double>, bool>());
}));

@ -1140,7 +1140,7 @@ Status DEPipeline::ParseConcatOp(const py::dict &args, std::shared_ptr<DatasetOp
if (!value.is_none()) {
if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
}
if (key == "children_flag_and_nums") {
@ -1164,7 +1164,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TFReaderOp::Builder> builder = std::make_shared<TFReaderOp::Builder>();
if (!args["dataset_files"].is_none()) {
@ -1210,7 +1210,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@ -1234,7 +1234,7 @@ Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
@ -1308,7 +1308,7 @@ Status DEPipeline::ParseImageFolderOp(const py::dict &args, std::shared_ptr<Data
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "extensions") {
(void)builder->SetExtensions(ToStringSet(value));
@ -1363,7 +1363,7 @@ Status DEPipeline::ParseManifestOp(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "class_indexing") {
(void)builder->SetClassIndex(ToStringMap(value));
@ -1416,7 +1416,7 @@ Status DEPipeline::ParseVOCOp(const py::dict &args, std::shared_ptr<DatasetOp> *
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@ -1478,7 +1478,7 @@ Status DEPipeline::ParseCocoOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@ -1529,7 +1529,7 @@ Status DEPipeline::ParseCifar10Op(const py::dict &args, std::shared_ptr<DatasetO
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@ -1583,7 +1583,7 @@ Status DEPipeline::ParseCifar100Op(const py::dict &args, std::shared_ptr<Dataset
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@ -1618,7 +1618,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
// Required arguments
RandomDataOp::Builder builder;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
if (args["total_rows"].is_none()) {
@ -1646,7 +1646,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@ -1670,7 +1670,7 @@ Status DEPipeline::ParseRandomDataOp(const py::dict &args, std::shared_ptr<Datas
} else if (cache_client) {
const int64_t num_samples = 0;
const int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder.SetSampler(std::move(sampler));
}
@ -1715,7 +1715,7 @@ Status DEPipeline::ParseMnistOp(const py::dict &args, std::shared_ptr<DatasetOp>
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "usage") {
(void)builder->SetUsage(ToString(value));
@ -1768,7 +1768,7 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr<DatasetOp
(void)builder->SetNumWorkers(num_workers);
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
std::shared_ptr<Sampler> sampler = create().cast<std::shared_ptr<Sampler>>();
std::shared_ptr<SamplerRT> sampler = create().cast<std::shared_ptr<SamplerRT>>();
(void)builder->SetSampler(std::move(sampler));
} else if (key == "decode") {
(void)builder->SetDecode(ToBool(value));
@ -1806,7 +1806,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
// Required arguments
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<TextFileOp::Builder> builder = std::make_shared<TextFileOp::Builder>();
if (!args["dataset_files"].is_none()) {
@ -1840,7 +1840,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@ -1855,7 +1855,7 @@ Status DEPipeline::ParseTextFileOp(const py::dict &args, std::shared_ptr<Dataset
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
@ -1991,7 +1991,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<ClueOp::Builder> builder = std::make_shared<ClueOp::Builder>();
@ -2036,7 +2036,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@ -2051,7 +2051,7 @@ Status DEPipeline::ParseClueOp(const py::dict &args, std::shared_ptr<DatasetOp>
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}
@ -2116,7 +2116,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
std::shared_ptr<DatasetOp> *bottom) {
std::vector<std::string> files_list;
std::shared_ptr<CacheClient> cache_client = nullptr;
std::shared_ptr<Sampler> sampler = nullptr;
std::shared_ptr<SamplerRT> sampler = nullptr;
int num_workers = 0;
std::shared_ptr<CsvOp::Builder> builder = std::make_shared<CsvOp::Builder>();
if (!args["dataset_files"].is_none()) {
@ -2173,7 +2173,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
cache_client = value.cast<std::shared_ptr<CacheClient>>();
} else if (key == "sampler") {
auto create = py::reinterpret_borrow<py::object>(value).attr("create");
sampler = create().cast<std::shared_ptr<Sampler>>();
sampler = create().cast<std::shared_ptr<SamplerRT>>();
}
}
}
@ -2188,7 +2188,7 @@ Status DEPipeline::ParseCsvOp(const py::dict &args, std::shared_ptr<DatasetOp> *
} else if (cache_client) {
int64_t num_samples = 0;
int64_t start_index = 0;
sampler = std::make_shared<SequentialSampler>(num_samples, start_index);
sampler = std::make_shared<SequentialSamplerRT>(num_samples, start_index);
(void)builder->SetSampler(std::move(sampler));
}

@ -35,7 +35,6 @@
namespace mindspore {
namespace dataset {
namespace api {
#define RETURN_NULL_IF_ERROR(_s) \
do { \
@ -151,9 +150,9 @@ bool DistributedSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> DistributedSamplerObj::Build() {
std::shared_ptr<SamplerRT> DistributedSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::DistributedSampler>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
auto sampler = std::make_shared<dataset::DistributedSamplerRT>(num_samples_, num_shards_, shard_id_, shuffle_, seed_,
offset_, even_dist_);
return sampler;
}
@ -184,9 +183,9 @@ bool PKSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> PKSamplerObj::Build() {
std::shared_ptr<SamplerRT> PKSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::PKSampler>(num_samples_, num_val_, shuffle_);
auto sampler = std::make_shared<dataset::PKSamplerRT>(num_samples_, num_val_, shuffle_);
return sampler;
}
@ -218,10 +217,10 @@ bool RandomSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> RandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> RandomSamplerObj::Build() {
// runtime sampler object
bool reshuffle_each_epoch = true;
auto sampler = std::make_shared<dataset::RandomSampler>(num_samples_, replacement_, reshuffle_each_epoch);
auto sampler = std::make_shared<dataset::RandomSamplerRT>(num_samples_, replacement_, reshuffle_each_epoch);
return sampler;
}
@ -255,9 +254,9 @@ bool SequentialSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> SequentialSamplerObj::Build() {
std::shared_ptr<SamplerRT> SequentialSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SequentialSampler>(num_samples_, start_index_);
auto sampler = std::make_shared<dataset::SequentialSamplerRT>(num_samples_, start_index_);
return sampler;
}
@ -284,9 +283,9 @@ bool SubsetRandomSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> SubsetRandomSamplerObj::Build() {
std::shared_ptr<SamplerRT> SubsetRandomSamplerObj::Build() {
// runtime sampler object
auto sampler = std::make_shared<dataset::SubsetRandomSampler>(num_samples_, indices_);
auto sampler = std::make_shared<dataset::SubsetRandomSamplerRT>(num_samples_, indices_);
return sampler;
}
@ -330,11 +329,10 @@ bool WeightedRandomSamplerObj::ValidateParams() {
return true;
}
std::shared_ptr<Sampler> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSampler>(num_samples_, weights_, replacement_);
std::shared_ptr<SamplerRT> WeightedRandomSamplerObj::Build() {
auto sampler = std::make_shared<dataset::WeightedRandomSamplerRT>(num_samples_, weights_, replacement_);
return sampler;
}
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -22,7 +22,6 @@
namespace mindspore {
namespace dataset {
namespace api {
// Transform operations for text.
namespace text {
@ -130,6 +129,5 @@ std::shared_ptr<TensorOp> SentencePieceTokenizerOperation::Build() {
}
} // namespace text
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -22,7 +22,6 @@
namespace mindspore {
namespace dataset {
namespace api {
TensorOperation::TensorOperation() {}
@ -94,6 +93,5 @@ Status TypeCastOperation::ValidateParams() {
std::shared_ptr<TensorOp> TypeCastOperation::Build() { return std::make_shared<TypeCastOp>(data_type_); }
} // namespace transforms
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -65,7 +65,6 @@
namespace mindspore {
namespace dataset {
namespace api {
// Transform operations for computer vision.
namespace vision {
@ -1702,6 +1701,5 @@ std::shared_ptr<TensorOp> UniformAugOperation::Build() {
#endif
} // namespace vision
} // namespace api
} // namespace dataset
} // namespace mindspore

@ -34,11 +34,11 @@ namespace mindspore::dataset {
// TreeConsumer
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
Status TreeConsumer::Init(std::shared_ptr<api::DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
Status TreeConsumer::Terminate() { return tree_adapter_->AllTasks()->DoServiceStop(); }
// IteratorConsumer
Status IteratorConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status IteratorConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
@ -74,7 +74,7 @@ Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr>
}
// ToDevice
Status ToDevice::Init(std::shared_ptr<api::DatasetNode> d) {
Status ToDevice::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
}
@ -385,8 +385,8 @@ TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), row_flag_(fal
tree_adapter_ = std::make_unique<TreeAdapter>();
}
Status TreeGetters::Init(std::shared_ptr<api::DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d));
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
Status s = tree_adapter_->BuildAndPrepare(std::move(d), 1);
if (!s.IsError()) {
init_flag_ = true;
}
@ -464,7 +464,7 @@ Status TreeGetters::GetNumClasses(int64_t *num_classes) {
RETURN_IF_NOT_OK(root->GetNumClasses(num_classes));
return Status::OK();
}
Status BuildVocabConsumer::Init(std::shared_ptr<api::DatasetNode> d) {
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) {
return tree_adapter_->BuildAndPrepare(std::move(d), 1);
}
Status BuildVocabConsumer::Start() {

@ -29,10 +29,7 @@
namespace mindspore::dataset {
// Forward declare
class TreeAdapter;
namespace api {
class DatasetNode;
}
/// A base class for tree consumers which would fetch rows from the tree pipeline
class TreeConsumer {
@ -42,7 +39,7 @@ class TreeConsumer {
/// Initializes the consumer, this involves constructing and preparing the tree.
/// \param d The dataset node that represent the root of the IR tree.
/// \return Status error code.
virtual Status Init(std::shared_ptr<api::DatasetNode> d);
virtual Status Init(std::shared_ptr<DatasetNode> d);
Status Terminate();
@ -61,7 +58,7 @@ class IteratorConsumer : public TreeConsumer {
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
/// Returns the next row in a vector format
/// \param[out] out std::vector of Tensors
@ -133,7 +130,7 @@ class ToDevice : public TreeConsumer {
explicit ToDevice(bool send_epoch_end, int32_t num_epochs = -1)
: TreeConsumer(), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
/// Send the data to device
/// \return Status error code
@ -162,7 +159,7 @@ class ToDevice : public TreeConsumer {
class TreeGetters : public TreeConsumer {
public:
TreeGetters();
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
Status GetDatasetSize(int64_t *size);
Status GetOutputTypes(std::vector<DataType> *types);
Status GetOutputShapes(std::vector<TensorShape> *shapes);
@ -185,10 +182,9 @@ class BuildVocabConsumer : public TreeConsumer {
/// BuildVocabConsumer Constructor which will call the base class default constructor.
BuildVocabConsumer() = default;
Status Init(std::shared_ptr<api::DatasetNode> d) override;
Status Init(std::shared_ptr<DatasetNode> d) override;
/// Save the given dataset to MindRecord format on disk. This is a blocking method (i.e., after returning, all rows
/// would be written to disk)
/// Start consuming
/// \return Status error code
Status Start();

@ -46,7 +46,7 @@ Status CacheBase::Reset() {
return Status::OK();
}
CacheBase::CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: ParallelOp(num_workers, op_connector_size, std::move(sampler)),
row_cnt_(0),
num_cache_miss_(0),

@ -46,7 +46,7 @@ class CacheBase : public ParallelOp {
/// \param cache_client CacheClient for communication to the CacheServer
/// \param sampler Sampler which is mandatory
CacheBase(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);
/// \brief Destructor
~CacheBase();

@ -87,7 +87,7 @@ Status CacheLookupOp::HandshakeRandomAccessOp(const RandomAccessOp *op) {
leaf_op_wp_.Set();
return Status::OK();
}
Status CacheLookupOp::InitSampler() { return Sampler::InitSampler(); }
Status CacheLookupOp::InitSampler() { return SamplerRT::InitSampler(); }
void CacheLookupOp::Print(std::ostream &out, bool show_all) const { CacheBase::Print(out, show_all); }
Status CacheLookupOp::GetNextSample(std::unique_ptr<DataBuffer> *out_buffer) {
std::vector<row_id_type> cache_miss;

@ -28,7 +28,7 @@ namespace dataset {
/// \brief provides a memory/disk cache that acts as a save-point within a mappable dataset.
/// \note For non-mappable dataset, please see CacheOp
/// \see CacheOp
class CacheLookupOp : public CacheBase, public Sampler {
class CacheLookupOp : public CacheBase, public SamplerRT {
public:
class Builder {
public:
@ -62,7 +62,7 @@ class CacheLookupOp : public CacheBase, public Sampler {
/// \brief Setter method.
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@ -77,7 +77,7 @@ class CacheLookupOp : public CacheBase, public Sampler {
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;
// Check if the required parameters are set by the builder.
// \return Status The error code return
@ -87,8 +87,8 @@ class CacheLookupOp : public CacheBase, public Sampler {
/// \note It takes the same argument as the base class.
/// \see CacheBase
CacheLookupOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), Sampler(*(sampler.get())) {}
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, cache_client, sampler), SamplerRT(*(sampler.get())) {}
~CacheLookupOp() = default;
// As a parallel op, we override these two functions
Status operator()() override;

@ -46,7 +46,7 @@ void CacheMergeOp::Print(std::ostream &out, bool show_all) const {
}
CacheMergeOp::CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler)
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler)
: ParallelOp(numWorkers, opConnectorSize, sampler),
num_cleaners_(numCleaners),
cache_client_(std::move(cache_client)),

@ -110,7 +110,7 @@ class CacheMergeOp : public ParallelOp {
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@ -133,7 +133,7 @@ class CacheMergeOp : public ParallelOp {
int32_t build_op_connector_size_;
int32_t build_num_cleaners_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;
/// Check if the required parameters are set by the builder.
/// \return Status The error code return
@ -147,7 +147,7 @@ class CacheMergeOp : public ParallelOp {
/// \param cache_client CacheClient to commmunicate with the Cache server
/// \param sampler as a derived class of ParallelOp
CacheMergeOp(int32_t numWorkers, int32_t opConnectorSize, int32_t numCleaners,
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<Sampler> &sampler);
std::shared_ptr<CacheClient> cache_client, const std::shared_ptr<SamplerRT> &sampler);
~CacheMergeOp();
void Print(std::ostream &out, bool show_all) const override;
std::string Name() const override { return kCacheMergeOp; }

@ -68,7 +68,7 @@ Status CacheOp::Builder::Build(std::shared_ptr<CacheOp> *ptr) {
// Constructor of CacheOp
CacheOp::CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler)
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler)
: CacheBase(num_workers, op_connector_size, rows_per_buf, std::move(cache_client), std::move(sampler)),
num_guys_in_(0),
phase_(Phase::kBuildPhase) {}

@ -81,7 +81,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \brief Setter method
/// \param sampler
/// \return Builder setter method returns reference to the builder.
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
build_sampler_ = std::move(sampler);
return *this;
}
@ -96,7 +96,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
int32_t rows_per_buffer_;
int32_t build_op_connector_size_;
std::shared_ptr<CacheClient> build_cache_client_;
std::shared_ptr<Sampler> build_sampler_;
std::shared_ptr<SamplerRT> build_sampler_;
/// \brief Check if the required parameters are set by the builder.
/// \return Status The error code return
@ -108,7 +108,7 @@ class CacheOp : public CacheBase, public RandomAccessOp {
/// \param num_workers The number of worker threads.
/// \param op_connector_size The size of each queue in the connector.
CacheOp(int32_t num_workers, int32_t op_connector_size, int32_t rows_per_buf,
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<Sampler> sampler);
std::shared_ptr<CacheClient> cache_client, std::shared_ptr<SamplerRT> sampler);
// Destructor
~CacheOp();

@ -36,7 +36,7 @@ ConcatOp::Builder::Builder() {
// The builder "build" method creates the final object.
Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
if (builder_sampler_ == nullptr) {
builder_sampler_ = std::make_shared<DistributedSampler>(0, 1, 0, false);
builder_sampler_ = std::make_shared<DistributedSamplerRT>(0, 1, 0, false);
}
*ptr = std::make_shared<ConcatOp>(builder_op_connector_size_, builder_sampler_, children_flag_and_nums_,
children_start_end_index_);
@ -44,7 +44,7 @@ Status ConcatOp::Builder::Build(std::shared_ptr<ConcatOp> *ptr) {
}
// Constructor of the ConcatOp.
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
ConcatOp::ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index)
: PipelineOp(op_connector_size),
@ -80,7 +80,7 @@ Status ConcatOp::operator()() {
bool is_not_mappable = true;
int num_shard = 1;
int shard_index = 0;
std::shared_ptr<DistributedSampler> distribute_sampler = std::dynamic_pointer_cast<DistributedSampler>(sampler_);
std::shared_ptr<DistributedSamplerRT> distribute_sampler = std::dynamic_pointer_cast<DistributedSamplerRT>(sampler_);
if (distribute_sampler != nullptr) {
num_shard = distribute_sampler->GetDeviceNum();
shard_index = distribute_sampler->GetDeviceID();

@ -44,7 +44,7 @@ class ConcatOp : public PipelineOp {
// The builder "build" method creates the final object.
// @return shared_ptr to the new ConcatOp object
Status Build(std::shared_ptr<ConcatOp> *);
Builder &SetSampler(std::shared_ptr<Sampler> sampler) {
Builder &SetSampler(std::shared_ptr<SamplerRT> sampler) {
builder_sampler_ = std::move(sampler);
return *this;
}
@ -61,7 +61,7 @@ class ConcatOp : public PipelineOp {
private:
int32_t builder_op_connector_size_;
std::shared_ptr<Sampler> builder_sampler_;
std::shared_ptr<SamplerRT> builder_sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
};
@ -70,7 +70,7 @@ class ConcatOp : public PipelineOp {
// @note The builder class should be used to call it
// @param op_connector_size - connector size
explicit ConcatOp(int32_t op_connector_size);
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler,
explicit ConcatOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler,
std::vector<std::pair<int, int>> children_flag_and_nums,
std::vector<std::pair<int, int>> children_start_end_index);
@ -123,7 +123,7 @@ class ConcatOp : public PipelineOp {
std::unordered_map<std::string, int32_t> column_name_id_; // Mapping between col index and col name
std::vector<DataType> data_type_;
std::vector<dsize_t> data_rank_;
std::shared_ptr<Sampler> sampler_;
std::shared_ptr<SamplerRT> sampler_;
std::vector<std::pair<int, int>> children_flag_and_nums_;
std::vector<std::pair<int, int>> children_start_end_index_;
};

@ -40,7 +40,7 @@
namespace mindspore {
namespace dataset {
// Constructor
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
DatasetOp::DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: oc_queue_size_(op_connector_size),
sampler_(sampler),
operator_id_(kInvalidOperatorId),
@ -409,7 +409,7 @@ Status DatasetOp::Accept(NodePass *p, bool *modified) {
}
// Getter for the sampler, and it also removes the sampler from the op
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<Sampler> *sampler) {
Status DatasetOp::FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler) {
*sampler = sampler_; // It's okay if it sampler_ points to nullptr
sampler_.reset(); // clear our member-copy of this pointer. We no longer have this sampler
return Status::OK();

@ -62,7 +62,7 @@ class DataBuffer;
class NodePass;
class Sampler;
class SamplerRT;
/// \brief The base class DatasetOp is the main tree node. It is an abstract class, so
/// the actual implementation of the operators will be derived from here.
@ -80,7 +80,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// Constructor
/// \param op_connector_size - The size for the output connector of this operator.
/// \param sampler - The sampler for the op
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<Sampler> sampler);
explicit DatasetOp(int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler);
/// Destructor
virtual ~DatasetOp() { tree_ = nullptr; }
@ -347,12 +347,12 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
/// Getter for the sampler
/// \return Shared pointer to the sampler (may return nullptr)
std::shared_ptr<Sampler> sampler() { return sampler_; }
std::shared_ptr<SamplerRT> sampler() { return sampler_; }
/// \brief Getter for the sampler, and it also removes the sampler from the op
/// \param[out] sampler A pointer to the output sampler that was removed
/// \return Status error code
Status FetchRemoveSampler(std::shared_ptr<Sampler> *sampler);
Status FetchRemoveSampler(std::shared_ptr<SamplerRT> *sampler);
#ifndef ENABLE_ANDROID
// Computes a CRC value for the operator
@ -368,7 +368,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
}
/// \brief Setter for the sampler. Allows you to overwrite a previous sampler with a new one.
void SetSampler(std::shared_ptr<Sampler> sampler) { sampler_ = sampler; }
void SetSampler(std::shared_ptr<SamplerRT> sampler) { sampler_ = sampler; }
/// \brief Checks if this is a leaf node (0 children)
/// \return boolean returns true if it's a leaf
@ -409,7 +409,7 @@ class DatasetOp : public std::enable_shared_from_this<DatasetOp> {
std::vector<std::shared_ptr<DatasetOp>> child_; // Child nodes
std::vector<DatasetOp *> parent_; // Parent nodes. No ownership
std::shared_ptr<Sampler> sampler_; // Some leaf ops might have a sampler
std::shared_ptr<SamplerRT> sampler_; // Some leaf ops might have a sampler
int32_t oc_queue_size_; // Capacity for each out_connector_
int32_t operator_id_; // Generated id for the node
ExecutionTree *tree_; // Back pointer to our tree.

@ -26,7 +26,7 @@
namespace mindspore {
namespace dataset {
// Constructor
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<Sampler> sampler)
ParallelOp::ParallelOp(int32_t num_workers, int32_t op_connector_size, std::shared_ptr<SamplerRT> sampler)
: DatasetOp(op_connector_size, sampler),
num_workers_(num_workers),
num_producers_(num_workers),

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save