[Dy2stat]Support buffers and register_buffer in Layer (#24888)

* support to save varBase created in __init__ test=develop

* polish code test=develop

* refine to_static_var test=develop

* refine warning test=develop

* add unitteset for to_static_var test=develop

* fix logger test=develop

* polish buffers en doc test=develop

* fix param_guard test=develop

* refine en doc test=develop
fix_copy_if_different
Aurelius84 5 years ago committed by GitHub
parent 4c964abdf7
commit 02adf68dcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,17 +64,26 @@ _functional_dygraph_context_manager = None
@signature_safe_contextmanager @signature_safe_contextmanager
def param_guard(parameters): def param_guard(parameters):
# Note: parameters is a reference of self._parameters # Note: parameters is a reference of self._parameters or self._buffers
if not framework.in_dygraph_mode() and parameters: if not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy() origin_parameters = parameters.copy()
for name, var_base in parameters.items(): for name, var_base in parameters.items():
if isinstance(var_base, core.VarBase): if isinstance(var_base, core.VarBase):
new_var = framework.Parameter( # Convert ParamBase into Parameter with same attributes in dy2stat.
var_base.block, if isinstance(var_base, framework.ParamBase):
var_base.shape, new_var = var_base._to_static_var(to_parameter=True)
var_base.dtype, else:
var_base.type, # Check whether has been created before.
name=var_base.name) if var_base.name in var_base.block.vars:
new_var = var_base.block.vars[var_base.name]
# Note(Aurelius84): Convert VarBase in self._buffers into Variabe with
# same attributes and set persistable=True to allow saving this var.
# Because users can create a VarBase in `__init__` like a
# `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
else:
new_var = var_base._to_static_var(
to_parameter=False, persistable=True)
parameters[name] = new_var parameters[name] = new_var
yield yield
parameters.update(origin_parameters) parameters.update(origin_parameters)

@ -272,18 +272,19 @@ class PartialProgramLayer(layers.Layer):
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s." "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
% type(self._params)) % type(self._params))
params_name_set = set() param_and_buffer_names_set = set()
for i, param in enumerate(self._params): for i, var in enumerate(self._params):
if not isinstance(param, framework.ParamBase): # self._params constains parameters and buffers with persistable=True.
if not isinstance(var, core.VarBase):
raise TypeError( raise TypeError(
'Type of self._params[{}] in PartialProgramLayer should be framework.ParamBase, but received {}.'. 'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.
format(i, type(param))) format(i, type(var)))
params_name_set.add(param.name) param_and_buffer_names_set.add(var.name)
for block in main_program.blocks: for block in main_program.blocks:
for name, var in block.vars.items(): for name, var in block.vars.items():
if isinstance(var, framework.Parameter): if isinstance(var, framework.Parameter):
if name not in params_name_set: if name not in param_and_buffer_names_set:
raise ValueError( raise ValueError(
"\n\tWe don't support to define layer with parameters in the function " "\n\tWe don't support to define layer with parameters in the function "
"decorated by `@declarative`.\n\tBecause that will re-defined parameters " "decorated by `@declarative`.\n\tBecause that will re-defined parameters "

@ -15,7 +15,7 @@
from __future__ import print_function from __future__ import print_function
import gast import gast
import inspect import inspect
import logging import warnings
import textwrap import textwrap
import threading import threading
import collections import collections
@ -39,8 +39,6 @@ from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_progr
__all__ = ['ProgramTranslator', 'convert_to_static'] __all__ = ['ProgramTranslator', 'convert_to_static']
logger = logging.getLogger("fluid")
class FunctionCache(object): class FunctionCache(object):
""" """
@ -131,16 +129,37 @@ class FunctionSpec(object):
return self._args and isinstance(self._args[0], layers.Layer) return self._args and isinstance(self._args[0], layers.Layer)
def parameters(self, include_sublayer=True): def parameters(self, include_sublayer=True):
"""
Returns parameters of decorated layers. If set `include_sublayer` True,
the parameters created in sub layers will be added.
"""
params = collections.OrderedDict() params = collections.OrderedDict()
if self.is_method(): if self.is_method():
layer_instance = self._args[0]
if include_sublayer: if include_sublayer:
params = self._args[0].parameters() params = layer_instance.parameters()
names = [p.name for p in params] names = [p.name for p in params]
params = collections.OrderedDict(zip(names, params)) params = collections.OrderedDict(zip(names, params))
else: else:
params = self._args[0]._parameters params = layer_instance._parameters
return params return params
def buffers(self, include_sublayer=True):
"""
Returns Variable buffers of decorated layers. If set `include_sublayer` True,
the Variable buffers created in sub layers will be added.
"""
buffers = collections.OrderedDict()
if self.is_method():
layer_instance = self._args[0]
if include_sublayer:
buffers = layer_instance.buffers()
names = [buffer.name for buffer in buffers]
buffers = collections.OrderedDict(zip(names, buffers))
else:
buffers = layer_instance._buffers
return buffers
@switch_to_static_graph @switch_to_static_graph
def to_static_inputs(self, main_program): def to_static_inputs(self, main_program):
inputs = [] inputs = []
@ -251,11 +270,13 @@ class ConcreteProgram(object):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs(main_program) inputs = func_spec.to_static_inputs(main_program)
# 2. Gets all ParamBases in the function # 2. Gets all ParamBases and buffered VarBases in the function
all_parameters = list(func_spec.parameters().values()) all_parameters_and_buffers = list(func_spec.parameters().values(
)) + list(func_spec.buffers().values())
# 3. Builds program only once and returns the output Variables. # 3. Builds program only once and returns the output Variables.
with param_guard(func_spec.parameters(False)): with param_guard(func_spec.parameters(False)), param_guard(
func_spec.buffers(False)):
outputs = static_func(*inputs) outputs = static_func(*inputs)
if not isinstance(outputs, (tuple, list)): if not isinstance(outputs, (tuple, list)):
outputs = [outputs] if outputs else [] outputs = [outputs] if outputs else []
@ -263,7 +284,7 @@ class ConcreteProgram(object):
return ConcreteProgram( return ConcreteProgram(
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
parameters=all_parameters, parameters=all_parameters_and_buffers,
func=dygraph_function, func=dygraph_function,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
@ -439,7 +460,7 @@ class ProgramTranslator(object):
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output" ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_declarative: if not self.enable_declarative:
logger.info( warnings.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
"We will just return dygraph output.") "We will just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
@ -490,7 +511,7 @@ class ProgramTranslator(object):
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func" ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_declarative: if not self.enable_declarative:
logger.info( warnings.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will " "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
"just return dygraph output.") "just return dygraph output.")
return dygraph_func return dygraph_func
@ -543,7 +564,7 @@ class ProgramTranslator(object):
dygraph_func dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program" ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_declarative: if not self.enable_declarative:
logger.info( warnings.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False." "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
"We will just return dygraph output.") "We will just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)

@ -16,7 +16,7 @@ from __future__ import print_function
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func'] __all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
import logging import warnings
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
@ -26,8 +26,6 @@ from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
logger = logging.getLogger("fluid")
def create_program_from_desc(program_desc): def create_program_from_desc(program_desc):
program = Program() program = Program()
@ -104,7 +102,7 @@ def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative: if in_dygraph_mode() or not program_translator.enable_declarative:
logger.info( warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in " "The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. " "dygraph mode or set ProgramTranslator.enable to False. "
"We will just return dygraph output.") "We will just return dygraph output.")
@ -156,7 +154,7 @@ def _declarative_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
if not program_translator.enable_declarative: if not program_translator.enable_declarative:
logger.info( warnings.warn(
"The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. " "The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.") "We will just return dygraph output.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)

File diff suppressed because it is too large Load Diff

@ -12,16 +12,67 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from .. import framework from .. import framework
from .. import core from .. import core
from . import BackwardStrategy from . import BackwardStrategy
from ..framework import Variable, _getitem_impl_ from ..framework import Variable, Parameter, ParamBase
from .. import unique_name from .base import switch_to_static_graph
import numpy as np import numpy as np
from .math_op_patch import monkey_patch_math_varbase from .math_op_patch import monkey_patch_math_varbase
def monkey_patch_varbase(): def monkey_patch_varbase():
@switch_to_static_graph
def _to_static_var(self, to_parameter=False, **kwargs):
"""
**Notes**:
**This API is ONLY available in Dygraph mode**
Transform a VarBase into static Variable with same attributes. It's a low level interface used
in dy2static and shall not be called directly.
Args:
to_parameter (bool): It takes effect only if the input a VarBase. If set True,
the VarBase will be converted into framework.Parameters. Otherwise, it will
be converted into framework.Variable. Default False.
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
import numpy as np
data = np.ones([3, 1024], dtype='float32')
with fluid.dygraph.guard():
var_base = to_variable(data)
static_var = var_base._to_static_var()
"""
if isinstance(self, ParamBase):
attr_kwargs = self.__dict__.copy()
else:
attr_names = [
name for name in dir(self)
if not (inspect.ismethod(getattr(self, name)) or
name.startswith('_'))
]
attr_kwargs = {name: getattr(self, name) for name in attr_names}
attr_keys = ['block', 'shape', 'dtype', 'type', 'name', 'persistable']
for attr in attr_keys:
attr_kwargs[attr] = getattr(self, attr, None)
attr_kwargs.update(kwargs)
if to_parameter or isinstance(self, ParamBase):
del attr_kwargs['persistable']
static_var = Parameter(**attr_kwargs)
else:
static_var = Variable(**attr_kwargs)
return static_var
# TODO(jiabin): move this to cplusplus end if we find some performance issue on it # TODO(jiabin): move this to cplusplus end if we find some performance issue on it
@framework.dygraph_only @framework.dygraph_only
def set_value(self, value): def set_value(self, value):
@ -214,8 +265,9 @@ def monkey_patch_varbase():
for method_name, method in ( for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__), ("__bool__", __bool__), ("__nonzero__", __nonzero__),
("set_value", set_value), ("block", block), ("backward", backward), ("_to_static_var", _to_static_var), ("set_value", set_value),
("gradient", gradient), ("__str__", __str__), ("to_string", to_string)): ("block", block), ("backward", backward), ("gradient", gradient),
("__str__", __str__), ("to_string", to_string)):
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
# patch math methods for varbase # patch math methods for varbase

@ -186,11 +186,11 @@ class BMN(fluid.dygraph.Layer):
act="relu") act="relu")
# init to speed up # init to speed up
self.sample_mask = get_interp1d_mask( sample_mask = get_interp1d_mask(self.tscale, self.dscale,
self.tscale, self.dscale, self.prop_boundary_ratio, self.num_sample, self.prop_boundary_ratio,
self.num_sample_perbin) self.num_sample, self.num_sample_perbin)
# self.sample_mask = fluid.dygraph.base.to_variable(sample_mask) self.sample_mask = fluid.dygraph.base.to_variable(sample_mask)
# self.sample_mask.stop_gradient = True self.sample_mask.stop_gradient = True
self.p_conv3d1 = fluid.dygraph.Conv3D( self.p_conv3d1 = fluid.dygraph.Conv3D(
num_channels=128, num_channels=128,
@ -241,12 +241,6 @@ class BMN(fluid.dygraph.Layer):
@declarative @declarative
def forward(self, x): def forward(self, x):
# TODO(Aurelius84): sample_mask is created in `__init__`,
# but currently we don't support that. The two lines code
# will be removed when support creating var outside of forward.
sample_mask = to_variable(self.sample_mask)
sample_mask.stop_gradient = True
# Base Module # Base Module
x = self.b_conv1(x) x = self.b_conv1(x)
x = self.b_conv2(x) x = self.b_conv2(x)
@ -262,7 +256,7 @@ class BMN(fluid.dygraph.Layer):
# PEM # PEM
xp = self.p_conv1(x) xp = self.p_conv1(x)
# BM layer # BM layer
xp = fluid.layers.matmul(xp, sample_mask) xp = fluid.layers.matmul(xp, self.sample_mask)
xp = fluid.layers.reshape( xp = fluid.layers.reshape(
xp, shape=[0, 0, -1, self.dscale, self.tscale]) xp, shape=[0, 0, -1, self.dscale, self.tscale])

@ -16,6 +16,8 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import ParamBase
class L1(fluid.Layer): class L1(fluid.Layer):
@ -85,5 +87,181 @@ class TestBaseLayer(unittest.TestCase):
self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2]))) self.assertTrue(np.allclose(ret.numpy(), 0.8 * np.ones([2, 2])))
class BufferLayer(fluid.Layer):
def __init__(self):
super(BufferLayer, self).__init__()
buffer_var = to_variable(np.zeros([2, 4]).astype('int32'))
self.register_buffer("layer_buffer", buffer_var)
def forward(self):
pass
class BufferNet(fluid.Layer):
def __init__(self):
super(BufferNet, self).__init__()
self.buffer_layer = BufferLayer()
self.w1 = self.create_parameter(
shape=[2, 2], dtype='float32', is_bias=False)
buffer_var = to_variable(np.ones([2, 4]).astype('int32'))
self.register_buffer("net_buffer", buffer_var)
self.new_buffer = to_variable(np.ones([4, 2]).astype('int32'))
def forward(self):
pass
class TestBuffer(unittest.TestCase):
def test_buffers_and_named_buffers(self):
def names(named_buffers):
return [name for name, _ in named_buffers]
with fluid.dygraph.guard():
layer = BufferLayer()
net = BufferNet()
self.assertEqual(len(layer.buffers()), 1)
self.assertEqual(names(layer.named_buffers()), ['layer_buffer'])
self.assertEqual(len(net.buffers()), 3)
self.assertEqual(
names(net.named_buffers()),
['net_buffer', 'new_buffer', 'buffer_layer.layer_buffer'])
self.assertEqual(len(net.buffers(include_sublayers=False)), 2)
self.assertEqual(
names(net.named_buffers(include_sublayers=False)),
['net_buffer', 'new_buffer'])
def test_register_buffer_with_error(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var = to_variable(np.zeros([1]))
with self.assertRaisesRegexp(TypeError,
"name of buffer should be a string"):
net.register_buffer(12, var)
with self.assertRaisesRegexp(TypeError,
"buffer should be a core.VarBase"):
net.register_buffer("buffer_name", ParamBase([2, 2], 'float32'))
with self.assertRaisesRegexp(KeyError,
"name of buffer can not contain"):
net.register_buffer("buffer.name", var)
with self.assertRaisesRegexp(KeyError,
"name of buffer can not be empty"):
net.register_buffer("", var)
net.attr_name = 10
with self.assertRaisesRegexp(KeyError, "already exists"):
net.register_buffer("attr_name", var)
del net.attr_name
net.attr_name = ParamBase([2, 2], 'float32')
with self.assertRaisesRegexp(KeyError, "already exists"):
net.register_buffer("attr_name", var)
def test_register_buffer_same_name(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
var2 = to_variable(np.zeros([2]))
var3 = to_variable(np.zeros([3]))
net.register_buffer("buffer_name", var1)
self.assert_var_base_equal(net.buffer_name, var1)
net.register_buffer("buffer_name", var2)
self.assert_var_base_equal(net.buffer_name, var2)
net.register_buffer("buffer_name", var3)
self.assert_var_base_equal(net.buffer_name, var3)
def test_buffer_not_persistable(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
net.register_buffer("buffer_name", var1, persistable=False)
self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 0)
def test_buffer_not_persistable_del(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
net.register_buffer("buffer_name", var1, persistable=False)
del net.buffer_name
self.assertEqual(len(net.buffers()), 0)
def test_buffer_not_persistable_overwrite(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
var2 = to_variable(np.zeros([2]))
net.register_buffer("buffer_name", var1, persistable=False)
net.register_buffer("buffer_name", var2)
# Allow to overwrite a non-persistable buffer with a persistable var.
self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 1)
net.register_buffer("buffer_name", var1, persistable=False)
self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 0)
def test_buffer_not_persistable_assign(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
net.register_buffer("buffer_name", var1, persistable=False)
# Assigning Nones will remove the buffer, but allow to re-assign
# to remark it as buffer.
net.buffer_name = None
self.assertEqual(len(net.buffers()), 0)
self.assertEqual(len(net.state_dict()), 0)
net.buffer_name = var1
self.assertEqual(len(net.buffers()), 1)
self.assertEqual(len(net.state_dict()), 0)
# Re-assign a ParamBase will remove the buffer.
net.buffer_name = ParamBase([2, 2], 'float32')
self.assertEqual(len(net.buffers()), 0)
self.assertEqual(len(net.state_dict()), 1)
def test_buffer_not_persistable_load(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([1]))
net.register_buffer("buffer_name", var1, persistable=False)
net.load_dict({})
def test_buffer_state_dict(self):
with fluid.dygraph.guard():
net = fluid.Layer()
var1 = to_variable(np.zeros([2, 3]))
var2 = to_variable(np.zeros([3, 2]))
net.register_buffer("buffer_var1", var1)
net.register_buffer("buffer_var2", var2, persistable=False)
self.assertEqual(len(net.state_dict()), 1)
self.assertEqual([name for name, _ in net.state_dict().items()],
["buffer_var1"])
# load state_dict
net_load = fluid.Layer()
var = to_variable(np.ones([2, 3]))
net_load.register_buffer("buffer_var1", var)
net_load.load_dict(net.state_dict())
self.assert_var_base_equal(net_load.buffer_var1, var1)
def assert_var_base_equal(self, var1, var2):
self.assertTrue(np.array_equal(var1.numpy(), var2.numpy()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

@ -233,6 +233,52 @@ class TestVarBase(unittest.TestCase):
assert bool(var1) == False, "bool(var1) is False" assert bool(var1) == False, "bool(var1) is False"
assert bool(var2) == True, "bool(var2) is True" assert bool(var2) == True, "bool(var2) is True"
def test_to_static_var(self):
with fluid.dygraph.guard():
# Convert VarBase into Variable or Parameter
var_base = fluid.dygraph.to_variable(self.array, name="var_base_1")
static_var = var_base._to_static_var()
self._assert_to_static(var_base, static_var)
var_base = fluid.dygraph.to_variable(self.array, name="var_base_2")
static_param = var_base._to_static_var(to_parameter=True)
self._assert_to_static(var_base, static_param, True)
# Convert ParamBase into Parameter
fc = fluid.dygraph.Linear(
10,
20,
param_attr=fluid.ParamAttr(
learning_rate=0.001,
do_model_average=True,
regularizer=fluid.regularizer.L1Decay()))
weight = fc.parameters()[0]
static_param = weight._to_static_var()
self._assert_to_static(weight, static_param, True)
def _assert_to_static(self, var_base, static_var, is_param=False):
if is_param:
self.assertTrue(isinstance(static_var, fluid.framework.Parameter))
self.assertTrue(static_var.persistable, True)
if isinstance(var_base, fluid.framework.ParamBase):
for attr in ['trainable', 'is_distributed', 'do_model_average']:
self.assertEqual(
getattr(var_base, attr), getattr(static_var, attr))
self.assertEqual(static_var.optimize_attr['learning_rate'],
0.001)
self.assertTrue(
isinstance(static_var.regularizer,
fluid.regularizer.L1Decay))
else:
self.assertTrue(isinstance(static_var, fluid.framework.Variable))
attr_keys = ['block', 'dtype', 'type', 'name']
for attr in attr_keys:
self.assertEqual(getattr(var_base, attr), getattr(static_var, attr))
self.assertListEqual(list(var_base.shape), list(static_var.shape))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save