From d3fdbceb613d2cf1004b0b1dc4ecb13cffda4d62 Mon Sep 17 00:00:00 2001 From: TinaMengtingZhang Date: Tue, 12 Jan 2021 18:06:47 -0500 Subject: [PATCH] [Part II] Push down json save logic to IR and add getter to each IR node --- .../ccsrc/minddata/dataset/api/vision.cc | 75 +++++++ .../engine/ir/cache/dataset_cache_impl.cc | 13 -- .../engine/ir/cache/dataset_cache_impl.h | 2 - .../ir/cache/pre_built_dataset_cache.cc | 10 + .../engine/ir/cache/pre_built_dataset_cache.h | 2 + .../datasetops/bucket_batch_by_length_node.h | 9 + .../build_sentence_piece_vocab_node.h | 8 + .../engine/ir/datasetops/build_vocab_node.h | 8 + .../engine/ir/datasetops/concat_node.h | 4 + .../engine/ir/datasetops/filter_node.cc | 8 + .../engine/ir/datasetops/filter_node.h | 9 + .../dataset/engine/ir/datasetops/map_node.cc | 2 +- .../engine/ir/datasetops/project_node.cc | 6 + .../engine/ir/datasetops/project_node.h | 8 + .../dataset/engine/ir/datasetops/root_node.h | 2 +- .../dataset/engine/ir/datasetops/skip_node.cc | 7 + .../dataset/engine/ir/datasetops/skip_node.h | 8 + .../engine/ir/datasetops/source/album_node.h | 6 + .../ir/datasetops/source/celeba_node.cc | 17 ++ .../engine/ir/datasetops/source/celeba_node.h | 11 + .../ir/datasetops/source/cifar100_node.cc | 15 ++ .../ir/datasetops/source/cifar100_node.h | 9 + .../ir/datasetops/source/cifar10_node.cc | 15 ++ .../ir/datasetops/source/cifar10_node.h | 9 + .../engine/ir/datasetops/source/clue_node.cc | 18 ++ .../engine/ir/datasetops/source/clue_node.h | 14 ++ .../engine/ir/datasetops/source/coco_node.cc | 17 ++ .../engine/ir/datasetops/source/coco_node.h | 11 + .../engine/ir/datasetops/source/csv_node.cc | 18 ++ .../engine/ir/datasetops/source/csv_node.h | 15 ++ .../ir/datasetops/source/generator_node.h | 6 + .../ir/datasetops/source/manifest_node.cc | 18 ++ .../ir/datasetops/source/manifest_node.h | 11 + .../engine/ir/datasetops/source/random_node.h | 8 + .../ir/datasetops/source/text_file_node.cc | 16 ++ .../ir/datasetops/source/text_file_node.h | 12 ++ .../engine/ir/datasetops/source/voc_node.cc | 18 ++ .../engine/ir/datasetops/source/voc_node.h | 12 ++ .../engine/ir/datasetops/sync_wait_node.h | 4 + .../dataset/engine/ir/datasetops/take_node.cc | 7 + .../dataset/engine/ir/datasetops/take_node.h | 8 + .../engine/ir/datasetops/transfer_node.cc | 9 + .../engine/ir/datasetops/transfer_node.h | 14 ++ .../ccsrc/minddata/dataset/include/vision.h | 10 + .../minddata/dataset/include/vision_lite.h | 8 + .../dataset/kernels/data/type_cast_op.cc | 7 + .../dataset/kernels/data/type_cast_op.h | 2 + .../dataset/kernels/image/decode_op.cc | 6 - .../dataset/kernels/image/decode_op.h | 2 - .../dataset/kernels/image/random_crop_op.cc | 11 - .../dataset/kernels/image/random_crop_op.h | 2 - .../dataset/kernels/image/rescale_op.cc | 8 - .../dataset/kernels/image/rescale_op.h | 2 - .../dataset/kernels/image/resize_op.cc | 8 - .../dataset/kernels/image/resize_op.h | 2 - .../dataset/engine/serializer_deserializer.py | 151 ++++++++++++-- .../ut/python/dataset/test_serdes_dataset.py | 196 +++++++++++++++++- 57 files changed, 826 insertions(+), 88 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index ea7a6dda7d..2400eb5344 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -526,6 +526,11 @@ std::shared_ptr CenterCropOperation::Build() { return tensor_op; } +Status CenterCropOperation::to_json(nlohmann::json *out_json) { + (*out_json)["size"] = size_; + return Status::OK(); +} + // CropOperation. CropOperation::CropOperation(std::vector coordinates, std::vector size) : coordinates_(coordinates), size_(size) {} @@ -638,6 +643,11 @@ Status DecodeOperation::ValidateParams() { return Status::OK(); } std::shared_ptr DecodeOperation::Build() { return std::make_shared(rgb_); } +Status DecodeOperation::to_json(nlohmann::json *out_json) { + (*out_json)["rgb"] = rgb_; + return Status::OK(); +} + // EqualizeOperation Status EqualizeOperation::ValidateParams() { return Status::OK(); } @@ -801,6 +811,14 @@ std::shared_ptr NormalizeOperation::Build() { return std::make_shared(mean_[0], mean_[1], mean_[2], std_[0], std_[1], std_[2]); } +Status NormalizeOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["mean"] = mean_; + args["std"] = std_; + *out_json = args; + return Status::OK(); +} + #ifndef ENABLE_ANDROID // NormalizePadOperation NormalizePadOperation::NormalizePadOperation(const std::vector &mean, const std::vector &std, @@ -893,6 +911,15 @@ std::shared_ptr PadOperation::Build() { return tensor_op; } +Status PadOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["padding"] = padding_; + args["fill_value"] = fill_value_; + args["padding_mode"] = padding_mode_; + *out_json = args; + return Status::OK(); +} + // RandomAffineOperation RandomAffineOperation::RandomAffineOperation(const std::vector °rees, const std::vector &translate_range, @@ -1188,6 +1215,16 @@ std::shared_ptr RandomColorAdjustOperation::Build() { return tensor_op; } +Status RandomColorAdjustOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["brightness"] = brightness_; + args["contrast"] = contrast_; + args["saturation"] = saturation_; + args["hue"] = hue_; + *out_json = args; + return Status::OK(); +} + // RandomCropOperation RandomCropOperation::RandomCropOperation(std::vector size, std::vector padding, bool pad_if_needed, std::vector fill_value, BorderType padding_mode) @@ -1261,6 +1298,17 @@ std::shared_ptr RandomCropOperation::Build() { return tensor_op; } +Status RandomCropOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["size"] = size_; + args["padding"] = padding_; + args["pad_if_needed"] = pad_if_needed_; + args["fill_value"] = fill_value_; + args["padding_mode"] = padding_mode_; + *out_json = args; + return Status::OK(); +} + // RandomCropDecodeResizeOperation RandomCropDecodeResizeOperation::RandomCropDecodeResizeOperation(std::vector size, std::vector scale, std::vector ratio, @@ -1735,6 +1783,17 @@ std::shared_ptr RandomRotationOperation::Build() { return tensor_op; } +Status RandomRotationOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["degrees"] = degrees_; + args["interpolation_mode"] = interpolation_mode_; + args["expand"] = expand_; + args["center"] = center_; + args["fill_value"] = fill_value_; + *out_json = args; + return Status::OK(); +} + // RandomSelectSubpolicyOperation. RandomSelectSubpolicyOperation::RandomSelectSubpolicyOperation( std::vector, double>>> policy) @@ -1889,6 +1948,14 @@ std::shared_ptr RescaleOperation::Build() { return tensor_op; } +Status RescaleOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["rescale"] = rescale_; + args["shift"] = shift_; + *out_json = args; + return Status::OK(); +} + #endif // ResizeOperation ResizeOperation::ResizeOperation(std::vector size, InterpolationMode interpolation) @@ -1920,6 +1987,14 @@ std::shared_ptr ResizeOperation::Build() { return std::make_shared(height, width, interpolation_); } +Status ResizeOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["size"] = size_; + args["interpolation"] = interpolation_; + *out_json = args; + return Status::OK(); +} + // RotateOperation RotateOperation::RotateOperation() { rotate_op = std::make_shared(0); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc index c7d90365f8..648744eb5f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/dataset_cache_impl.cc @@ -44,18 +44,5 @@ Status DatasetCacheImpl::CreateCacheOp(int32_t num_workers, std::shared_ptr cache_client_; session_id_type session_id_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc index b99c06fbd2..aee0c0b1ab 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.cc @@ -36,5 +36,15 @@ Status PreBuiltDatasetCache::CreateCacheOp(int32_t num_workers, std::shared_ptr< return Status::OK(); } +Status PreBuiltDatasetCache::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["session_id"] = cache_client_->session_id(); + args["cache_memory_size"] = cache_client_->GetCacheMemSz(); + args["spill"] = cache_client_->isSpill(); + args["num_connections"] = cache_client_->GetNumConnections(); + args["prefetch_size"] = cache_client_->GetPrefetchSize(); + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h index 1a5c733dda..fa11d9ee87 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/cache/pre_built_dataset_cache.h @@ -42,6 +42,8 @@ class PreBuiltDatasetCache : public DatasetCache { Status ValidateParams() override { return Status::OK(); } + Status to_json(nlohmann::json *out_json) override; + private: std::shared_ptr cache_client_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h index 1afe5a77c6..77ff1b3442 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/bucket_batch_by_length_node.h @@ -63,6 +63,15 @@ class BucketBatchByLengthNode : public DatasetNode { bool IsSizeDefined() override { return false; }; + /// \brief Getter functions + const std::vector &ColumnNames() const { return column_names_; } + const std::vector &BucketBoundaries() const { return bucket_boundaries_; } + const std::vector &BucketBatchSizes() const { return bucket_batch_sizes_; } + const std::shared_ptr &ElementLengthFunction() const { return element_length_function_; } + const std::map>> &PadInfo() const { return pad_info_; } + bool PadToBucketBoundary() const { return pad_to_bucket_boundary_; } + bool DropRemainder() const { return drop_remainder_; } + private: std::vector column_names_; std::vector bucket_boundaries_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h index 7a073a1e24..4d668d08db 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_sentence_piece_vocab_node.h @@ -72,6 +72,14 @@ class BuildSentenceVocabNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + const std::shared_ptr &GetVocab() const { return vocab_; } + const std::vector &ColNames() const { return col_names_; } + int32_t VocabSize() const { return vocab_size_; } + float CharacterCoverage() const { return character_coverage_; } + SentencePieceModel ModelType() const { return model_type_; } + const std::unordered_map &Params() const { return params_; } + private: std::shared_ptr vocab_; std::vector col_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h index d39ce4c84b..02b57ba29a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/build_vocab_node.h @@ -70,6 +70,14 @@ class BuildVocabNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + const std::shared_ptr &GetVocab() const { return vocab_; } + const std::vector &Columns() const { return columns_; } + const std::pair &FreqRange() const { return freq_range_; } + int64_t TopK() const { return top_k_; } + const std::vector &SpecialTokens() const { return special_tokens_; } + bool SpecialFirst() const { return special_first_; } + private: std::shared_ptr vocab_; std::vector columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h index 45d238ad11..e2dc960754 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/concat_node.h @@ -61,6 +61,10 @@ class ConcatNode : public DatasetNode { bool IsSizeDefined() override { return false; } + /// \brief Getter functions + const std::vector> &ChildrenFlagAndNums() const { return children_flag_and_nums_; } + const std::vector> &ChildrenStartEndIndex() const { return children_start_end_index_; } + private: std::shared_ptr sampler_; std::vector> children_flag_and_nums_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc index ea996a003c..24be067e73 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.cc @@ -73,5 +73,13 @@ Status FilterNode::AcceptAfter(IRNodePass *const p, bool *const modified) { return p->VisitAfter(shared_from_base(), modified); } +Status FilterNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["input_columns"] = input_columns_; + args["num_parallel_workers"] = num_workers_; + args["predicate"] = "pyfunc"; + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h index 0eae36668f..e7fdc959f5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/filter_node.h @@ -70,6 +70,15 @@ class FilterNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + const std::shared_ptr &Predicate() const { return predicate_; } + const std::vector &InputColumns() const { return input_columns_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::shared_ptr predicate_; std::vector input_columns_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc index 9da5f4de57..90acb95463 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/map_node.cc @@ -138,8 +138,8 @@ Status MapNode::to_json(nlohmann::json *out_json) { std::vector ops; std::vector cbs; - nlohmann::json op_args; for (auto op : operations_) { + nlohmann::json op_args; RETURN_IF_NOT_OK(op->to_json(&op_args)); op_args["tensor_op_name"] = op->Name(); ops.push_back(op_args); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc index 50174d87bf..e0fb878579 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.cc @@ -57,5 +57,11 @@ Status ProjectNode::Build(std::vector> *const node_op return Status::OK(); } +Status ProjectNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["columns"] = columns_; + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h index de77bd997c..791bf8f865 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/project_node.h @@ -55,6 +55,14 @@ class ProjectNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Getter functions + const std::vector &Columns() const { return columns_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::vector columns_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h index 66cc26b9d8..283a389686 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/root_node.h @@ -29,7 +29,7 @@ namespace dataset { class RootNode : public DatasetNode { public: /// \brief Constructor - RootNode() : DatasetNode() {} + RootNode() : DatasetNode(), num_epochs_(0) {} /// \brief Constructor explicit RootNode(std::shared_ptr child); diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc index 9abcbeb9e6..eb29c30391 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.cc @@ -83,5 +83,12 @@ Status SkipNode::AcceptAfter(IRNodePass *const p, bool *const modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } + +Status SkipNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["count"] = skip_count_; + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h index 323b50275d..e98e49036c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/skip_node.h @@ -80,6 +80,14 @@ class SkipNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + int32_t SkipCount() const { return skip_count_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: int32_t skip_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h index 258d1dde9c..3cba8ca785 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/album_node.h @@ -61,6 +61,12 @@ class AlbumNode : public MappableSourceNode { /// \return Status Status::OK() if get shard id successfully Status GetShardId(int32_t *shard_id) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &SchemaPath() const { return schema_path_; } + const std::vector &ColumnNames() const { return column_names_; } + bool Decode() const { return decode_; } + private: std::string dataset_dir_; std::string schema_path_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc index 4ec79a4c0c..408f877b55 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.cc @@ -144,5 +144,22 @@ Status CelebANode::GetDatasetSize(const std::shared_ptr &size return Status::OK(); } +Status CelebANode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["decode"] = decode_; + args["extensions"] = extensions_; + args["usage"] = usage_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h index 79cbdeabfd..413584f307 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/celeba_node.h @@ -71,6 +71,17 @@ class CelebANode : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + bool Decode() const { return decode_; } + const std::set &Extensions() const { return extensions_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc index 8ea425a913..d48cc28a28 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.cc @@ -95,5 +95,20 @@ Status Cifar100Node::GetDatasetSize(const std::shared_ptr &si return Status::OK(); } +Status Cifar100Node::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h index a18fa827d8..705e3dd5c4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar100_node.h @@ -69,6 +69,15 @@ class Cifar100Node : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc index d5afadcdb3..7b0e103c99 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.cc @@ -93,5 +93,20 @@ Status Cifar10Node::GetDatasetSize(const std::shared_ptr &siz return Status::OK(); } +Status Cifar10Node::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["usage"] = usage_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h index 2ea179856a..7b4c4161da 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/cifar10_node.h @@ -69,6 +69,15 @@ class Cifar10Node : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Usage() const { return usage_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string dataset_dir_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc index 6320f39cf0..d07ac6984a 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.cc @@ -252,5 +252,23 @@ Status CLUENode::GetDatasetSize(const std::shared_ptr &size_g return Status::OK(); } +Status CLUENode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_files_; + args["task"] = task_; + args["usage"] = usage_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h index 1d78504b7f..cffe8d026c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/clue_node.h @@ -71,6 +71,20 @@ class CLUENode : public NonMappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::vector &DatasetFiles() const { return dataset_files_; } + const std::string &Task() const { return task_; } + const std::string &Usage() const { return usage_; } + int64_t NumSamples() const { return num_samples_; } + ShuffleMode Shuffle() const { return shuffle_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: /// \brief Split string based on a character delimiter /// \return A string vector diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc index cae79ab334..bb03ce14a6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.cc @@ -151,5 +151,22 @@ Status CocoNode::GetDatasetSize(const std::shared_ptr &size_g return Status::OK(); } +Status CocoNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["annotation_file"] = annotation_file_; + args["task"] = task_; + args["decode"] = decode_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h index b431f6394b..f0660c9d8c 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/coco_node.h @@ -69,6 +69,17 @@ class CocoNode : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &AnnotationFile() const { return annotation_file_; } + const std::string &Task() const { return task_; } + bool Decode() const { return decode_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string dataset_dir_; std::string annotation_file_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc index 8d3c0f0f30..da1d3e66b0 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.cc @@ -170,5 +170,23 @@ Status CSVNode::GetDatasetSize(const std::shared_ptr &size_ge return Status::OK(); } +Status CSVNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["dataset_files"] = dataset_files_; + args["field_delim"] = std::string(1, field_delim_); + args["column_names"] = column_names_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h index 490f842ce1..03d4539366 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/csv_node.h @@ -92,6 +92,21 @@ class CSVNode : public NonMappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::vector &DatasetFiles() const { return dataset_files_; } + char FieldDelim() const { return field_delim_; } + const std::vector> &ColumnDefaults() const { return column_defaults_; } + const std::vector &ColumnNames() const { return column_names_; } + int64_t NumSamples() const { return num_samples_; } + ShuffleMode Shuffle() const { return shuffle_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::vector dataset_files_; char field_delim_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h index 57e39bee83..cd74569a05 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/generator_node.h @@ -83,6 +83,12 @@ class GeneratorNode : public MappableSourceNode { return Status::OK(); } + /// \brief Getter functions + const py::function &GeneratorFunction() const { return generator_function_; } + const std::vector &ColumnNames() const { return column_names_; } + const std::vector &ColumnTypes() const { return column_types_; } + const std::shared_ptr &Schema() const { return schema_; } + private: py::function generator_function_; std::vector column_names_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc index 768e83707d..413cb4dc37 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.cc @@ -124,5 +124,23 @@ Status ManifestNode::GetDatasetSize(const std::shared_ptr &si return Status::OK(); } +Status ManifestNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_file"] = dataset_file_; + args["usage"] = usage_; + args["class_indexing"] = class_index_; + args["decode"] = decode_; + + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h index 4ebbff9cc2..0f4cb9ecdc 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/manifest_node.h @@ -70,6 +70,17 @@ class ManifestNode : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetFile() const { return dataset_file_; } + const std::string &Usage() const { return usage_; } + bool Decode() const { return decode_; } + const std::map &ClassIndex() const { return class_index_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string dataset_file_; std::string usage_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h index 8a0bae512f..9088139235 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/random_node.h @@ -91,6 +91,14 @@ class RandomNode : public NonMappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + int32_t TotalRows() const { return total_rows_; } + const std::string &SchemaPath() const { return schema_path_; } + const std::shared_ptr &GetSchema() const { return schema_; } + const std::vector &ColumnsList() const { return columns_list_; } + const std::mt19937 &RandGen() const { return rand_gen_; } + const std::unique_ptr &GetDataSchema() const { return data_schema_; } + private: /// \brief A quick inline for producing a random number between (and including) min/max /// \param[in] min minimum number that can be generated. diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc index 136dbb23a4..6fcb4ebab5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.cc @@ -136,5 +136,21 @@ Status TextFileNode::GetDatasetSize(const std::shared_ptr &si return Status::OK(); } +Status TextFileNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["num_parallel_workers"] = num_workers_; + args["dataset_files"] = dataset_files_; + args["num_samples"] = num_samples_; + args["shuffle"] = shuffle_; + args["num_shards"] = num_shards_; + args["shard_id"] = shard_id_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h index 94c0afe863..300251c3f9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/text_file_node.h @@ -71,6 +71,18 @@ class TextFileNode : public NonMappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::vector &DatasetFiles() const { return dataset_files_; } + int32_t NumSamples() const { return num_samples_; } + int32_t NumShards() const { return num_shards_; } + int32_t ShardId() const { return shard_id_; } + ShuffleMode Shuffle() const { return shuffle_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::vector dataset_files_; int32_t num_samples_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc index 81ee00311f..25f9623640 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.cc @@ -140,5 +140,23 @@ Status VOCNode::GetDatasetSize(const std::shared_ptr &size_ge return Status::OK(); } +Status VOCNode::to_json(nlohmann::json *out_json) { + nlohmann::json args, sampler_args; + RETURN_IF_NOT_OK(sampler_->to_json(&sampler_args)); + args["sampler"] = sampler_args; + args["num_parallel_workers"] = num_workers_; + args["dataset_dir"] = dataset_dir_; + args["task"] = task_; + args["usage"] = usage_; + args["class_indexing"] = class_index_; + args["decode"] = decode_; + if (cache_ != nullptr) { + nlohmann::json cache_args; + RETURN_IF_NOT_OK(cache_->to_json(&cache_args)); + args["cache"] = cache_args; + } + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h index ba4df831fa..195ba34ddf 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/source/voc_node.h @@ -71,6 +71,18 @@ class VOCNode : public MappableSourceNode { Status GetDatasetSize(const std::shared_ptr &size_getter, bool estimate, int64_t *dataset_size) override; + /// \brief Getter functions + const std::string &DatasetDir() const { return dataset_dir_; } + const std::string &Task() const { return task_; } + const std::string &Usage() const { return usage_; } + const std::map &ClassIndex() const { return class_index_; } + bool Decode() const { return decode_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: const std::string kColumnImage = "image"; const std::string kColumnTarget = "target"; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h index cd1a07304b..7c68390d67 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/sync_wait_node.h @@ -57,6 +57,10 @@ class SyncWaitNode : public DatasetNode { /// \return Status Status::OK() if all the parameters are valid Status ValidateParams() override; + /// \brief Getter functions + const std::string &ConditionName() const { return condition_name_; } + const py::function &Callback() const { return callback_; } + private: std::string condition_name_; py::function callback_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc index efb14b1fda..94efe21b5d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.cc @@ -81,5 +81,12 @@ Status TakeNode::AcceptAfter(IRNodePass *const p, bool *const modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } + +Status TakeNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["count"] = take_count_; + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h index 3446b73337..598ba44598 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/take_node.h @@ -80,6 +80,14 @@ class TakeNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + int32_t TakeCount() const { return take_count_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: int32_t take_count_; }; diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc index a9098a9a92..79732c516f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.cc @@ -116,5 +116,14 @@ Status TransferNode::AcceptAfter(IRNodePass *const p, bool *const modified) { // Downcast shared pointer then call visitor return p->VisitAfter(shared_from_base(), modified); } + +Status TransferNode::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["send_epoch_end"] = send_epoch_end_; + args["total_batch"] = total_batch_; + args["create_data_info_queue"] = create_data_info_queue_; + *out_json = args; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h index e2154c98a5..9d5617f2a5 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/ir/datasetops/transfer_node.h @@ -70,6 +70,20 @@ class TransferNode : public DatasetNode { /// \return Status of the node visit Status AcceptAfter(IRNodePass *const p, bool *const modified) override; + /// \brief Getter functions + const std::string &QueueName() const { return queue_name_; } + int32_t DeviceId() const { return device_id_; } + const std::string &DeviceType() const { return device_type_; } + int32_t PrefetchSize() const { return prefetch_size_; } + bool SendEpochEnd() const { return send_epoch_end_; } + int32_t TotalBatch() const { return total_batch_; } + bool CreateDataInfoQueue() const { return create_data_info_queue_; } + + /// \brief Get the arguments of node + /// \param[out] out_json JSON string of all attributes + /// \return Status of the function + Status to_json(nlohmann::json *out_json) override; + private: std::string queue_name_; int32_t device_id_; diff --git a/mindspore/ccsrc/minddata/dataset/include/vision.h b/mindspore/ccsrc/minddata/dataset/include/vision.h index d6aba2d625..7615ce0a27 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision.h @@ -668,6 +668,8 @@ class PadOperation : public TensorOperation { std::string Name() const override { return kPadOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector padding_; std::vector fill_value_; @@ -729,6 +731,8 @@ class RandomColorAdjustOperation : public TensorOperation { std::string Name() const override { return kRandomColorAdjustOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector brightness_; std::vector contrast_; @@ -750,6 +754,8 @@ class RandomCropOperation : public TensorOperation { std::string Name() const override { return kRandomCropOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector size_; std::vector padding_; @@ -936,6 +942,8 @@ class RandomRotationOperation : public TensorOperation { std::string Name() const override { return kRandomRotationOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector degrees_; InterpolationMode interpolation_mode_; @@ -1037,6 +1045,8 @@ class RescaleOperation : public TensorOperation { std::string Name() const override { return kRescaleOperation; } + Status to_json(nlohmann::json *out_json) override; + private: float rescale_; float shift_; diff --git a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h index aa7d56b031..d431240667 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision_lite.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision_lite.h @@ -105,6 +105,8 @@ class CenterCropOperation : public TensorOperation { std::string Name() const override { return kCenterCropOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector size_; }; @@ -137,6 +139,8 @@ class DecodeOperation : public TensorOperation { std::string Name() const override { return kDecodeOperation; } + Status to_json(nlohmann::json *out_json) override; + private: bool rgb_; }; @@ -153,6 +157,8 @@ class NormalizeOperation : public TensorOperation { std::string Name() const override { return kNormalizeOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector mean_; std::vector std_; @@ -171,6 +177,8 @@ class ResizeOperation : public TensorOperation { std::string Name() const override { return kResizeOperation; } + Status to_json(nlohmann::json *out_json) override; + private: std::vector size_; InterpolationMode interpolation_; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc index 5a58745293..d420140295 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/data/type_cast_op.cc @@ -33,5 +33,12 @@ Status TypeCastOp::OutputType(const std::vector &inputs, std::vector &inputs, std::vector &inputs, std::ve if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } - -Status RandomCropOp::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["size"] = std::vector{crop_height_, crop_width_}; - args["padding"] = std::vector{pad_top_, pad_bottom_, pad_left_, pad_right_}; - args["pad_if_needed"] = pad_if_needed_; - args["fill_value"] = std::tuple{fill_r_, fill_g_, fill_b_}; - args["padding_mode"] = border_type_; - *out_json = args; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h index 21adeab38c..20f3395697 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/random_crop_op.h @@ -79,8 +79,6 @@ class RandomCropOp : public TensorOp { std::string Name() const override { return kRandomCropOp; } - Status to_json(nlohmann::json *out_json) override; - protected: int32_t crop_height_ = 0; int32_t crop_width_ = 0; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc index 94087cc3c9..2a500d6c34 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/rescale_op.cc @@ -29,13 +29,5 @@ Status RescaleOp::OutputType(const std::vector &inputs, std::vector &inputs, std::vector if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } - -Status ResizeOp::to_json(nlohmann::json *out_json) { - nlohmann::json args; - args["size"] = std::vector{size1_, size2_}; - args["interpolation"] = interpolation_; - *out_json = args; - return Status::OK(); -} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h index 68cb696ad4..336c23f6ff 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/resize_op.h @@ -61,8 +61,6 @@ class ResizeOp : public TensorOp { std::string Name() const override { return kResizeOp; } - Status to_json(nlohmann::json *out_json) override; - protected: int32_t size1_; int32_t size2_; diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 843bd16d32..30c0263ab4 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -167,19 +167,82 @@ def create_node(node): pyobj = None # Find a matching Dataset class and call the constructor with the corresponding args. # When a new Dataset class is introduced, another if clause and parsing code needs to be added. - if dataset_op == 'ImageFolderDataset': + # Dataset Source Ops (in alphabetical order) + if dataset_op == 'CelebADataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_dir'], node.get('num_parallel_workers'), node.get('shuffle'), node.get('usage'), + sampler, node.get('decode'), node.get('extensions'), num_samples, node.get('num_shards'), + node.get('shard_id')) + + elif dataset_op == 'Cifar10Dataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), + node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'Cifar100Dataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), + node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'ClueDataset': + shuffle = to_shuffle_mode(node.get('shuffle')) + if shuffle is not None and isinstance(shuffle, str): + shuffle = de.Shuffle(shuffle) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_files'], node.get('task'), + node.get('usage'), num_samples, node.get('num_parallel_workers'), shuffle, + node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'CocoDataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_dir'], node.get('annotation_file'), node.get('task'), num_samples, + node.get('num_parallel_workers'), node.get('shuffle'), node.get('decode'), sampler, + node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'CSVDataset': + shuffle = to_shuffle_mode(node.get('shuffle')) + if shuffle is not None and isinstance(shuffle, str): + shuffle = de.Shuffle(shuffle) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_files'], node.get('field_delim'), + node.get('column_defaults'), node.get('column_names'), num_samples, + node.get('num_parallel_workers'), shuffle, + node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'ImageFolderDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], num_samples, node.get('num_parallel_workers'), node.get('shuffle'), sampler, node.get('extensions'), node.get('class_indexing'), node.get('decode'), node.get('num_shards'), - node.get('shard_id'), node.get('cache')) + node.get('shard_id')) + + elif dataset_op == 'ManifestDataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_file'], node['usage'], num_samples, + node.get('num_parallel_workers'), node.get('shuffle'), sampler, + node.get('class_indexing'), node.get('decode'), node.get('num_shards'), + node.get('shard_id')) elif dataset_op == 'MnistDataset': sampler = construct_sampler(node.get('sampler')) num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_dir'], node['usage'], num_samples, node.get('num_parallel_workers'), - node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id'), node.get('cache')) + node.get('shuffle'), sampler, node.get('num_shards'), node.get('shard_id')) + + elif dataset_op == 'TextFileDataset': + shuffle = to_shuffle_mode(node.get('shuffle')) + if shuffle is not None and isinstance(shuffle, str): + shuffle = de.Shuffle(shuffle) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_files'], num_samples, + node.get('num_parallel_workers'), shuffle, + node.get('num_shards'), node.get('shard_id')) elif dataset_op == 'TFRecordDataset': shuffle = to_shuffle_mode(node.get('shuffle')) @@ -188,30 +251,50 @@ def create_node(node): num_samples = check_and_replace_input(node.get('num_samples'), 0, None) pyobj = pyclass(node['dataset_files'], node.get('schema'), node.get('columns_list'), num_samples, node.get('num_parallel_workers'), - shuffle, node.get('num_shards'), node.get('shard_id'), node.get('cache')) + shuffle, node.get('num_shards'), node.get('shard_id')) - elif dataset_op == 'Repeat': - pyobj = de.Dataset().repeat(node.get('count')) + elif dataset_op == 'VOCDataset': + sampler = construct_sampler(node.get('sampler')) + num_samples = check_and_replace_input(node.get('num_samples'), 0, None) + pyobj = pyclass(node['dataset_dir'], node.get('task'), node.get('usage'), node.get('class_indexing'), + num_samples, node.get('num_parallel_workers'), node.get('shuffle'), + node.get('decode'), sampler, node.get('num_shards'), node.get('shard_id')) + + # Dataset Ops (in alphabetical order) + elif dataset_op == 'Batch': + pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) elif dataset_op == 'Map': tensor_ops = construct_tensor_ops(node.get('operations')) pyobj = de.Dataset().map(tensor_ops, node.get('input_columns'), node.get('output_columns'), node.get('column_order'), node.get('num_parallel_workers'), - True, node.get('cache'), node.get('callbacks')) + True, node.get('callbacks')) + + elif dataset_op == 'Project': + pyobj = de.Dataset().project(node['columns']) + + elif dataset_op == 'Rename': + pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) + + elif dataset_op == 'Repeat': + pyobj = de.Dataset().repeat(node.get('count')) elif dataset_op == 'Shuffle': pyobj = de.Dataset().shuffle(node.get('buffer_size')) - elif dataset_op == 'Batch': - pyobj = de.Dataset().batch(node['batch_size'], node.get('drop_remainder')) + elif dataset_op == 'Skip': + pyobj = de.Dataset().skip(node.get('count')) + + elif dataset_op == 'Take': + pyobj = de.Dataset().take(node.get('count')) + + elif dataset_op == 'Transfer': + pyobj = de.Dataset().to_device(node.get('send_epoch_end'), node.get('create_data_info_queue')) elif dataset_op == 'Zip': # Create ZipDataset instance, giving dummy input dataset that will be overrided in the caller. pyobj = de.ZipDataset((de.Dataset(), de.Dataset())) - elif dataset_op == 'Rename': - pyobj = de.Dataset().rename(node['input_columns'], node['output_columns']) - else: raise RuntimeError(dataset_op + " is not yet supported by ds.engine.deserialize().") @@ -252,35 +335,59 @@ def construct_tensor_ops(operations): """Instantiate tensor op object(s) based on the information from dictionary['operations']""" result = [] for op in operations: - op_name = op['tensor_op_name'][:-2] # to remove op from the back of the name + op_name = op['tensor_op_name'] op_module_vis = sys.modules["mindspore.dataset.vision.c_transforms"] op_module_trans = sys.modules["mindspore.dataset.transforms.c_transforms"] + if op_name == "HwcToChw": op_name = "HWC2CHW" if hasattr(op_module_vis, op_name): op_class = getattr(op_module_vis, op_name) - elif hasattr(op_module_trans, op_name): + elif hasattr(op_module_trans, op_name[:-2]): + op_name = op_name[:-2] # to remove op from the back of the name op_class = getattr(op_module_trans, op_name) else: raise RuntimeError(op_name + " is not yet supported by deserialize().") - if op_name == 'Decode': + # Transforms Ops (in alphabetical order) + if op_name == 'OneHot': + result.append(op_class(op['num_classes'])) + + elif op_name == 'TypeCast': + result.append(op_class(to_mstype(op['data_type']))) + + # Vision Ops (in alphabetical order) + elif op_name == 'CenterCrop': + result.append(op_class(op['size'])) + + elif op_name == 'Decode': result.append(op_class(op.get('rgb'))) + elif op_name == 'HWC2CHW': + result.append(op_class()) + + elif op_name == 'Normalize': + result.append(op_class(op['mean'], op['std'])) + + elif op_name == 'Pad': + result.append(op_class(op['padding'], tuple(op['fill_value']), Border(to_border_mode(op['padding_mode'])))) + + elif op_name == 'RandomColorAdjust': + result.append(op_class(op.get('brightness'), op.get('contrast'), op.get('saturation'), + op.get('hue'))) + elif op_name == 'RandomCrop': result.append(op_class(op['size'], op.get('padding'), op.get('pad_if_needed'), tuple(op.get('fill_value')), Border(to_border_mode(op.get('padding_mode'))))) - elif op_name == 'Resize': - result.append(op_class(op['size'], Inter(to_interpolation_mode(op.get('interpolation'))))) + elif op_name == 'RandomRotation': + result.append(op_class(op['degrees'], to_interpolation_mode(op.get('interpolation_mode')), op.get('expand'), + tuple(op.get('center')), tuple(op.get('fill_value')))) elif op_name == 'Rescale': result.append(op_class(op['rescale'], op['shift'])) - elif op_name == 'HWC2CHW': - result.append(op_class()) - - elif op_name == 'OneHot': - result.append(op_class(op['num_classes'])) + elif op_name == 'Resize': + result.append(op_class(op['size'], to_interpolation_mode(op.get('interpolation')))) else: raise ValueError("Tensor op name is unknown: {}.".format(op_name)) diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 38fd73e9da..725b4f7eed 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -19,11 +19,13 @@ import filecmp import glob import json import os +import pytest import numpy as np from test_minddataset_sampler import add_and_remove_cv_file, get_data, CV_DIR_NAME, CV_FILE_NAME from util import config_get_set_num_parallel_workers, config_get_set_seed +import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as c import mindspore.dataset.vision.c_transforms as vision @@ -31,7 +33,7 @@ from mindspore import log as logger from mindspore.dataset.vision import Inter -def skip_test_imagefolder(remove_json_files=True): +def test_serdes_imagefolder_dataset(remove_json_files=True): """ Test simulating resnet50 dataset pipeline. """ @@ -100,7 +102,7 @@ def skip_test_imagefolder(remove_json_files=True): delete_json_files() -def test_mnist_dataset(remove_json_files=True): +def test_serdes_mnist_dataset(remove_json_files=True): """ Test serdes on mnist dataset pipeline. """ @@ -141,7 +143,7 @@ def test_mnist_dataset(remove_json_files=True): delete_json_files() -def test_zip_dataset(remove_json_files=True): +def test_serdes_zip_dataset(remove_json_files=True): """ Test serdes on zip dataset pipeline. """ @@ -185,7 +187,7 @@ def test_zip_dataset(remove_json_files=True): delete_json_files() -def skip_test_random_crop(): +def test_serdes_random_crop(): """ Test serdes on RandomCrop pipeline. """ @@ -225,6 +227,179 @@ def skip_test_random_crop(): ds.config.set_num_parallel_workers(original_num_parallel_workers) +def test_serdes_cifar10_dataset(remove_json_files=True): + """ + Test serdes on Cifar10 dataset pipeline + """ + data_dir = "../data/dataset/testCifar10Data" + original_seed = config_get_set_seed(1) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + data1 = ds.Cifar10Dataset(data_dir, num_samples=10, shuffle=False) + data1 = data1.take(6) + + trans = [ + vision.RandomCrop((32, 32), (4, 4, 4, 4)), + vision.Resize((224, 224)), + vision.Rescale(1.0 / 255.0, 0.0), + vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + vision.HWC2CHW() + ] + + type_cast_op = c.TypeCast(mstype.int32) + data1 = data1.map(operations=type_cast_op, input_columns="label") + data1 = data1.map(operations=trans, input_columns="image") + data1 = data1.batch(3, drop_remainder=True) + data1 = data1.repeat(1) + data2 = util_check_serialize_deserialize_file(data1, "cifar10_dataset_pipeline", remove_json_files) + + num_samples = 0 + # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), + data2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1['image'], item2['image']) + num_samples += 1 + + assert num_samples == 2 + + # Restore configuration num_parallel_workers + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_serdes_celeba_dataset(remove_json_files=True): + """ + Test serdes on Celeba dataset pipeline. + """ + DATA_DIR = "../data/dataset/testCelebAData/" + data1 = ds.CelebADataset(DATA_DIR, decode=True, num_shards=1, shard_id=0) + # define map operations + data1 = data1.repeat(2) + center_crop = vision.CenterCrop((80, 80)) + pad_op = vision.Pad(20, fill_value=(20, 20, 20)) + data1 = data1.map(operations=[center_crop, pad_op], input_columns=["image"], num_parallel_workers=8) + data2 = util_check_serialize_deserialize_file(data1, "celeba_dataset_pipeline", remove_json_files) + + num_samples = 0 + # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), + data2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1['image'], item2['image']) + num_samples += 1 + + assert num_samples == 8 + + +def test_serdes_csv_dataset(remove_json_files=True): + """ + Test serdes on Csvdataset pipeline. + """ + DATA_DIR = "../data/dataset/testCSV/1.csv" + data1 = ds.CSVDataset( + DATA_DIR, + column_defaults=["1", "2", "3", "4"], + column_names=['col1', 'col2', 'col3', 'col4'], + shuffle=False) + columns = ["col1", "col4", "col2"] + data1 = data1.project(columns=columns) + data2 = util_check_serialize_deserialize_file(data1, "csv_dataset_pipeline", remove_json_files) + + num_samples = 0 + # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), + data2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1['col1'], item2['col1']) + np.testing.assert_array_equal(item1['col2'], item2['col2']) + np.testing.assert_array_equal(item1['col4'], item2['col4']) + num_samples += 1 + + assert num_samples == 3 + + +def test_serdes_voc_dataset(remove_json_files=True): + """ + Test serdes on VOC dataset pipeline. + """ + data_dir = "../data/dataset/testVOC2012" + original_seed = config_get_set_seed(1) + original_num_parallel_workers = config_get_set_num_parallel_workers(1) + + # define map operations + random_color_adjust_op = vision.RandomColorAdjust(brightness=(0.5, 0.5)) + random_rotation_op = vision.RandomRotation((0, 90), expand=True, resample=Inter.BILINEAR, center=(50, 50), + fill_value=150) + + data1 = ds.VOCDataset(data_dir, task="Detection", usage="train", decode=True) + data1 = data1.map(operations=random_color_adjust_op, input_columns=["image"]) + data1 = data1.map(operations=random_rotation_op, input_columns=["image"]) + data1 = data1.skip(2) + data2 = util_check_serialize_deserialize_file(data1, "voc_dataset_pipeline", remove_json_files) + + num_samples = 0 + # Iterate and compare the data in the original pipeline (data1) against the deserialized pipeline (data2) + for item1, item2 in zip(data1.create_dict_iterator(num_epochs=1, output_numpy=True), + data2.create_dict_iterator(num_epochs=1, output_numpy=True)): + np.testing.assert_array_equal(item1['image'], item2['image']) + num_samples += 1 + + assert num_samples == 7 + + # Restore configuration num_parallel_workers + ds.config.set_seed(original_seed) + ds.config.set_num_parallel_workers(original_num_parallel_workers) + + +def test_serdes_to_device(remove_json_files=True): + """ + Test serdes on VOC dataset pipeline. + """ + data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) + data1 = data1.to_device() + util_check_serialize_deserialize_file(data1, "transfer_dataset_pipeline", remove_json_files) + + +def test_serdes_exception(): + """ + Test exception case in serdes + """ + data_dir = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + schema_file = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + data1 = ds.TFRecordDataset(data_dir, schema_file, columns_list=["image", "label"], shuffle=False) + data1 = data1.filter(input_columns=["image", "label"], predicate=lambda data: data < 11, num_parallel_workers=4) + data1_json = ds.serialize(data1) + with pytest.raises(RuntimeError) as msg: + ds.deserialize(input_dict=data1_json) + assert "Filter is not yet supported by ds.engine.deserialize" in str(msg) + + +def util_check_serialize_deserialize_file(data_orig, filename, remove_json_files): + """ + Utility function for testing serdes files. It is to check if a json file is indeed created with correct name + after serializing and if it remains the same after repeatly saving and loading. + :param data_orig: original data pipeline to be serialized + :param filename: filename to be saved as json format + :param remove_json_files: whether to remove the json file after testing + :return: The data pipeline after serializing and deserializing using the original pipeline + """ + file1 = filename + ".json" + file2 = filename + "_1.json" + ds.serialize(data_orig, file1) + assert validate_jsonfile(file1) is True + assert validate_jsonfile("wrong_name.json") is False + + data_changed = ds.deserialize(json_filepath=file1) + ds.serialize(data_changed, file2) + assert validate_jsonfile(file2) is True + assert filecmp.cmp(file1, file2) + + # Remove the generated json file + if remove_json_files: + delete_json_files() + return data_changed + + def validate_jsonfile(filepath): try: file_exist = os.path.exists(filepath) @@ -276,7 +451,12 @@ def skip_test_minddataset(add_and_remove_cv_file): if __name__ == '__main__': - test_imagefolder() - test_zip_dataset() - test_mnist_dataset() - test_random_crop() + test_serdes_imagefolder_dataset() + test_serdes_mnist_dataset() + test_serdes_cifar10_dataset() + test_serdes_celeba_dataset() + test_serdes_csv_dataset() + test_serdes_voc_dataset() + test_serdes_zip_dataset() + test_serdes_random_crop() + test_serdes_exception()