!10142 dataset CPP UT: Updates for ExecTree to IRTree Support
From: @cathwong Reviewed-by: Signed-off-by:pull/10142/MERGE
commit
a5869f2984
@ -0,0 +1,210 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestEpochCtrl : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestEpochCtrl, TestAutoInjectEpoch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestEpochCtrl-TestAutoInjectEpoch.";
|
||||
|
||||
int32_t img_class[4] = {0, 1, 2, 3};
|
||||
int32_t num_epochs = 2 + std::rand() % 3;
|
||||
int32_t sampler_size = 44;
|
||||
int32_t class_size = 11;
|
||||
MS_LOG(INFO) << "num_epochs: " << num_epochs;
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, SequentialSampler(0, sampler_size));
|
||||
ds = ds->SetNumWorkers(2);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect a valid iterator
|
||||
ASSERT_NE(iter, nullptr);
|
||||
|
||||
uint64_t i = 0;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
|
||||
for (int epoch = 0; epoch < num_epochs; epoch++) {
|
||||
// Iterate the dataset and get each row
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
while (row.size() != 0) {
|
||||
auto label = row["label"];
|
||||
int32_t label_value;
|
||||
label->GetItemAt(&label_value, {0});
|
||||
EXPECT_TRUE(img_class[(i % sampler_size) / class_size] == label_value);
|
||||
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_EQ(i, sampler_size * num_epochs);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_EQ(row.size(), 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrl, TestEpoch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestEpochCtrl-TestEpoch.";
|
||||
|
||||
int32_t num_epochs = 1 + std::rand() % 4;
|
||||
int32_t sampler_size = 7;
|
||||
MS_LOG(INFO) << "num_epochs: " << num_epochs;
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(0, sampler_size));
|
||||
ds = ds->SetNumWorkers(3);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect a valid iterator
|
||||
ASSERT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
uint64_t i = 0;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
|
||||
for (int epoch = 0; epoch < num_epochs; epoch++) {
|
||||
iter->GetNextRow(&row);
|
||||
while (row.size() != 0) {
|
||||
auto label = row["label"];
|
||||
int32_t label_value;
|
||||
label->GetItemAt(&label_value, {0});
|
||||
EXPECT_TRUE(label_value >= 0 && label_value <= 3);
|
||||
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of rows fetched
|
||||
EXPECT_EQ(i, sampler_size * num_epochs);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_EQ(row.size(), 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrl, TestRepeatEpoch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestEpochCtrl-TestRepeatEpoch.";
|
||||
|
||||
int32_t num_epochs = 2 + std::rand() % 5;
|
||||
int32_t num_repeats = 3;
|
||||
int32_t sampler_size = 7;
|
||||
MS_LOG(INFO) << "num_epochs: " << num_epochs;
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, RandomSampler(0, sampler_size));
|
||||
ds = ds->SetNumWorkers(3);
|
||||
ds = ds->Repeat(num_repeats);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect a valid iterator
|
||||
ASSERT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
uint64_t i = 0;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
|
||||
for (int epoch = 0; epoch < num_epochs; epoch++) {
|
||||
iter->GetNextRow(&row);
|
||||
while (row.size() != 0) {
|
||||
auto label = row["label"];
|
||||
int32_t label_value;
|
||||
label->GetItemAt(&label_value, {0});
|
||||
EXPECT_TRUE(label_value >= 0 && label_value <= 3);
|
||||
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of rows fetched
|
||||
EXPECT_EQ(i, sampler_size * num_repeats * num_epochs);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_EQ(row.size(), 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestEpochCtrl, TestRepeatRepeatEpoch) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestEpochCtrl-TestRepeatRepeatEpoch.";
|
||||
|
||||
int32_t num_epochs = 1 + std::rand() % 5;
|
||||
int32_t num_repeats[2] = {2, 3};
|
||||
int32_t sampler_size = 11;
|
||||
MS_LOG(INFO) << "num_epochs: " << num_epochs;
|
||||
|
||||
// Create an ImageFolder Dataset
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, true, SequentialSampler(5, sampler_size));
|
||||
ds = ds->Repeat(num_repeats[0]);
|
||||
ds = ds->Repeat(num_repeats[1]);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect a valid iterator
|
||||
ASSERT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
uint64_t i = 0;
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
|
||||
for (int epoch = 0; epoch < num_epochs; epoch++) {
|
||||
iter->GetNextRow(&row);
|
||||
while (row.size() != 0) {
|
||||
auto label = row["label"];
|
||||
int32_t label_value;
|
||||
label->GetItemAt(&label_value, {0});
|
||||
EXPECT_TRUE(label_value >= 0 && label_value <= 3);
|
||||
|
||||
iter->GetNextRow(&row);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Verify correct number of rows fetched
|
||||
EXPECT_EQ(i, sampler_size * num_repeats[0] * num_repeats[1] * num_epochs);
|
||||
|
||||
// Try to fetch data beyond the specified number of epochs.
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_EQ(row.size(), 2);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
}
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestRepeatSetNumWorkers) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestRepeat-TestRepeatSetNumWorkers.";
|
||||
|
||||
std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
|
||||
std::shared_ptr<Dataset> ds = TFRecord({file_path});
|
||||
ds = ds->SetNumWorkers(16);
|
||||
ds = ds->Repeat(32);
|
||||
|
||||
// Create an iterator over the result of the above dataset
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect a valid iterator
|
||||
ASSERT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Verify correct number of rows fetched
|
||||
EXPECT_EQ(i, 12 * 32);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/tree_adapter.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
#include "minddata/dataset/include/transforms.h"
|
||||
#include "minddata/dataset/include/vision.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
|
||||
class MindDataTestTensorOpFusionPass : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestTensorOpFusionPass() = default;
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResizeDisabled) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTensorOpFusionPass-RandomCropDecodeResizeDisabled";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
ds = ds->SetNumWorkers(16);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> decode = vision::Decode();
|
||||
std::shared_ptr<TensorOperation> random_resized_crop = vision::RandomResizedCrop({5});
|
||||
ds = ds->Map({decode, random_resized_crop}, {"image"});
|
||||
|
||||
std::shared_ptr<DatasetNode> node = ds->IRNode();
|
||||
auto ir_tree = std::make_shared<TreeAdapter>();
|
||||
// Disable IR optimization pass
|
||||
ir_tree->SetOptimize(false);
|
||||
Status rc;
|
||||
rc = ir_tree->Compile(node);
|
||||
EXPECT_TRUE(rc);
|
||||
auto root_op = ir_tree->GetRoot();
|
||||
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
auto it = tree->begin(static_cast<std::shared_ptr<DatasetOp>>(root_op));
|
||||
++it;
|
||||
auto *map_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(map_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
EXPECT_EQ((*func_it)->Name(), kDecodeOp);
|
||||
++func_it;
|
||||
EXPECT_EQ((*func_it)->Name(), kRandomCropAndResizeOp);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResizeEnabled) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestTensorOpFusionPass-RandomCropDecodeResizeEnabled";
|
||||
|
||||
std::string folder_path = datasets_root_path_ + "/testPK/data/";
|
||||
std::shared_ptr<Dataset> ds = ImageFolder(folder_path, false, SequentialSampler(0, 11));
|
||||
ds = ds->SetNumWorkers(16);
|
||||
|
||||
// Create objects for the tensor ops
|
||||
std::shared_ptr<TensorOperation> decode = vision::Decode();
|
||||
std::shared_ptr<TensorOperation> random_resized_crop = vision::RandomResizedCrop({5});
|
||||
ds = ds->Map({decode, random_resized_crop}, {"image"});
|
||||
|
||||
std::shared_ptr<DatasetNode> node = ds->IRNode();
|
||||
auto ir_tree = std::make_shared<TreeAdapter>();
|
||||
// Enable IR optimization pass
|
||||
ir_tree->SetOptimize(true);
|
||||
Status rc;
|
||||
rc = ir_tree->Compile(node);
|
||||
EXPECT_TRUE(rc);
|
||||
auto root_op = ir_tree->GetRoot();
|
||||
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
auto it = tree->begin(static_cast<std::shared_ptr<DatasetOp>>(root_op));
|
||||
++it;
|
||||
auto *map_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(map_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
// FIXME: Currently the following 2 commented out verifications for this test will fail because this
|
||||
// optimization is still in ExecutionTree code, and not yet in IR optimization pass
|
||||
// However, use a bogus check for func_it, to avoid compile error for unused variable.
|
||||
EXPECT_EQ(func_it, func_it);
|
||||
// EXPECT_EQ((*func_it)->Name(), kRandomCropDecodeResizeOp);
|
||||
// EXPECT_EQ(++func_it, tfuncs.end());
|
||||
}
|
||||
|
@ -1,63 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/util/circular_pool.h"
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
using mindspore::ExceptionType::NoExceptionType;
|
||||
using mindspore::LogStream;
|
||||
|
||||
class MindDataTestrepeat_op : public UT::DatasetOpTesting {
|
||||
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestrepeat_op, Testrepeat_opFuntions) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestrepeat_op.";
|
||||
auto my_tree = std::make_shared<ExecutionTree>();
|
||||
|
||||
std::shared_ptr<DatasetOp> parent_op = std::make_shared<RepeatOp>(32);
|
||||
std::string dataset_path;
|
||||
dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data";
|
||||
// TFReaderOp
|
||||
std::shared_ptr<TFReaderOp> my_tfreader_op;
|
||||
TFReaderOp::Builder builder;
|
||||
builder.SetDatasetFilesList({dataset_path})
|
||||
.SetRowsPerBuffer(16)
|
||||
.SetWorkerConnectorSize(16)
|
||||
.SetNumWorkers(16);
|
||||
Status rc= builder.Build(&my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(my_tfreader_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
rc = my_tree->AssociateNode(parent_op);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
ASSERT_NE(parent_op, nullptr);
|
||||
ASSERT_NE(my_tfreader_op, nullptr);
|
||||
parent_op->AddChild(std::move(my_tfreader_op));
|
||||
MS_LOG(INFO) << parent_op;
|
||||
my_tree->AssignRoot(parent_op);
|
||||
my_tree->Prepare();
|
||||
|
||||
RepeatOp RepeatOpOp();
|
||||
|
||||
std::shared_ptr<RepeatOp> repeat_op;
|
||||
rc = RepeatOp::Builder(3).Build(&repeat_op);
|
||||
ASSERT_NE(repeat_op, nullptr);
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "minddata/dataset/core/client.h"
|
||||
#include "common/common.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "minddata/dataset/kernels/image/random_crop_and_resize_op.h"
|
||||
#include "minddata/dataset/kernels/image/decode_op.h"
|
||||
#include "minddata/dataset/engine/datasetops/source/image_folder_op.h"
|
||||
#include "minddata/dataset/engine/execution_tree.h"
|
||||
|
||||
|
||||
using namespace mindspore::dataset;
|
||||
using mindspore::LogStream;
|
||||
using mindspore::MsLogLevel::INFO;
|
||||
|
||||
class MindDataTestTensorOpFusionPass : public UT::DatasetOpTesting {
|
||||
public:
|
||||
MindDataTestTensorOpFusionPass() = default;
|
||||
void SetUp() override { GlobalInit(); }
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_disabled) {
|
||||
MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion";
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::shared_ptr<SamplerRT> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
auto rcar_op = std::make_shared<RandomCropAndResizeOp>();
|
||||
auto decode_op = std::make_shared<DecodeOp>();
|
||||
Status rc;
|
||||
std::vector<std::shared_ptr<TensorOp>> func_list;
|
||||
func_list.push_back(decode_op);
|
||||
func_list.push_back(rcar_op);
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
tree = Build({ImageFolder(16, 2, 32, "./", false), map_op});
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc);
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
auto it = tree->begin();
|
||||
++it;
|
||||
auto *m_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(m_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
EXPECT_EQ((*func_it)->Name(), kDecodeOp);
|
||||
++func_it;
|
||||
EXPECT_EQ((*func_it)->Name(), kRandomCropAndResizeOp);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestTensorOpFusionPass, RandomCropDecodeResize_fusion_enabled) {
|
||||
MS_LOG(INFO) << "Doing RandomCropDecodeResize_fusion";
|
||||
std::shared_ptr<ImageFolderOp> ImageFolder(int64_t num_works, int64_t rows, int64_t conns, std::string path,
|
||||
bool shuf = false, std::shared_ptr<SamplerRT> sampler = nullptr,
|
||||
std::map<std::string, int32_t> map = {}, bool decode = false);
|
||||
std::shared_ptr<ExecutionTree> Build(std::vector<std::shared_ptr<DatasetOp>> ops);
|
||||
auto rcar_op = std::make_shared<RandomCropAndResizeOp>();
|
||||
auto decode_op = std::make_shared<DecodeOp>();
|
||||
Status rc;
|
||||
std::vector<std::shared_ptr<TensorOp>> func_list;
|
||||
func_list.push_back(decode_op);
|
||||
func_list.push_back(rcar_op);
|
||||
std::shared_ptr<MapOp> map_op;
|
||||
MapOp::Builder map_decode_builder;
|
||||
map_decode_builder.SetInColNames({}).SetOutColNames({}).SetTensorFuncs(func_list).SetNumWorkers(4);
|
||||
rc = map_decode_builder.Build(&map_op);
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
auto tree = std::make_shared<ExecutionTree>();
|
||||
tree = Build({ImageFolder(16, 2, 32, "./", false), map_op});
|
||||
rc = tree->SetOptimize(true);
|
||||
EXPECT_TRUE(rc);
|
||||
rc = tree->Prepare();
|
||||
EXPECT_TRUE(rc.IsOk());
|
||||
rc = tree->SetOptimize(false);
|
||||
EXPECT_TRUE(rc.IsError());
|
||||
auto it = tree->begin();
|
||||
++it;
|
||||
auto *m_op = &(*it);
|
||||
auto tfuncs = static_cast<MapOp *>(m_op)->TFuncs();
|
||||
auto func_it = tfuncs.begin();
|
||||
EXPECT_EQ((*func_it)->Name(), kRandomCropDecodeResizeOp);
|
||||
EXPECT_EQ(++func_it, tfuncs.end());
|
||||
}
|
Loading…
Reference in new issue