|
|
|
@ -301,8 +301,9 @@ std::shared_ptr<SchemaObj> Schema(const std::string &schema_file) {
|
|
|
|
|
// Function to create a AlbumNode.
|
|
|
|
|
std::shared_ptr<AlbumNode> Album(const std::string &dataset_dir, const std::string &data_schema,
|
|
|
|
|
const std::vector<std::string> &column_names, bool decode,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler) {
|
|
|
|
|
auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler);
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler,
|
|
|
|
|
const std::shared_ptr<DatasetCache> &cache) {
|
|
|
|
|
auto ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
|
|
|
|
|
|
|
|
|
|
return ds->ValidateParams() ? ds : nullptr;
|
|
|
|
|
}
|
|
|
|
@ -1021,9 +1022,25 @@ std::shared_ptr<DatasetCache> CreateDatasetCache(session_id_type id, uint64_t me
|
|
|
|
|
auto cache = std::make_shared<DatasetCacheImpl>(id, mem_sz, spill, hostname, port, num_connections, prefetch_sz);
|
|
|
|
|
return cache->ValidateParams() ? cache : nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<SamplerObj> SelectSampler(int64_t num_samples, bool shuffle, int32_t num_shards, int32_t shard_id) {
|
|
|
|
|
if (shuffle) {
|
|
|
|
|
if (num_shards > 1) {
|
|
|
|
|
// If shuffle enabled, sharding enabled, use distributed random sampler
|
|
|
|
|
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
|
|
|
|
|
}
|
|
|
|
|
// If shuffle enabled, sharding disabled, use random sampler
|
|
|
|
|
return RandomSampler(num_samples >= 0, num_samples);
|
|
|
|
|
}
|
|
|
|
|
if (num_shards > 1) {
|
|
|
|
|
// If shuffle disabled, sharding enabled, use distributed sequential sampler
|
|
|
|
|
return DistributedSampler(num_shards, shard_id, shuffle, num_samples);
|
|
|
|
|
}
|
|
|
|
|
// If shuffle disabled, sharding disabled, use sequential sampler
|
|
|
|
|
return SequentialSampler(0, num_samples);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace api
|
|
|
|
|
} // namespace dataset
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|