|
|
|
@ -147,7 +147,7 @@ class BottleneckBlock(fluid.imperative.Layer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet(fluid.imperative.Layer):
|
|
|
|
|
def __init__(self, layers=50, class_dim=1000):
|
|
|
|
|
def __init__(self, layers=50, class_dim=102):
|
|
|
|
|
super(ResNet, self).__init__()
|
|
|
|
|
|
|
|
|
|
self.layers = layers
|
|
|
|
@ -208,6 +208,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
seed = 90
|
|
|
|
|
|
|
|
|
|
batch_size = train_parameters["batch_size"]
|
|
|
|
|
batch_num = 1
|
|
|
|
|
with fluid.imperative.guard():
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
@ -227,7 +228,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
dy_param_init_value[param.name] = param._numpy()
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
|
|
|
if batch_id >= 1:
|
|
|
|
|
if batch_id >= batch_num:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
dy_x_data = np.array(
|
|
|
|
@ -313,7 +314,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
static_param_init_value[static_param_name_list[i]] = out[i]
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
|
|
|
if batch_id >= 1:
|
|
|
|
|
if batch_id >= batch_num:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
static_x_data = np.array(
|
|
|
|
@ -368,6 +369,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
seed = 90
|
|
|
|
|
|
|
|
|
|
batch_size = train_parameters["batch_size"]
|
|
|
|
|
batch_num = 1
|
|
|
|
|
with fluid.imperative.guard(device=None):
|
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
@ -387,7 +389,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
dy_param_init_value[param.name] = param._numpy()
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
|
|
|
if batch_id >= 1:
|
|
|
|
|
if batch_id >= batch_num:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
dy_x_data = np.array(
|
|
|
|
@ -473,7 +475,7 @@ class TestImperativeResnet(unittest.TestCase):
|
|
|
|
|
static_param_init_value[static_param_name_list[i]] = out[i]
|
|
|
|
|
|
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
|
|
|
if batch_id >= 1:
|
|
|
|
|
if batch_id >= batch_num:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
static_x_data = np.array(
|
|
|
|
|