@ -19,27 +19,44 @@ START = "<s>"
END = " <e> "
def hook ( settings , src_dict , trg_dict , file_list , * * kwargs ) :
def hook ( settings , src_dict_path , trg_dict_path , is_generating , file_list ,
* * kwargs ) :
# job_mode = 1: training mode
# job_mode = 0: generating mode
settings . job_mode = trg_dict is not None
settings . src_dict = src_dict
settings . job_mode = not is_generating
settings . src_dict = dict ( )
with open ( src_dict_path , " r " ) as fin :
settings . src_dict = {
line . strip ( ) : line_count
for line_count , line in enumerate ( fin )
}
settings . trg_dict = dict ( )
with open ( trg_dict_path , " r " ) as fin :
settings . trg_dict = {
line . strip ( ) : line_count
for line_count , line in enumerate ( fin )
}
settings . logger . info ( " src dict len : %d " % ( len ( settings . src_dict ) ) )
settings . sample_count = 0
if settings . job_mode :
settings . trg_dict = trg_dict
settings . slots = [
settings . slots = {
' source_language_word ' :
integer_value_sequence ( len ( settings . src_dict ) ) ,
' target_language_word ' :
integer_value_sequence ( len ( settings . trg_dict ) ) ,
' target_language_next_word ' :
integer_value_sequence ( len ( settings . trg_dict ) )
]
}
settings . logger . info ( " trg dict len : %d " % ( len ( settings . trg_dict ) ) )
else :
settings . slots = [
settings . slots = {
' source_language_word ' :
integer_value_sequence ( len ( settings . src_dict ) ) ,
' sent_id ' :
integer_value_sequence ( len ( open ( file_list [ 0 ] , " r " ) . readlines ( ) ) )
]
}
def _get_ids ( s , dictionary ) :
@ -69,6 +86,10 @@ def process(settings, file_name):
continue
trg_ids_next = trg_ids + [ settings . trg_dict [ END ] ]
trg_ids = [ settings . trg_dict [ START ] ] + trg_ids
yield src_ids , trg_ids , trg_ids_next
yield {
' source_language_word ' : src_ids ,
' target_language_word ' : trg_ids ,
' target_language_next_word ' : trg_ids_next
}
else :
yield src_ids , [ line_count ]
yield { ' source_language_word ' : src_ids , ' sent_id ' : [ line_count ] }