|
|
|
@ -275,8 +275,10 @@ int64_t Dataset::GetDatasetSize(bool estimate) {
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
|
|
|
|
std::shared_ptr<DatasetSizeGetter> size_getter = std::make_shared<DatasetSizeGetter>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(size_getter->Init(this->IRNode()), -1);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(size_getter->GetDatasetSize(&dataset_size, estimate), -1);
|
|
|
|
|
DatasetSizeGetter *consumer = size_getter.get();
|
|
|
|
|
runtime_context->AssignConsumer(size_getter);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetDatasetSize(&dataset_size, estimate), -1);
|
|
|
|
|
return dataset_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -284,8 +286,10 @@ std::vector<mindspore::DataType> Dataset::GetOutputTypes() {
|
|
|
|
|
std::vector<DataType> types;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputTypes(&types), {});
|
|
|
|
|
TreeGetters *consumer = tree_getters_.get();
|
|
|
|
|
runtime_context->AssignConsumer(tree_getters_);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetOutputTypes(&types), {});
|
|
|
|
|
std::vector<mindspore::DataType> ret_types;
|
|
|
|
|
std::transform(
|
|
|
|
|
types.begin(), types.end(), std::back_inserter(ret_types),
|
|
|
|
@ -297,8 +301,10 @@ std::vector<std::vector<int64_t>> Dataset::GetOutputShapes() {
|
|
|
|
|
std::vector<TensorShape> shapes;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->GetOutputShapes(&shapes), {});
|
|
|
|
|
TreeGetters *consumer = tree_getters_.get();
|
|
|
|
|
runtime_context->AssignConsumer(tree_getters_);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetOutputShapes(&shapes), {});
|
|
|
|
|
std::vector<std::vector<int64_t>> ret_shapes;
|
|
|
|
|
std::transform(shapes.begin(), shapes.end(), std::back_inserter(ret_shapes),
|
|
|
|
|
[](const TensorShape &s) -> std::vector<int64_t> { return s.AsVector(); });
|
|
|
|
@ -309,8 +315,10 @@ int64_t Dataset::GetNumClasses() {
|
|
|
|
|
int64_t num_classes;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), -1);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), -1);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->GetNumClasses(&num_classes), -1);
|
|
|
|
|
TreeGetters *consumer = tree_getters_.get();
|
|
|
|
|
runtime_context->AssignConsumer(tree_getters_);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), -1);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetNumClasses(&num_classes), -1);
|
|
|
|
|
return num_classes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -318,8 +326,10 @@ std::vector<std::vector<char>> Dataset::GetColumnNamesCharIF() {
|
|
|
|
|
std::vector<std::string> col_names;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->GetColumnNames(&col_names), {});
|
|
|
|
|
TreeGetters *consumer = tree_getters_.get();
|
|
|
|
|
runtime_context->AssignConsumer(tree_getters_);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetColumnNames(&col_names), {});
|
|
|
|
|
return VectorStringToChar(col_names);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -327,8 +337,10 @@ std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> Dataset::GetClas
|
|
|
|
|
std::vector<std::pair<std::string, std::vector<int32_t>>> output_class_indexing;
|
|
|
|
|
std::unique_ptr<NativeRuntimeContext> runtime_context = std::make_unique<NativeRuntimeContext>();
|
|
|
|
|
RETURN_SECOND_IF_ERROR(runtime_context->Init(), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(tree_getters_->GetClassIndexing(&output_class_indexing), {});
|
|
|
|
|
TreeGetters *consumer = tree_getters_.get();
|
|
|
|
|
runtime_context->AssignConsumer(tree_getters_);
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->Init(this->IRNode()), {});
|
|
|
|
|
RETURN_SECOND_IF_ERROR(consumer->GetClassIndexing(&output_class_indexing), {});
|
|
|
|
|
return ClassIndexStringToChar(output_class_indexing);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|