increment resnet and ptbrnn's batch_num

test=develop
revert-16045-imperative_remove_desc
minqiyang 6 years ago committed by ceci3
parent 3723dcc301
commit a5dc2812e3

@ -243,7 +243,9 @@ class TestImperativePtbRnn(unittest.TestCase):
dy_loss = None
last_hidden = None
last_cell = None
for i in range(2):
batch_num = 200
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
@ -302,7 +304,7 @@ class TestImperativePtbRnn(unittest.TestCase):
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(2):
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))

@ -231,7 +231,7 @@ class TestImperativeResnet(unittest.TestCase):
seed = 90
batch_size = train_parameters["batch_size"]
batch_num = 2
batch_num = 50
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed

Loading…
Cancel
Save