|
|
|
@ -27,9 +27,6 @@ from .. import backward
|
|
|
|
|
from .base import switch_to_static_graph
|
|
|
|
|
from ... import compat as cpt
|
|
|
|
|
|
|
|
|
|
# Set Log level
|
|
|
|
|
logging.getLogger().setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
# DESIGN IDEA: Add an special operator, execute static program inside operator.
|
|
|
|
|
#
|
|
|
|
|
# Op's Inputs:
|
|
|
|
@ -397,7 +394,6 @@ class StaticModelRunner(layers.Layer):
|
|
|
|
|
if params_filename is None:
|
|
|
|
|
if not self._is_parameter(each_var):
|
|
|
|
|
continue
|
|
|
|
|
# logging.info("persis var name %s" % each_var.name())
|
|
|
|
|
framework._dygraph_tracer().trace_op(
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
@ -442,7 +438,6 @@ class StaticModelRunner(layers.Layer):
|
|
|
|
|
for param_name in self._parameters:
|
|
|
|
|
param_grad_name = param_name + core.grad_var_suffix()
|
|
|
|
|
if param_grad_name not in all_var_names:
|
|
|
|
|
logging.info("set %s stop gradient = True" % param_grad_name)
|
|
|
|
|
self._parameters[param_name].stop_gradient = True
|
|
|
|
|
|
|
|
|
|
def _get_all_var_names(self, program_desc):
|
|
|
|
@ -450,7 +445,6 @@ class StaticModelRunner(layers.Layer):
|
|
|
|
|
for i in six.moves.range(program_desc.num_blocks()):
|
|
|
|
|
block = program_desc.block(i)
|
|
|
|
|
for var in block.all_vars():
|
|
|
|
|
logging.info(var.name())
|
|
|
|
|
all_var_names.add(var.name())
|
|
|
|
|
return all_var_names
|
|
|
|
|
|
|
|
|
|