|
|
|
@ -810,11 +810,10 @@ std::shared_ptr<ConcatDataset> operator+(const std::shared_ptr<Dataset> &dataset
|
|
|
|
|
class RandomDataDataset : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
RandomDataDataset(const int32_t &total_rows, std::shared_ptr<SchemaObj> schema,
|
|
|
|
|
const std::vector<std::string> &columns_list, const std::shared_ptr<SamplerObj> &sampler,
|
|
|
|
|
std::shared_ptr<DatasetCache> cache);
|
|
|
|
|
const std::vector<std::string> &columns_list, std::shared_ptr<DatasetCache> cache);
|
|
|
|
|
|
|
|
|
|
RandomDataDataset(const int32_t &total_rows, std::string schema_path, const std::vector<std::string> &columns_list,
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
|
|
|
|
std::shared_ptr<DatasetCache> cache);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a RandomDataset
|
|
|
|
@ -829,16 +828,13 @@ class RandomDataDataset : public Dataset {
|
|
|
|
|
template <typename T = std::shared_ptr<SchemaObj>>
|
|
|
|
|
std::shared_ptr<RandomDataDataset> RandomData(const int32_t &total_rows = 0, const T &schema = nullptr,
|
|
|
|
|
const std::vector<std::string> &columns_list = {},
|
|
|
|
|
const std::shared_ptr<SamplerObj> &sampler = RandomSampler(),
|
|
|
|
|
const std::shared_ptr<DatasetCache> &cache = nullptr) {
|
|
|
|
|
std::shared_ptr<RandomDataDataset> ds;
|
|
|
|
|
if constexpr (std::is_same<T, std::nullptr_t>::value || std::is_same<T, std::shared_ptr<SchemaObj>>::value) {
|
|
|
|
|
std::shared_ptr<SchemaObj> schema_obj = schema;
|
|
|
|
|
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list),
|
|
|
|
|
std::move(sampler), cache);
|
|
|
|
|
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema_obj), std::move(columns_list), cache);
|
|
|
|
|
} else {
|
|
|
|
|
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), std::move(sampler),
|
|
|
|
|
cache);
|
|
|
|
|
ds = std::make_shared<RandomDataDataset>(total_rows, std::move(schema), std::move(columns_list), cache);
|
|
|
|
|
}
|
|
|
|
|
return ds;
|
|
|
|
|
}
|
|
|
|
|