!7540 Extract ImageFolderNode and BatchNode into its own files from datasets.cc/.h
Merge pull request !7540 from ZiruiWu/refactor_ir_demopull/7540/MERGE
commit
de7ea8e68c
@ -0,0 +1,3 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(datasetops)
|
@ -0,0 +1,5 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_subdirectory(source)
|
||||
add_library(engine-ir-datasetops OBJECT
|
||||
batch_node.cc)
|
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/batch_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
BatchNode::BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
|
||||
std::vector<std::string> cols_to_map,
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map)
|
||||
: batch_size_(batch_size),
|
||||
drop_remainder_(drop_remainder),
|
||||
pad_(pad),
|
||||
cols_to_map_(cols_to_map),
|
||||
pad_map_(pad_map) {
|
||||
this->children.push_back(child);
|
||||
}
|
||||
|
||||
Status BatchNode::ValidateParams() {
|
||||
if (batch_size_ <= 0) {
|
||||
std::string err_msg = "Batch: batch_size should be positive integer, but got: " + std::to_string(batch_size_);
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
if (!cols_to_map_.empty()) {
|
||||
std::string err_msg = "cols_to_map functionality is not implemented in C++; this should be left empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> BatchNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
#ifdef ENABLE_PYTHON
|
||||
py::function noop;
|
||||
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
cols_to_map_, cols_to_map_, noop, noop, pad_map_));
|
||||
#else
|
||||
node_ops.push_back(std::make_shared<BatchOp>(batch_size_, drop_remainder_, pad_, connector_que_size_, num_workers_,
|
||||
cols_to_map_, pad_map_));
|
||||
#endif
|
||||
|
||||
// Until py::function is implemented for C++ API, there is no need for a project op to be inserted after batch
|
||||
// because project is only needed when batch op performs per_batch_map. This per_batch_map is a pyfunc
|
||||
return node_ops;
|
||||
}
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,61 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
class BatchNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
BatchNode(std::shared_ptr<Dataset> child, int32_t batch_size, bool drop_remainder, bool pad,
|
||||
std::vector<std::string> cols_to_map,
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map);
|
||||
|
||||
/// \brief Destructor
|
||||
~BatchNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
int32_t batch_size_;
|
||||
bool drop_remainder_;
|
||||
bool pad_;
|
||||
std::vector<std::string> cols_to_map_;
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
|
||||
};
|
||||
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_
|
@ -0,0 +1,4 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-ir-datasetops-source OBJECT
|
||||
image_folder_node.cc)
|
@ -0,0 +1,70 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/image_folder_node.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
|
||||
#include "minddata/dataset/util/status.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
ImageFolderNode::ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
bool recursive, std::set<std::string> extensions,
|
||||
std::map<std::string, int32_t> class_indexing)
|
||||
: dataset_dir_(dataset_dir),
|
||||
decode_(decode),
|
||||
sampler_(sampler),
|
||||
recursive_(recursive),
|
||||
class_indexing_(class_indexing),
|
||||
exts_(extensions) {}
|
||||
|
||||
Status ImageFolderNode::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateDatasetDirParam("ImageFolderNode", dataset_dir_));
|
||||
|
||||
RETURN_IF_NOT_OK(ValidateDatasetSampler("ImageFolderNode", sampler_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<DatasetOp>> ImageFolderNode::Build() {
|
||||
// A vector containing shared pointer to the Dataset Ops that this object will create
|
||||
std::vector<std::shared_ptr<DatasetOp>> node_ops;
|
||||
|
||||
// Do internal Schema generation.
|
||||
// This arg is exist in ImageFolderOp, but not externalized (in Python API).
|
||||
std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>();
|
||||
TensorShape scalar = TensorShape::CreateScalar();
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("image", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1)));
|
||||
RETURN_EMPTY_IF_ERROR(
|
||||
schema->AddColumn(ColDescriptor("label", DataType(DataType::DE_INT32), TensorImpl::kFlexible, 0, &scalar)));
|
||||
node_ops.push_back(std::make_shared<ImageFolderOp>(num_workers_, rows_per_buffer_, dataset_dir_, connector_que_size_,
|
||||
recursive_, decode_, exts_, class_indexing_, std::move(schema),
|
||||
std::move(sampler_->Build())));
|
||||
return node_ops;
|
||||
}
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMAGE_FOLDER_NODE_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMAGE_FOLDER_NODE_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace api {
|
||||
|
||||
/// \class ImageFolderNode
|
||||
/// \brief A Dataset derived class to represent ImageFolder dataset
|
||||
class ImageFolderNode : public Dataset {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
ImageFolderNode(std::string dataset_dir, bool decode, std::shared_ptr<SamplerObj> sampler, bool recursive,
|
||||
std::set<std::string> extensions, std::map<std::string, int32_t> class_indexing);
|
||||
|
||||
/// \brief Destructor
|
||||
~ImageFolderNode() = default;
|
||||
|
||||
/// \brief a base class override function to create the required runtime dataset op objects for this class
|
||||
/// \return The list of shared pointers to the newly created DatasetOps
|
||||
std::vector<std::shared_ptr<DatasetOp>> Build() override;
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
Status ValidateParams() override;
|
||||
|
||||
private:
|
||||
std::string dataset_dir_;
|
||||
bool decode_;
|
||||
bool recursive_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::map<std::string, int32_t> class_indexing_;
|
||||
std::set<std::string> exts_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_IMAGE_FOLDER_NODE_H_
|
Loading…
Reference in new issue