|
|
|
@ -267,7 +267,7 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
|
|
|
|
|
# pass is not the last, the last batch of this pass
|
|
|
|
|
# is also equal to args.batch_size.
|
|
|
|
|
if args.use_reader_op:
|
|
|
|
|
num_samples += args.batch_size
|
|
|
|
|
num_samples += args.batch_size * args.gpus
|
|
|
|
|
else:
|
|
|
|
|
num_samples += len(data)
|
|
|
|
|
train_losses.append(loss)
|
|
|
|
@ -363,7 +363,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
|
|
|
|
|
if args.update_method == "pserver":
|
|
|
|
|
exe.bcast_params()
|
|
|
|
|
if args.use_reader_op:
|
|
|
|
|
num_samples += args.batch_size
|
|
|
|
|
num_samples += args.batch_size * args.gpus
|
|
|
|
|
else:
|
|
|
|
|
num_samples += len(data)
|
|
|
|
|
iters += 1
|
|
|
|
|