adapt to weight initializer modification

pull/1237/head
gengdongjie 6 years ago
parent 699d0c1082
commit ae9ce1629b

@ -64,11 +64,11 @@ if __name__ == '__main__':
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(), cell.weight.default_input.shape(),
cell.weight.default_input.dtype()) cell.weight.default_input.dtype()).to_tensor()
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(), cell.weight.default_input.shape(),
cell.weight.default_input.dtype()) cell.weight.default_input.dtype()).to_tensor()
if not config.label_smooth: if not config.label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)

@ -61,11 +61,11 @@ if __name__ == '__main__':
if isinstance(cell, nn.Conv2d): if isinstance(cell, nn.Conv2d):
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(), cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
cell.weight.default_input.shape(), cell.weight.default_input.shape(),
cell.weight.default_input.dtype()) cell.weight.default_input.dtype()).to_tensor()
if isinstance(cell, nn.Dense): if isinstance(cell, nn.Dense):
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(), cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
cell.weight.default_input.shape(), cell.weight.default_input.shape(),
cell.weight.default_input.dtype()) cell.weight.default_input.dtype()).to_tensor()
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0

Loading…
Cancel
Save