|
|
@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False):
|
|
|
|
loss, = exe.run(compiled_program,
|
|
|
|
loss, = exe.run(compiled_program,
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
feed=feeder.feed(data),
|
|
|
|
fetch_list=[sum_cost])
|
|
|
|
fetch_list=[sum_cost])
|
|
|
|
|
|
|
|
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss
|
|
|
|
print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'.
|
|
|
|
print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'.
|
|
|
|
format(pass_id, batch_id + 1, float(loss)))
|
|
|
|
format(pass_id, batch_id + 1, float(loss_v)))
|
|
|
|
train_loss_list.append(float(loss))
|
|
|
|
train_loss_list.append(float(loss_v))
|
|
|
|
|
|
|
|
|
|
|
|
if batch_id >= 4: # For speeding up CI
|
|
|
|
if batch_id >= 4: # For speeding up CI
|
|
|
|
test_loss_list = []
|
|
|
|
test_loss_list = []
|
|
|
|