|
|
|
@ -30,8 +30,8 @@
|
|
|
|
|
#include "minddata/mindrecord/include/shard_writer.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace mindspore::dataset {
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace dataset {
|
|
|
|
|
// TreeConsumer
|
|
|
|
|
TreeConsumer::TreeConsumer() { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
|
|
|
|
|
|
|
|
@ -440,7 +440,9 @@ Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false) { tree_adapter_ = std::make_unique<TreeAdapter>(); }
|
|
|
|
|
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) {
|
|
|
|
|
tree_adapter_ = std::make_unique<TreeAdapter>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
|
|
|
|
|
root_ = std::move(d);
|
|
|
|
@ -473,20 +475,14 @@ Status TreeGetters::GetDatasetSize(int64_t *dataset_size) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TreeGetters::GetOutputTypes(std::vector<DataType> *types) {
|
|
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
|
|
|
|
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
|
|
|
|
|
|
|
|
|
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*types),
|
|
|
|
|
[](const TensorPtr &t) { return t->type(); });
|
|
|
|
|
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
|
|
|
|
|
*types = first_row_type_;
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TreeGetters::GetOutputShapes(std::vector<TensorShape> *shapes) {
|
|
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
|
|
|
|
if (first_row_.empty()) RETURN_IF_NOT_OK(GetRow(&first_row_));
|
|
|
|
|
|
|
|
|
|
std::transform(first_row_.begin(), first_row_.end(), std::back_inserter(*shapes),
|
|
|
|
|
[](const TensorPtr &t) { return t->shape(); });
|
|
|
|
|
RETURN_IF_NOT_OK(GetFirstRowShapeAndType());
|
|
|
|
|
*shapes = first_row_shape_;
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -555,6 +551,18 @@ Status TreeGetters::InternalInit() {
|
|
|
|
|
if (!s.IsError()) init_flag_ = true;
|
|
|
|
|
return s;
|
|
|
|
|
}
|
|
|
|
|
Status TreeGetters::GetFirstRowShapeAndType() {
|
|
|
|
|
RETURN_OK_IF_TRUE(first_row_obtained_);
|
|
|
|
|
RETURN_IF_NOT_OK(InternalInit(static_cast<int8_t>(GetterPass::kOutputShapeAndType)));
|
|
|
|
|
TensorRow first_row;
|
|
|
|
|
RETURN_IF_NOT_OK(GetRow(&first_row));
|
|
|
|
|
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_type_),
|
|
|
|
|
[](const TensorPtr &t) { return t->type(); });
|
|
|
|
|
std::transform(first_row.begin(), first_row.end(), std::back_inserter(first_row_shape_),
|
|
|
|
|
[](const TensorPtr &t) { return t->shape(); });
|
|
|
|
|
first_row_obtained_ = true;
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
Status BuildVocabConsumer::Init(std::shared_ptr<DatasetNode> d) { return tree_adapter_->Compile(std::move(d), 1); }
|
|
|
|
|
|
|
|
|
|
Status BuildVocabConsumer::Start() {
|
|
|
|
@ -565,4 +573,5 @@ Status BuildVocabConsumer::Start() {
|
|
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(row.empty(), "The fetched row from BuildVocab should be an EOE.");
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
} // namespace mindspore::dataset
|
|
|
|
|
} // namespace dataset
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|