|
|
|
@ -205,6 +205,7 @@ class GraphWrapper(object):
|
|
|
|
|
super(GraphWrapper, self).__init__()
|
|
|
|
|
self.program = Program() if program is None else program
|
|
|
|
|
self.persistables = {}
|
|
|
|
|
self.teacher_persistables = {}
|
|
|
|
|
for var in self.program.list_vars():
|
|
|
|
|
if var.persistable:
|
|
|
|
|
self.persistables[var.name] = var
|
|
|
|
@ -306,6 +307,8 @@ class GraphWrapper(object):
|
|
|
|
|
graph(GraphWrapper): The graph to be merged by current graph.
|
|
|
|
|
"""
|
|
|
|
|
for var in graph.program.list_vars():
|
|
|
|
|
if var.persistable:
|
|
|
|
|
self.teacher_persistables[var.name] = var
|
|
|
|
|
new_var = self.program.global_block()._clone_variable(
|
|
|
|
|
var, force_persistable=False)
|
|
|
|
|
new_var.stop_gradient = var.stop_gradient
|
|
|
|
@ -479,7 +482,7 @@ class GraphWrapper(object):
|
|
|
|
|
self.persistables[var.name] = var
|
|
|
|
|
persistables = []
|
|
|
|
|
for var in self.persistables:
|
|
|
|
|
if 'reader' not in var and 'double_buffer' not in var:
|
|
|
|
|
if 'reader' not in var and 'double_buffer' not in var and var not in self.teacher_persistables:
|
|
|
|
|
persistables.append(self.persistables[var])
|
|
|
|
|
|
|
|
|
|
io.save_vars(exe.exe, path, vars=persistables)
|
|
|
|
|