@ -260,51 +260,48 @@ class DataReader(object):
yield token_ids , sent_ids , pos_ids , label
def data_generator ( self ) :
def wrapper ( ) :
def reader ( ) :
for epoch in range ( self . epoch ) :
self . current_epoch = epoch + 1
sample_generator = self . build_fake_data ( )
for sample in sample_generator :
if sample is None :
continue
yield sample
def batch_reader ( reader , batch_size , in_tokens ) :
batch , total_token_num , max_len = [ ] , 0 , 0
for parsed_line in reader ( ) :
token_ids , sent_ids , pos_ids , label = parsed_line
max_len = max ( max_len , len ( token_ids ) )
if in_tokens :
to_append = ( len ( batch ) + 1 ) * max_len < = batch_size
else :
to_append = len ( batch ) < batch_size
if to_append :
batch . append ( parsed_line )
total_token_num + = len ( token_ids )
else :
yield batch , total_token_num
batch , total_token_num , max_len = [ parsed_line ] , len (
token_ids ) , len ( token_ids )
if len ( batch ) > 0 :
def reader ( ) :
for epoch in range ( self . epoch ) :
self . current_epoch = epoch + 1
sample_generator = self . build_fake_data ( )
for sample in sample_generator :
if sample is None :
continue
yield sample
def batch_reader ( reader , batch_size , in_tokens ) :
batch , total_token_num , max_len = [ ] , 0 , 0
for parsed_line in reader ( ) :
token_ids , sent_ids , pos_ids , label = parsed_line
max_len = max ( max_len , len ( token_ids ) )
if in_tokens :
to_append = ( len ( batch ) + 1 ) * max_len < = batch_size
else :
to_append = len ( batch ) < batch_size
if to_append :
batch . append ( parsed_line )
total_token_num + = len ( token_ids )
else :
yield batch , total_token_num
for batch_data , total_token_num in batch_reader (
reader , self . batch_size , self . in_tokens ) :
yield prepare_batch_data (
batch_data ,
total_token_num ,
voc_size = self . voc_size ,
pad_id = self . pad_id ,
cls_id = self . cls_id ,
sep_id = self . sep_id ,
mask_id = self . mask_id ,
return_input_mask = True ,
return_max_len = False ,
return_num_token = False )
return wrapper
batch , total_token_num , max_len = [ parsed_line ] , len (
token_ids ) , len ( token_ids )
if len ( batch ) > 0 :
yield batch , total_token_num
for batch_data , total_token_num in batch_reader ( reader , self . batch_size ,
self . in_tokens ) :
yield prepare_batch_data (
batch_data ,
total_token_num ,
voc_size = self . voc_size ,
pad_id = self . pad_id ,
cls_id = self . cls_id ,
sep_id = self . sep_id ,
mask_id = self . mask_id ,
return_input_mask = True ,
return_max_len = False ,
return_num_token = False )
class ModelHyperParams ( object ) :