fix a bug in multi_precision_fp16 unittest. (#29756)

revert-31562-mean
huangxu96 4 years ago committed by GitHub
parent 2e5b4a216c
commit 97e29411eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -155,9 +155,10 @@ def train(use_pure_fp16=True, use_nesterov=False):
loss, = exe.run(compiled_program,
feed=feeder.feed(data),
fetch_list=[sum_cost])
loss_v = loss[0] if isinstance(loss, np.ndarray) else loss
print('PassID {0:1}, Train Batch ID {1:04}, train loss {2:2.4}'.
format(pass_id, batch_id + 1, float(loss)))
train_loss_list.append(float(loss))
format(pass_id, batch_id + 1, float(loss_v)))
train_loss_list.append(float(loss_v))
if batch_id >= 4: # For speeding up CI
test_loss_list = []

Loading…
Cancel
Save