!11034 Migrate GeneratorDataset reset logic to IR optimizer
From: @nsyca Reviewed-by: Signed-off-by:pull/11034/MERGE
commit
9646953465
@ -1,19 +1,35 @@
|
||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
add_library(engine-opt OBJECT
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
pass.cc
|
||||
post/auto_worker_pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_error_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
pre/epoch_ctrl_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
pre/getter_pass.cc
|
||||
pre/input_validation_pass.cc
|
||||
pre/node_removal_pass.cc
|
||||
pre/removal_pass.cc
|
||||
util/printer_pass.cc
|
||||
|
||||
set(DATASET_ENGINE_OPT_SRC_FILES
|
||||
pass.cc
|
||||
post/auto_worker_pass.cc
|
||||
pre/cache_validation_pass.cc
|
||||
pre/deep_copy_pass.cc
|
||||
pre/getter_pass.cc
|
||||
pre/input_validation_pass.cc
|
||||
pre/epoch_ctrl_pass.cc
|
||||
pre/node_removal_pass.cc
|
||||
)
|
||||
|
||||
# This set of files is for ExecTree's optimizer. It is being migrated to IR's optimizer.
|
||||
# When the migration is complete, we will remove these files.
|
||||
set(DATASET_ENGINE_OPT_SRC_FILES
|
||||
${DATASET_ENGINE_OPT_SRC_FILES}
|
||||
optional/tensor_op_fusion_pass.cc
|
||||
pre/cache_error_pass.cc
|
||||
post/repeat_pass.cc
|
||||
pre/cache_transform_pass.cc
|
||||
pre/epoch_injection_pass.cc
|
||||
util/printer_pass.cc
|
||||
pre/removal_pass.cc
|
||||
)
|
||||
|
||||
if (ENABLE_PYTHON)
|
||||
set(DATASET_ENGINE_OPT_SRC_FILES
|
||||
${DATASET_ENGINE_OPT_SRC_FILES}
|
||||
post/generator_node_pass.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
add_library(engine-opt OBJECT ${DATASET_ENGINE_OPT_SRC_FILES})
|
||||
|
@ -0,0 +1,108 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include "minddata/dataset/engine/opt/post/generator_node_pass.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/source/generator_node.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
GeneratorNodePass::GeneratorNodePass() : repeat_ancestors_({}) {}
|
||||
/*
|
||||
* A diagram shows how the code work:
|
||||
* With the tree below as an input
|
||||
*
|
||||
* EpochCtrl(-1)
|
||||
* / \
|
||||
* Repeat1 \
|
||||
* / Repeat3
|
||||
* .. \
|
||||
* / Generator2
|
||||
* Repeat2 Add: Gen2-Rep3
|
||||
* /
|
||||
* Generator1
|
||||
* Add: Gen1-Rep2
|
||||
*
|
||||
* The sequence of the DFS walk of the tree looks like this:
|
||||
* 1) Visit(EpochCtrl): push EpochCtrl, repeat_ancestor_ = { EpochCtrl }
|
||||
* 2) Visit(Repeat1): push Repeat1, repeat_ancestors_ = { EpochCtrl, Repeat1 }
|
||||
* 3) Visit(Repeat2): push Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1, Repeat2 }
|
||||
* 4) Visit(Generator1): record Repeat2 as its ancestor
|
||||
* record Repeat1 as Repeat2's ancestor
|
||||
* record EpochCtrl as Repeat1's ancestor
|
||||
* 5) VisitAfter(Repeat2): pop Repeat2, repeat_ancestors_ = { EpochCtrl, Repeat1 }
|
||||
* 6) VisitAfter(Repeat1): pop Repeat1, repeat_ancestors_ = { EpochCtrl }
|
||||
* 7) Visit(Repeat3): push Repeat3, repeat_ancestors_ = { EpochCtrl, Repeat3 }
|
||||
* 8) Visit(Generator2): record Repeat3 as its ancestors
|
||||
* record EpochCtrl as Repeat3's ancestor
|
||||
* 9) VisitAfter(Repeat3): pop Repeat3, repeat_ancestors_ = { EpochCtrl }
|
||||
* 10) VisitAfter(EpochCtrl): don't care. We could pop EpochCtrl.
|
||||
*/
|
||||
|
||||
Status GeneratorNodePass::Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
|
||||
// Add this EpochCtrl node as an ancestor of its descendant
|
||||
repeat_ancestors_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GeneratorNodePass::Visit(std::shared_ptr<RepeatNode> node, bool *const modified) {
|
||||
// Add this Repeat node as an ancestor of its descendant
|
||||
repeat_ancestors_.push_back(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GeneratorNodePass::Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) {
|
||||
// Form a reset relationship with the immediate Repeat/EpochCtrl ancestor node of this leaf Generator Node
|
||||
// only when any of its ancestors is an infinite repeat.
|
||||
if (repeat_ancestors_.size() > 0) {
|
||||
bool infinite_repeat = false;
|
||||
for (auto &repeat_ancestor : repeat_ancestors_) {
|
||||
if (repeat_ancestor->Count() < 0) {
|
||||
infinite_repeat = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (infinite_repeat) {
|
||||
// Form a pair-wise relationship between this leaf Generator node and its immediate Repeat/EpochCtrl
|
||||
// ancestor node, and between the next adjacent pairs in the vector. For example,
|
||||
// if we have GeneratorNode -> Repeat1 -> Repeat2 -> EpochCtrl(-1), the pair-wise relationships are:
|
||||
// (GeneratorNode, Repeat1), (Repeat1, Repeat2), and (Repeat2, EpochCtrl)
|
||||
for (auto i = repeat_ancestors_.size() - 1; i > 0; --i) {
|
||||
auto ancestor = repeat_ancestors_[i - 1];
|
||||
RETURN_IF_NOT_OK(repeat_ancestors_[i]->AddResetAncestor(ancestor));
|
||||
}
|
||||
RETURN_IF_NOT_OK(node->AddResetAncestor(repeat_ancestors_.back()));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GeneratorNodePass::VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) {
|
||||
// When we backtrack from the same Repeat node, we pop it out from the list of ancestors.
|
||||
repeat_ancestors_.pop_back();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GeneratorNodePass::VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) {
|
||||
// As EpochCtrl node is a terminal node, the process stops here.
|
||||
// Popping it back out of the reset ancestors is unnecessary.
|
||||
// This function becomes a no-op function and can be deleted completely.
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class GeneratorNodePass repeat_pass.h
|
||||
/// \brief This is a NodePass who's job is to perform setup actions for RepeatOps. A RepeatOp needs to have references
|
||||
/// to the eoe-producing (typically leaf) nodes underneath it.
|
||||
class GeneratorNodePass : public IRNodePass {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
GeneratorNodePass();
|
||||
|
||||
/// \brief Destructor
|
||||
~GeneratorNodePass() = default;
|
||||
|
||||
/// \brief Record the starting point to collect the Generator node
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<RepeatNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Record the starting point to collect the Generator node
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Add the Generator node to the set
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status Visit(std::shared_ptr<GeneratorNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Add the Generator node(s) from the set to this Repeat node for run-time processing
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<RepeatNode> node, bool *const modified) override;
|
||||
|
||||
/// \brief Add the Generator node(s) from the set to this EpochCtrl node for run-time processing
|
||||
/// \param[in] node The node being visited
|
||||
/// \param[inout] modified Indicator if the node was changed at all
|
||||
/// \return Status The status code returned
|
||||
Status VisitAfter(std::shared_ptr<EpochCtrlNode> node, bool *const modified) override;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<RepeatNode>> repeat_ancestors_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_OPT_POST_GENERATOR_NODE_PASS_H
|
@ -0,0 +1,208 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
|
||||
# Generate 2 rows of data (1, 2)
|
||||
def generator_1to2():
|
||||
for i in np.array([1, 2]):
|
||||
yield (np.array(i),)
|
||||
|
||||
# Generate 3 rows of data (10, 11, 12)
|
||||
def generator_10to12():
|
||||
for i in np.array([10, 11, 12]):
|
||||
yield (np.array(i),)
|
||||
|
||||
# Generate 3 rows of data (22, 23, 24)
|
||||
def generator_22to24():
|
||||
for i in np.array([22, 23, 24]):
|
||||
yield (np.array(i),)
|
||||
|
||||
def test_simple_repeat():
|
||||
|
||||
# Since numer of epoch is 1, the GeneratorPass logic will not add the reset logic.
|
||||
logger.info("test_simple_repeat")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
branch1 = data1.repeat(2)
|
||||
branch1 = branch1.skip(1) # Skip the first row
|
||||
|
||||
output = np.array([0])
|
||||
for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 2, 1, 2])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
def test_generator_reset_1():
|
||||
"""
|
||||
Test (Generator -> Repeat) + (Generator -> Repeat) + (Generator -> Repeat)
|
||||
"""
|
||||
logger.info("test_generator_reset_1")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
branch1 = data1.repeat(4)
|
||||
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
|
||||
branch2 = data2.repeat(2)
|
||||
branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between
|
||||
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
|
||||
branch3 = data3.repeat(3)
|
||||
branch3 = branch3.skip(1) # Skip the first row
|
||||
|
||||
concat1 = branch1 + branch2
|
||||
concat2 = concat1 + branch3
|
||||
|
||||
output = np.array([0])
|
||||
for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 10, 11, 12, 10, 11, 12, 23, 24, 22, 23, 24, 22, 23, 24])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
def test_generator_reset_2():
|
||||
"""
|
||||
Test ((Generator -> Repeat) + (Generator -> Repeat) -> Repeat) + (Generator)
|
||||
"""
|
||||
logger.info("test_generator_reset_2")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
data1 = data1.skip(1)
|
||||
branch1 = data1.repeat(3)
|
||||
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
|
||||
branch2 = data2.repeat(2)
|
||||
branch2 = branch2.take(10) # Meaningless opearation, just want to insert an op in between
|
||||
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
|
||||
branch3 = data3.skip(2) # Skip the first row
|
||||
|
||||
concat1 = branch1 + branch2
|
||||
concat2 = concat1.repeat(2).take(11) + branch3
|
||||
|
||||
output = np.array([0])
|
||||
for item in concat2.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 2, 2, 2, 10, 11, 12, 10, 11, 12, 2, 2, 24])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
def test_generator_reset_3():
|
||||
"""
|
||||
Test (Generator -> Repeat -> Repeat) + ((Generator -> Repeat) + (Generator)) -> Repeat) -> EpochCtrl
|
||||
"""
|
||||
logger.info("test_generator_reset_3")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
branch1 = data1.repeat(2)
|
||||
branch1 = branch1.skip(1)
|
||||
branch1 = branch1.take(2)
|
||||
branch1 = branch1.repeat(2)
|
||||
data2 = ds.GeneratorDataset(generator_10to12, ["data"])
|
||||
branch2 = data2.repeat(2)
|
||||
data3 = ds.GeneratorDataset(generator_22to24, ["data"])
|
||||
branch3 = data3.take(2)
|
||||
branch3 = branch3
|
||||
|
||||
concat1 = branch2 + branch3
|
||||
concat2 = branch1 + concat1.repeat(3).skip(5).take(15)
|
||||
|
||||
itr = concat2.create_dict_iterator(output_numpy=True)
|
||||
|
||||
num_epochs = 5
|
||||
output = np.array([0])
|
||||
golden = np.array([0])
|
||||
expected = np.array([2, 1, 2, 1, 12, 22, 23, 10, 11, 12, 10, 11, 12, 22, 23, 10, 11, 12, 10])
|
||||
for _ in range(num_epochs):
|
||||
golden = np.append(golden, expected)
|
||||
for item in itr:
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
itr.stop()
|
||||
|
||||
def test_generator_reset_4():
|
||||
"""
|
||||
Test Generator -> Repeat -> Repeat
|
||||
"""
|
||||
logger.info("test_generator_reset_4")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
branch1 = data1.repeat(4).repeat(2)
|
||||
|
||||
output = np.array([0])
|
||||
for item in branch1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
def test_generator_reset_5():
|
||||
"""
|
||||
Test Generator -> Repeat -> Repeat -> EpochCtrl
|
||||
"""
|
||||
logger.info("test_generator_reset_5")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_1to2, ["data"])
|
||||
branch1 = data1.repeat(3).take(3).repeat(2)
|
||||
|
||||
num_epochs = 2
|
||||
output = np.array([0])
|
||||
itr = branch1.create_dict_iterator(output_numpy=True)
|
||||
|
||||
for _ in range(num_epochs):
|
||||
for item in itr:
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
itr.stop()
|
||||
|
||||
def test_generator_reset_6():
|
||||
"""
|
||||
Test Generator -> Repeat -> Repeat -> EpochCtrl
|
||||
"""
|
||||
logger.info("test_generator_reset_6")
|
||||
# apply dataset operations
|
||||
data1 = ds.GeneratorDataset(generator_10to12, ["data"])
|
||||
branch1 = data1.repeat(2).take(5).repeat(2).skip(2)
|
||||
iter1 = branch1.create_dict_iterator(num_epochs=3, output_numpy=True)
|
||||
|
||||
output = np.array([0])
|
||||
for _ in range(2):
|
||||
for item in iter1:
|
||||
output = np.append(output, item["data"])
|
||||
|
||||
golden = np.array([0, 12, 10, 11, 10, 11, 12, 10, 11, 12, 10, 11, 10, 11, 12, 10, 11])
|
||||
|
||||
np.testing.assert_array_equal(output, golden)
|
||||
|
||||
# intentionally not adding itr.stop() to trigger the self-termination when itr is out of scope
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_generator_reset_1()
|
||||
test_generator_reset_2()
|
||||
test_generator_reset_3()
|
||||
test_generator_reset_4()
|
||||
test_generator_reset_5()
|
||||
test_generator_reset_6()
|
||||
logger.info('\n')
|
Loading…
Reference in new issue