|
|
|
@ -38,13 +38,13 @@ def _cleanup():
|
|
|
|
|
|
|
|
|
|
def alter_tree(node):
|
|
|
|
|
"""Traversing the python Dataset tree/graph to perform some alteration to some specific nodes."""
|
|
|
|
|
if not node.input:
|
|
|
|
|
if not node.children:
|
|
|
|
|
return _alter_node(node)
|
|
|
|
|
|
|
|
|
|
converted_children = []
|
|
|
|
|
for input_op in node.input:
|
|
|
|
|
for input_op in node.children:
|
|
|
|
|
converted_children.append(alter_tree(input_op))
|
|
|
|
|
node.input = converted_children
|
|
|
|
|
node.children = converted_children
|
|
|
|
|
return _alter_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -86,14 +86,14 @@ class Iterator:
|
|
|
|
|
|
|
|
|
|
def __is_tree_node(self, node):
|
|
|
|
|
"""Check if a node is tree node."""
|
|
|
|
|
if not node.input:
|
|
|
|
|
if len(node.output) > 1:
|
|
|
|
|
if not node.children:
|
|
|
|
|
if len(node.parent) > 1:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
if len(node.output) > 1:
|
|
|
|
|
if len(node.parent) > 1:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
for input_node in node.input:
|
|
|
|
|
for input_node in node.children:
|
|
|
|
|
cls = self.__is_tree_node(input_node)
|
|
|
|
|
if not cls:
|
|
|
|
|
return False
|
|
|
|
@ -174,7 +174,7 @@ class Iterator:
|
|
|
|
|
op_type = self.__get_dataset_type(node)
|
|
|
|
|
c_node = self.depipeline.AddNodeToTree(op_type, node.get_args())
|
|
|
|
|
|
|
|
|
|
for py_child in node.input:
|
|
|
|
|
for py_child in node.children:
|
|
|
|
|
c_child = self.__convert_node_postorder(py_child)
|
|
|
|
|
self.depipeline.AddChildToParentNode(c_child, c_node)
|
|
|
|
|
|
|
|
|
@ -184,7 +184,7 @@ class Iterator:
|
|
|
|
|
"""Recursively get batch node in the dataset tree."""
|
|
|
|
|
if isinstance(dataset, de.BatchDataset):
|
|
|
|
|
return
|
|
|
|
|
for input_op in dataset.input:
|
|
|
|
|
for input_op in dataset.children:
|
|
|
|
|
self.__batch_node(input_op, level + 1)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
@ -194,11 +194,11 @@ class Iterator:
|
|
|
|
|
ptr = hex(id(dataset))
|
|
|
|
|
for _ in range(level):
|
|
|
|
|
logger.info("\t", end='')
|
|
|
|
|
if not dataset.input:
|
|
|
|
|
if not dataset.children:
|
|
|
|
|
logger.info("-%s (%s)", name, ptr)
|
|
|
|
|
else:
|
|
|
|
|
logger.info("+%s (%s)", name, ptr)
|
|
|
|
|
for input_op in dataset.input:
|
|
|
|
|
for input_op in dataset.children:
|
|
|
|
|
Iterator.__print_local(input_op, level + 1)
|
|
|
|
|
|
|
|
|
|
def print(self):
|
|
|
|
|