Make weight normalization adapt to the up-to-date code

emailweixu-patch-1
guosheng 7 years ago
parent a422f34607
commit 6b9f1d343e

@ -226,8 +226,8 @@ class LayerHelper(object):
scale = elementwise_div( scale = elementwise_div(
x=g, y=norm) # The shapes of g and norm are the same. x=g, y=norm) # The shapes of g and norm are the same.
# Currently, elementwise_mul only support broadcast when the shape # Currently, elementwise_mul only support broadcast when the shape
# of y is a subset of x. Thus, we should reshape y to squeeze to # of y is a subset of the shape of x. Thus, we reshape y to squeeze
# achive it. # to achive the subset.
w = elementwise_mul( w = elementwise_mul(
x=v, x=v,
y=scale if dim is None else reshape( y=scale if dim is None else reshape(

@ -15,7 +15,10 @@
from initializer import Initializer, Xavier, Constant from initializer import Initializer, Xavier, Constant
from regularizer import WeightDecayRegularizer from regularizer import WeightDecayRegularizer
__all__ = ['ParamAttr'] __all__ = [
'ParamAttr',
'WeightNormParamAttr',
]
class ParamAttr(object): class ParamAttr(object):
@ -92,7 +95,7 @@ class WeightNormParamAttr(ParamAttr):
""" """
# List to record the parameters reparameterized by weight normalization. # List to record the parameters reparameterized by weight normalization.
# If these parameters are treated as Variable rather than Parameter, # If these parameters are treated as Variable rather than Parameter,
# it can be used to discriminate these parameters and help to serialize # it can be used to discriminate these parameters and help to serialize
# these paramters for inference. # these paramters for inference.
params_with_weight_norm = [] params_with_weight_norm = []

@ -52,7 +52,7 @@ class TestWeightNormalization(unittest.TestCase):
def run_program(self): def run_program(self):
outputs = [] outputs = []
places = [core.CPUPlace()] places = [core.CPUPlace()]
if core.is_compile_gpu(): if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0)) places.append(core.CUDAPlace(0))
for place in places: for place in places:
self.set_inputs(place) self.set_inputs(place)

Loading…
Cancel
Save