|
|
|
@ -200,13 +200,11 @@ int64_t Dataset::GetDatasetSize() {
|
|
|
|
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing RuntimeContext failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(this->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetDatasetSize: Initializing TreeGetters failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetDatasetSize(&dataset_size);
|
|
|
|
|
return rc.IsError() ? -1 : dataset_size;
|
|
|
|
|
}
|
|
|
|
@ -218,17 +216,13 @@ std::vector<DataType> Dataset::GetOutputTypes() {
|
|
|
|
|
rc = runtime_context->Init();
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputTypes: Initializing RuntimeContext failed.";
|
|
|
|
|
types.clear();
|
|
|
|
|
return types;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(this->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputTypes: Initializing TreeGetters failed.";
|
|
|
|
|
types.clear();
|
|
|
|
|
return types;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetOutputTypes(&types);
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputTypes: Get Output Types failed.";
|
|
|
|
@ -245,17 +239,13 @@ std::vector<TensorShape> Dataset::GetOutputShapes() {
|
|
|
|
|
rc = runtime_context->Init();
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputShapes: Initializing RuntimeContext failed.";
|
|
|
|
|
shapes.clear();
|
|
|
|
|
return shapes;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(this->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputShapes: Initializing TreeGetters failed.";
|
|
|
|
|
shapes.clear();
|
|
|
|
|
return shapes;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetOutputShapes(&shapes);
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetOutputShapes: Get Output Shapes failed.";
|
|
|
|
@ -275,17 +265,39 @@ int64_t Dataset::GetNumClasses() {
|
|
|
|
|
MS_LOG(ERROR) << "GetNumClasses: Initializing RuntimeContext failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(ds->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetNumClasses: Initializing TreeGetters failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetNumClasses(&num_classes);
|
|
|
|
|
return rc.IsError() ? -1 : num_classes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<std::string, std::vector<int32_t>>> Dataset::GetClassIndexing() {
|
|
|
|
|
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
|
|
|
|
auto ds = shared_from_this();
|
|
|
|
|
Status rc;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
rc = runtime_context->Init();
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetClassIndexing: Initializing RuntimeContext failed.";
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->Init(ds->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetClassIndexing: Initializing TreeGetters failed.";
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetClassIndexing(&output_class_indexing);
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetClassIndexing: Get Class Index failed.";
|
|
|
|
|
output_class_indexing.clear();
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
}
|
|
|
|
|
return output_class_indexing;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// \brief Function to create a SchemaObj
|
|
|
|
|
/// \param[in] schema_file Path of schema file
|
|
|
|
|
/// \return Shared pointer to the current schema
|
|
|
|
@ -580,13 +592,11 @@ int64_t Dataset::GetBatchSize() {
|
|
|
|
|
MS_LOG(ERROR) << "GetBatchSize: Initializing RuntimeContext failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(ds->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetBatchSize: Initializing TreeGetters failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetBatchSize(&batch_size);
|
|
|
|
|
return rc.IsError() ? -1 : batch_size;
|
|
|
|
|
}
|
|
|
|
@ -601,22 +611,22 @@ int64_t Dataset::GetRepeatCount() {
|
|
|
|
|
MS_LOG(ERROR) << "GetRepeatCount: Initializing RuntimeContext failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
if (!tree_getters_->isInitialized()) {
|
|
|
|
|
rc = tree_getters_->Init(ds->IRNode());
|
|
|
|
|
if (rc.IsError()) {
|
|
|
|
|
MS_LOG(ERROR) << "GetRepeatCount: Initializing TreeGetters failed.";
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
rc = tree_getters_->GetRepeatCount(&repeat_count);
|
|
|
|
|
return rc.IsError() ? 0 : repeat_count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Dataset> Dataset::SetNumWorkers(int32_t num_workers) {
|
|
|
|
|
if (ir_node_ == nullptr || ir_node_->SetNumWorkers(num_workers) == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return shared_from_this();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ANDROID
|
|
|
|
|
std::shared_ptr<SentencePieceVocab> Dataset::BuildSentencePieceVocab(
|
|
|
|
|
const std::vector<std::string> &col_names, uint32_t vocab_size, float character_coverage,
|
|
|
|
|