|
|
|
@ -35,7 +35,7 @@ import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from test_dist_base import TestDistRunnerBase, runtime_main
|
|
|
|
|
from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
|
|
|
|
|
import paddle.compat as cpt
|
|
|
|
|
from paddle.compat import long_type
|
|
|
|
|
|
|
|
|
@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
|
|
|
|
|
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
|
|
|
|
|
pass_start_time = time.time()
|
|
|
|
|
for batch_id, data in enumerate(train_data()):
|
|
|
|
|
if batch_id >= 5:
|
|
|
|
|
if batch_id >= RUN_STEP:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
feed_list = []
|
|
|
|
|
total_num_token = 0
|
|
|
|
|
|
|
|
|
|
#if TrainTaskConfig.local:
|
|
|
|
|
# lr_rate = lr_scheduler.update_learning_rate()
|
|
|
|
|
#for place_id, data_buffer in enumerate(
|
|
|
|
|
# split_data(
|
|
|
|
|
# data, num_part=dev_count)):
|
|
|
|
|
|
|
|
|
|
if TrainTaskConfig.local:
|
|
|
|
|
lr_rate = lr_scheduler.update_learning_rate()
|
|
|
|
|
|
|
|
|
@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
|
|
|
|
|
init = True
|
|
|
|
|
|
|
|
|
|
# Validate and save the model for inference.
|
|
|
|
|
if batch_id == 0 or batch_id == 4:
|
|
|
|
|
if TrainTaskConfig.val_file_pattern is not None:
|
|
|
|
|
val_avg_cost, val_ppl = test()
|
|
|
|
|
print("[%f]" % val_avg_cost)
|
|
|
|
|
else:
|
|
|
|
|
assert (False)
|
|
|
|
|
if TrainTaskConfig.val_file_pattern is not None:
|
|
|
|
|
val_avg_cost, val_ppl = test()
|
|
|
|
|
print("[%f]" % val_avg_cost)
|
|
|
|
|
else:
|
|
|
|
|
assert (False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#import transformer_reader as reader
|
|
|
|
@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
|
|
|
|
|
|
|
|
|
|
def run_trainer(self, args):
|
|
|
|
|
TrainTaskConfig.use_gpu = args.use_cuda
|
|
|
|
|
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model(
|
|
|
|
|
sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
|
|
|
|
|
args.is_dist, not args.sync_mode)
|
|
|
|
|
|
|
|
|
|
if args.is_dist:
|
|
|
|
|