fix android build failure fix include in datasets.cc attempting fix android build fix complie err x86pull/7657/head
parent
d50736df2c
commit
cfc50e231a
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,91 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/map_op/map_op.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
MapNode::MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns, std::vector<std::string> output_columns,
|
||||
const std::vector<std::string> &project_columns, std::shared_ptr<DatasetCache> cache)
|
||||
: operations_(operations),
|
||||
input_columns_(input_columns),
|
||||
output_columns_(output_columns),
|
||||
project_columns_(project_columns),
|
||||
Dataset(std::move(cache)) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> MapNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::vector<std::shared_ptr<TensorOp>> tensor_ops;
|
||||
|
||||
// Build tensorOp from tensorOperation vector
|
||||
// This is to ensure each iterator hold its own copy of the tensorOp objects.
|
||||
(void)std::transform(
|
||||
operations_.begin(), operations_.end(), std::back_inserter(tensor_ops),
|
||||
[](std::shared_ptr<TensorOperation> operation) -> std::shared_ptr<TensorOp> { return operation->Build(); });
|
||||
|
||||
// This parameter will be removed with next rebase
|
||||
std::vector<std::string> col_orders;
|
||||
auto map_op = std::make_shared<MapOp>(input_columns_, output_columns_, tensor_ops, num_workers_, connector_que_size_);
|
||||
if (!project_columns_.empty()) {
|
||||
auto project_op = std::make_shared<ProjectOp>(project_columns_);
|
||||
node_ops.push_back(project_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(map_op);
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
Status MapNode::ValidateParams() {
|
||||
if (operations_.empty()) {
|
||||
std::string err_msg = "MapNode: No operation is specified.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
if (!input_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "input_columns", input_columns_));
|
||||
}
|
||||
|
||||
if (!output_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "output_columns", output_columns_));
|
||||
}
|
||||
|
||||
if (!project_columns_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("MapNode", "project_columns", project_columns_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
class MapNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
MapNode(std::shared_ptr<Dataset> child, std::vector<std::shared_ptr<TensorOperation>> operations,
|
||||
std::vector<std::string> input_columns = {}, std::vector<std::string> output_columns = {},
|
||||
const std::vector<std::string> &columns = {}, std::shared_ptr<DatasetCache> cache = nullptr);
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<TensorOperation>> operations_;
|
||||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
std::vector<std::string> project_columns_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/skip_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/skip_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for SkipNode
|
||||
SkipNode::SkipNode(std::shared_ptr<Dataset> child, int32_t count) : skip_count_(count) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
// Function to build the SkipOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> SkipNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
node_ops.push_back(std::make_shared<SkipOp>(skip_count_, connector_que_size_));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
// Function to validate the parameters for SkipNode
|
||||
Status SkipNode::ValidateParams() {
|
||||
if (skip_count_ <= -1) {
|
||||
std::string err_msg = "SkipNode: skip_count should not be negative, skip_count: " + std::to_string(skip_count_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
class SkipNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SkipNode(std::shared_ptr<Dataset> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/album_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/album_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for AlbumNode
|
||||
AlbumNode::AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode,
|
||||
const std::shared_ptr<SamplerObj> &sampler)
|
||||
: dataset_dir_(dataset_dir),
|
||||
schema_path_(data_schema),
|
||||
column_names_(column_names),
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
Status AlbumNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("AlbumNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("AlbumNode", {schema_path_}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("AlbumNode", sampler_));
|
||||
|
||||
if (!column_names_.empty()) {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("AlbumNode", "column_names", column_names_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build AlbumNode
|
||||
std::vector<std::shared_ptr<DatasetOp>> AlbumNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->LoadSchemaFile(schema_path_, column_names_));
|
||||
|
||||
// Argument that is not exposed to user in the API.
|
||||
std::set<std::string> extensions = {};
|
||||
|
||||
node_ops.push_back(std::make_shared<AlbumOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, extensions, std::move(schema), std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,58 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class AlbumNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
AlbumNode(const std::string &dataset_dir, const std::string &data_schema,
|
||||
const std::vector<std::string> &column_names, bool decode, const std::shared_ptr<SamplerObj> &sampler);
|
||||
|
||||
/// \brief Destructor
|
||||
~AlbumNode() = default;
|
||||
|
||||
/// \brief a base class override function to create a runtime dataset op object from this class
|
||||
/// \return shared pointer to the newly created DatasetOp
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string schema_path_;
|
||||
std::vector<std::string> column_names_;
|
||||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/celeba_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/celeba_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
// Constructor for CelebANode
|
||||
CelebANode::CelebANode(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::shared_ptr<SamplerObj> &sampler, const bool &decode,
|
||||
const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache)
|
||||
: Dataset(cache),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
sampler_(sampler),
|
||||
decode_(decode),
|
||||
extensions_(extensions) {}
|
||||
|
||||
Status CelebANode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("CelebANode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("CelebANode", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CelebANode", usage_, {"all", "train", "valid", "test"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CelebANode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CelebANode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
// label is like this:0 1 0 0 1......
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("attr", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 1)));
|
||||
|
||||
node_ops.push_back(std::make_shared<CelebAOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
decode_, usage_, extensions_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class CelebANode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CelebANode(const std::string &dataset_dir, const std::string &usage, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CelebANode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return shared pointer to the list of newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
bool decode_;
|
||||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
@ -0,0 +1,71 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar100_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for Cifar100Node
|
||||
Cifar100Node::Cifar100Node(const std::string &dataset_dir, const std::string &usage,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
Status Cifar100Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar100Node", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar100Node", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar100Node", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar100
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar100Node::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("coarse_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("fine_label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar100, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class Cifar100Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar100Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/cifar10_node.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/cifar_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for Cifar10Node
|
||||
Cifar10Node::Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)), dataset_dir_(dataset_dir), usage_(usage), sampler_(sampler) {}
|
||||
|
||||
Status Cifar10Node::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("Cifar10Node", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("Cifar10Node", sampler_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("Cifar10Node", usage_, {"train", "test", "all"}));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to build CifarOp for Cifar10
|
||||
std::vector<std::shared_ptr<DatasetOp>> Cifar10Node::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
auto schema = std::make_unique<DataSchema>();
|
||||
RETURN_EMPTY_IF_ERROR(schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kCv, 1)));
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_UINT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(std::make_shared<CifarOp>(CifarOp::CifarType::kCifar10, usage_, num_workers_, rows_per_buffer_,
|
||||
dataset_dir_, connector_que_size_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,56 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class Cifar10Node : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
Cifar10Node(const std::string &dataset_dir, const std::string &usage, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Node() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
@ -0,0 +1,218 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/clue_node.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/clue_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Constructor for CLUENode
|
||||
CLUENode::CLUENode(const std::vector<std::string> clue_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache)
|
||||
: Dataset(std::move(cache)),
|
||||
dataset_files_(clue_files),
|
||||
task_(task),
|
||||
usage_(usage),
|
||||
num_samples_(num_samples),
|
||||
shuffle_(shuffle),
|
||||
num_shards_(num_shards),
|
||||
shard_id_(shard_id) {}
|
||||
|
||||
Status CLUENode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetFilesParam("CLUENode", dataset_files_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", task_, {"AFQMC", "TNEWS", "IFLYTEK", "CMNLI", "WSC", "CSL"}));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateStringValue("CLUENode", usage_, {"train", "test", "eval"}));
|
||||
|
||||
if (num_samples_ < 0) {
|
||||
std::string err_msg = "CLUENode: Invalid number of samples: " + std::to_string(num_samples_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetShardParams("CLUENode", num_shards_, shard_id_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function to split string based on a character delimiter
|
||||
std::vector<std::string> CLUENode::split(const std::string &s, char delim) {
|
||||
std::vector<std::string> res;
|
||||
std::stringstream ss(s);
|
||||
std::string item;
|
||||
|
||||
while (getline(ss, item, delim)) {
|
||||
res.push_back(item);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
// Function to build CLUENode
|
||||
std::vector<std::shared_ptr<DatasetOp>> CLUENode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
std::map<std::string, std::string> key_map;
|
||||
if (task_ == "AFQMC") {
|
||||
if (usage_ == "train") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CMNLI") {
|
||||
if (usage_ == "train") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["sentence1"] = "sentence1";
|
||||
key_map["sentence2"] = "sentence2";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "CSL") {
|
||||
if (usage_ == "train") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["id"] = "id";
|
||||
key_map["abst"] = "abst";
|
||||
key_map["keyword"] = "keyword";
|
||||
key_map["label"] = "label";
|
||||
}
|
||||
} else if (task_ == "IFLYTEK") {
|
||||
if (usage_ == "train") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_des"] = "label_des";
|
||||
key_map["sentence"] = "sentence";
|
||||
}
|
||||
} else if (task_ == "TNEWS") {
|
||||
if (usage_ == "train") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["id"] = "id";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["label"] = "label";
|
||||
key_map["label_desc"] = "label_desc";
|
||||
key_map["sentence"] = "sentence";
|
||||
key_map["keywords"] = "keywords";
|
||||
}
|
||||
} else if (task_ == "WSC") {
|
||||
if (usage_ == "train") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "test") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["text"] = "text";
|
||||
} else if (usage_ == "eval") {
|
||||
key_map["span1_index"] = "target/span1_index";
|
||||
key_map["span2_index"] = "target/span2_index";
|
||||
key_map["span1_text"] = "target/span1_text";
|
||||
key_map["span2_text"] = "target/span2_text";
|
||||
key_map["idx"] = "idx";
|
||||
key_map["label"] = "label";
|
||||
key_map["text"] = "text";
|
||||
}
|
||||
}
|
||||
|
||||
ColKeyMap ck_map;
|
||||
for (auto &p : key_map) {
|
||||
ck_map.insert({p.first, split(p.second, '/')});
|
||||
}
|
||||
|
||||
bool shuffle_files = (shuffle_ == ShuffleMode::kGlobal || shuffle_ == ShuffleMode::kFiles);
|
||||
|
||||
// Sort the dataset files in a lexicographical order
|
||||
std::vector<std::string> sorted_dataset_files = dataset_files_;
|
||||
std::sort(sorted_dataset_files.begin(), sorted_dataset_files.end());
|
||||
|
||||
std::shared_ptr<ClueOp> clue_op =
|
||||
std::make_shared<ClueOp>(num_workers_, rows_per_buffer_, num_samples_, worker_connector_size_, ck_map,
|
||||
sorted_dataset_files, connector_que_size_, shuffle_files, num_shards_, shard_id_, nullptr);
|
||||
RETURN_EMPTY_IF_ERROR(clue_op->Init());
|
||||
if (shuffle_ == ShuffleMode::kGlobal) {
|
||||
// Inject ShuffleOp
|
||||
std::shared_ptr<DatasetOp> shuffle_op = nullptr;
|
||||
int64_t num_rows = 0;
|
||||
|
||||
// First, get the number of rows in the dataset
|
||||
RETURN_EMPTY_IF_ERROR(ClueOp::CountAllFileRows(sorted_dataset_files, &num_rows));
|
||||
|
||||
// Add the shuffle op after this op
|
||||
RETURN_EMPTY_IF_ERROR(AddShuffleOp(sorted_dataset_files.size(), num_shards_, num_rows, 0, connector_que_size_,
|
||||
rows_per_buffer_, &shuffle_op));
|
||||
node_ops.push_back(shuffle_op);
|
||||
}
|
||||
RETURN_EMPTY_IF_ERROR(AddCacheOp(&node_ops));
|
||||
|
||||
node_ops.push_back(clue_op);
|
||||
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
/// \class CLUENode
|
||||
/// \brief A Dataset derived class to represent CLUE dataset
|
||||
class CLUENode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
CLUENode(const std::vector<std::string> dataset_files, std::string task, std::string usage, int64_t num_samples,
|
||||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CLUENode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
/// \brief Split string based on a character delimiter
|
||||
/// \return A string vector
|
||||
std::vector<std::string> split(const std::string &s, char delim);
|
||||
|
||||
std::vector<std::string> dataset_files_;
|
||||
std::string task_;
|
||||
std::string usage_;
|
||||
int64_t num_samples_;
|
||||
ShuffleMode shuffle_;
|
||||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue