|
|
|
@ -17,6 +17,42 @@ __all__ = ['ParamAttr', 'ExtraAttr', 'ParameterAttribute',
|
|
|
|
|
'ExtraLayerAttribute']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_and_compare(x, Type):
|
|
|
|
|
"""
|
|
|
|
|
Convert x to be the same type as Type and then convert back to
|
|
|
|
|
check whether there is a loss of information
|
|
|
|
|
:param x: object to be checked
|
|
|
|
|
:param Type: target type to check x over
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
return type(x)(Type(x))==x
|
|
|
|
|
|
|
|
|
|
def is_compatible_with(x, Type):
|
|
|
|
|
"""
|
|
|
|
|
Check if x has a type compatible with Type
|
|
|
|
|
:param x: object to be checked
|
|
|
|
|
:param Type: target type to check x over
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if type(x) == Type:
|
|
|
|
|
return True
|
|
|
|
|
try:
|
|
|
|
|
if float == Type or int == Type:
|
|
|
|
|
# avoid those types that can be converted to float/int but not very
|
|
|
|
|
# meaningful and could potentially lead to error
|
|
|
|
|
# i.e., str and bool typed value should not be used for initializing float/int variable
|
|
|
|
|
if not isinstance(x, str) and not isinstance(x, bool):
|
|
|
|
|
return convert_and_compare(x, Type)
|
|
|
|
|
elif bool == Type:
|
|
|
|
|
# should not use string type to initialize bool variable
|
|
|
|
|
if not isinstance(x, str):
|
|
|
|
|
return convert_and_compare(x, Type)
|
|
|
|
|
else:
|
|
|
|
|
return False
|
|
|
|
|
except:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParameterAttribute(object):
|
|
|
|
|
"""
|
|
|
|
|
Parameter Attributes object. To fine-tuning network training process, user
|
|
|
|
@ -65,14 +101,18 @@ class ParameterAttribute(object):
|
|
|
|
|
elif initial_std is None and initial_mean is None and initial_max \
|
|
|
|
|
is None and initial_min is None:
|
|
|
|
|
self.attr = {'initial_smart': True}
|
|
|
|
|
elif isinstance(initial_std, float) or isinstance(initial_mean, float):
|
|
|
|
|
elif is_compatible_with(initial_std, float) or \
|
|
|
|
|
is_compatible_with(initial_mean, float):
|
|
|
|
|
self.attr = dict()
|
|
|
|
|
if initial_std is not None:
|
|
|
|
|
self.attr['initial_std'] = initial_std
|
|
|
|
|
if initial_mean is not None:
|
|
|
|
|
self.attr['initial_mean'] = initial_mean
|
|
|
|
|
self.attr['initial_strategy'] = 0 # Gauss Random
|
|
|
|
|
elif isinstance(initial_max, float) and isinstance(initial_min, float):
|
|
|
|
|
elif is_compatible_with(initial_max, float) and \
|
|
|
|
|
is_compatible_with(initial_min, float):
|
|
|
|
|
initial_max = initial_max
|
|
|
|
|
initial_min = initial_min
|
|
|
|
|
assert initial_min < initial_max
|
|
|
|
|
initial_mean = (initial_max + initial_min) / 2
|
|
|
|
|
initial_std = initial_mean - initial_min
|
|
|
|
@ -83,16 +123,16 @@ class ParameterAttribute(object):
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Unexpected branch.")
|
|
|
|
|
|
|
|
|
|
if not is_static and isinstance(l1_rate, float):
|
|
|
|
|
if not is_static and is_compatible_with(l1_rate, float):
|
|
|
|
|
self.attr['decay_rate_l1'] = l1_rate
|
|
|
|
|
|
|
|
|
|
if not is_static and isinstance(l2_rate, float):
|
|
|
|
|
if not is_static and is_compatible_with(l2_rate, float):
|
|
|
|
|
self.attr['decay_rate'] = l2_rate
|
|
|
|
|
|
|
|
|
|
if not is_static and isinstance(learning_rate, float):
|
|
|
|
|
if not is_static and is_compatible_with(learning_rate, float):
|
|
|
|
|
self.attr['learning_rate'] = learning_rate
|
|
|
|
|
|
|
|
|
|
if not is_static and isinstance(momentum, float):
|
|
|
|
|
if not is_static and is_compatible_with(momentum, float):
|
|
|
|
|
self.attr['momentum'] = momentum
|
|
|
|
|
|
|
|
|
|
if name is not None:
|
|
|
|
|