|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""train_imagenet."""
|
|
|
|
|
import os
|
|
|
|
|
import math
|
|
|
|
|
import argparse
|
|
|
|
|
import random
|
|
|
|
|
import numpy as np
|
|
|
|
@ -64,7 +63,6 @@ if __name__ == '__main__':
|
|
|
|
|
epoch_size = config.epoch_size
|
|
|
|
|
net = resnet101(class_num=config.class_num)
|
|
|
|
|
# weight init
|
|
|
|
|
default_recurisive_init(net)
|
|
|
|
|
for _, cell in net.cells_and_names():
|
|
|
|
|
if isinstance(cell, nn.Conv2d):
|
|
|
|
|
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
|
|
|
|
|