!7555 Add TreeConsumer and RuntimeContext to support multiple front ends
Merge pull request !7555 from h.farahat/consumerspull/7555/MERGE
commit
938f377886
@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
/// Consumer that iterates over the dataset and returns the rows one by one as a python list or a dict
|
||||
class PythonIterator : public IteratorConsumer {
|
||||
/// Constructor
|
||||
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
|
||||
explicit PythonIterator(int32_t num_epochs = -1) : IteratorConsumer(num_epochs) {}
|
||||
|
||||
/// Get the next row as a python dict
|
||||
/// \param[out] output python dict
|
||||
/// \return Status error code
|
||||
Status GetNextAsMap(py::dict *output) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
/// Get the next row as a python dict
|
||||
/// \param[out] output python dict
|
||||
/// \return Status error code
|
||||
Status GetNextAsList(py::list *output) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_PYTHON_TREE_CONSUMER_H_
|
@ -0,0 +1,72 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
Status IteratorConsumer::GetNextAsVector(std::vector<TensorPtr> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
out->clear();
|
||||
|
||||
TensorRow res;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
|
||||
|
||||
// Return empty vector if there's no data
|
||||
RETURN_OK_IF_TRUE(res.empty());
|
||||
|
||||
std::copy(res.begin(), res.end(), std::back_inserter(*out));
|
||||
return Status::OK();
|
||||
}
|
||||
Status IteratorConsumer::GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out_map) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_map);
|
||||
out_map->clear();
|
||||
|
||||
TensorRow res;
|
||||
RETURN_IF_NOT_OK(tree_adapter_->GetNext(&res));
|
||||
|
||||
// Return empty map if there's no data
|
||||
RETURN_OK_IF_TRUE(res.empty());
|
||||
|
||||
// Populate the out map from the row and return it
|
||||
for (const auto &colMap : tree_adapter_->GetColumnNameMap()) {
|
||||
(*out_map)[colMap.first] = std::move(res[colMap.second]);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
||||
|
||||
Status IteratorConsumer::Init(std::shared_ptr<api::Dataset> d) {
|
||||
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
|
||||
}
|
||||
Status TreeConsumer::Init(std::shared_ptr<api::Dataset> d) { return tree_adapter_->BuildAndPrepare(std::move(d)); }
|
||||
|
||||
Status ToDevice::Init(std::shared_ptr<api::Dataset> d) {
|
||||
// TODO(CRC):
|
||||
// Get device ID from children look at get_distribution in python
|
||||
// Add DeviceQue IR on top of dataset d
|
||||
|
||||
return tree_adapter_->BuildAndPrepare(std::move(d), num_epochs_);
|
||||
}
|
||||
} // namespace mindspore::dataset
|
@ -0,0 +1,154 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
// Forward declare
|
||||
class TreeAdapter;
|
||||
|
||||
namespace api {
|
||||
class Dataset;
|
||||
}
|
||||
|
||||
/// A base class for tree consumers which would fetch rows from the tree pipeline
|
||||
class TreeConsumer {
|
||||
public:
|
||||
/// Constructor that prepares an empty tree_adapter
|
||||
TreeConsumer();
|
||||
/// Initializes the consumer, this involves constructing and preparing the tree.
|
||||
/// \param d The dataset node that represent the root of the IR tree.
|
||||
/// \return Status error code.
|
||||
virtual Status Init(std::shared_ptr<api::Dataset> d);
|
||||
|
||||
protected:
|
||||
/// The class owns the tree_adapter that handles execution tree operations.
|
||||
std::unique_ptr<TreeAdapter> tree_adapter_;
|
||||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
virtual std::string Name() = 0;
|
||||
};
|
||||
|
||||
/// Consumer that iterates over the dataset and returns the rows one by one as a vector or a map
|
||||
class IteratorConsumer : public TreeConsumer {
|
||||
public:
|
||||
/// Constructor which will call the base class default constructor.
|
||||
/// \param num_epochs number of epochs. Default to -1 (infinite epochs).
|
||||
explicit IteratorConsumer(int32_t num_epochs = -1) : TreeConsumer(), num_epochs_(num_epochs) {}
|
||||
|
||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||
|
||||
/// Returns the next row in a vector format
|
||||
/// \param[out] out std::vector of Tensors
|
||||
/// \return Status error code
|
||||
Status GetNextAsVector(std::vector<TensorPtr> *out);
|
||||
|
||||
/// Returns the next row in as a map
|
||||
/// \param[out] out std::map of string to Tensor
|
||||
/// \return Status error code
|
||||
Status GetNextAsMap(std::unordered_map<std::string, TensorPtr> *out);
|
||||
|
||||
protected:
|
||||
/// Method to return the name of the consumer
|
||||
/// \return string
|
||||
std::string Name() override { return "IteratorConsumer"; }
|
||||
|
||||
private:
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
/// Consumer that iterates over the dataset and writes it to desk
|
||||
class SaveToDesk : public TreeConsumer {
|
||||
public:
|
||||
/// Constructor which will call the base class default constructor.
|
||||
/// \param dataset_path path the the dataset
|
||||
/// \param num_files number of files. Default to 1
|
||||
/// \param dataset_type The format of the dataset. Default to "mindrecod".
|
||||
explicit SaveToDesk(std::string dataset_path, int32_t num_files = 1, std::string dataset_type = "mindrecord")
|
||||
: TreeConsumer(), dataset_path_(dataset_path), num_files_(num_files), dataset_type_(dataset_type) {}
|
||||
|
||||
/// Save the given dataset to MindRecord format on desk. This is a blocking method (i.e., after returning, all rows
|
||||
/// would be written to desk)
|
||||
/// \return Status error code
|
||||
Status Save() { return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet."); }
|
||||
|
||||
private:
|
||||
std::string dataset_path_;
|
||||
int32_t num_files_;
|
||||
std::string dataset_type_;
|
||||
};
|
||||
|
||||
/// Consumer that iterates over the dataset and send it to a device
|
||||
class ToDevice : public TreeConsumer {
|
||||
public:
|
||||
ToDevice(std::string device_type, bool send_epoch_end, int32_t num_epochs)
|
||||
: TreeConsumer(), device_type_(device_type), send_epoch_end_(send_epoch_end), num_epochs_(num_epochs) {}
|
||||
|
||||
Status Init(std::shared_ptr<api::Dataset> d) override;
|
||||
|
||||
Status Send() {
|
||||
// TODO(CRC): launch the tree
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status Stop() {
|
||||
// TODO(CRC): Get root + call StopSend
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status Continue() {
|
||||
// TODO(CRC): Get root + call StopSend
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
|
||||
private:
|
||||
std::string device_type_;
|
||||
bool send_epoch_end_;
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
/// Consumer that is used to get some pipeline information
|
||||
class TreeGetters : public TreeConsumer {
|
||||
Status GetDatasetSize(int32_t *size) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetBatchSize(int32_t *batch_size) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetRepeatCount(int32_t *repeat_count) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetNumClasses(int32_t *num_classes) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputShapes(std::vector<TensorShape> *shapes) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputTypes(std::vector<DataType> *types) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
Status GetOutputNames(std::vector<std::string> *names) {
|
||||
return Status(StatusCode::kNotImplementedYet, __LINE__, __FILE__, "Method is not implemented yet.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_CONSUMER_TREE_CONSUMER_H_
|
@ -0,0 +1,25 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/engine/runtime_context.h"
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
namespace mindspore::dataset {
|
||||
|
||||
void RuntimeContext::AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer) {
|
||||
tree_consumer_ = std::move(tree_consumer);
|
||||
}
|
||||
} // namespace mindspore::dataset
|
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "minddata/dataset/engine/consumers/tree_consumer.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
class TreeConsumer;
|
||||
|
||||
/// Class the represents single runtime instance which can consume data from a data pipeline
|
||||
class RuntimeContext {
|
||||
public:
|
||||
/// Default constructor
|
||||
RuntimeContext() = default;
|
||||
|
||||
/// Initialize the runtime, for now we just call the global init
|
||||
/// \return Status error code
|
||||
Status Init() { return GlobalInit(); }
|
||||
|
||||
/// Method to terminate the runtime, this will not release the resources
|
||||
/// \return Status error code
|
||||
virtual Status Terminate() { return Status::OK(); }
|
||||
|
||||
/// Set the tree consumer
|
||||
/// \param tree_consumer to be assigned
|
||||
void AssignConsumer(std::unique_ptr<TreeConsumer> tree_consumer);
|
||||
|
||||
/// Get the tree consumer
|
||||
/// \return Raw pointer to the tree consumer.
|
||||
TreeConsumer *GetConsumer() { return tree_consumer_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<TreeConsumer> tree_consumer_;
|
||||
};
|
||||
|
||||
} // namespace mindspore::dataset
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_RUNTIME_CONTEXT_H_
|
Loading…
Reference in new issue