|
|
|
@ -18,27 +18,33 @@
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
|
|
|
|
|
from paddle.trainer.PyDataProviderWrapper import *
|
|
|
|
|
from paddle.trainer.PyDataProvider2 import *
|
|
|
|
|
|
|
|
|
|
@init_hook_wrapper
|
|
|
|
|
def hook(obj, dict_file, **kwargs):
|
|
|
|
|
obj.word_dict = dict_file
|
|
|
|
|
obj.slots = [IndexSlot(len(obj.word_dict)), IndexSlot(3)]
|
|
|
|
|
obj.logger.info('dict len : %d' % (len(obj.word_dict)))
|
|
|
|
|
def hook(settings, dict_file, **kwargs):
|
|
|
|
|
settings.word_dict = dict_file
|
|
|
|
|
settings.input_types = [integer_value_sequence(len(settings.word_dict)),
|
|
|
|
|
integer_value_sequence(3)]
|
|
|
|
|
settings.logger.info('dict len : %d' % (len(settings.word_dict)))
|
|
|
|
|
|
|
|
|
|
@provider(use_seq=True, init_hook=hook)
|
|
|
|
|
def process(obj, file_name):
|
|
|
|
|
@provider(init_hook=hook)
|
|
|
|
|
def process(settings, file_name):
|
|
|
|
|
with open(file_name, 'r') as fdata:
|
|
|
|
|
for line in fdata:
|
|
|
|
|
label, comment = line.strip().split('\t')
|
|
|
|
|
label = int(''.join(label.split()))
|
|
|
|
|
words = comment.split()
|
|
|
|
|
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
|
|
|
|
|
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
|
|
|
|
|
yield word_slot, [label]
|
|
|
|
|
|
|
|
|
|
## for hierarchical sequence network
|
|
|
|
|
@provider(use_seq=True, init_hook=hook)
|
|
|
|
|
def process2(obj, file_name):
|
|
|
|
|
def hook2(settings, dict_file, **kwargs):
|
|
|
|
|
settings.word_dict = dict_file
|
|
|
|
|
settings.input_types = [integer_value_sub_sequence(len(settings.word_dict)),
|
|
|
|
|
integer_value_sub_sequence(3)]
|
|
|
|
|
settings.logger.info('dict len : %d' % (len(settings.word_dict)))
|
|
|
|
|
|
|
|
|
|
@provider(init_hook=hook2)
|
|
|
|
|
def process2(settings, file_name):
|
|
|
|
|
with open(file_name) as fdata:
|
|
|
|
|
label_list = []
|
|
|
|
|
word_slot_list = []
|
|
|
|
@ -47,7 +53,7 @@ def process2(obj, file_name):
|
|
|
|
|
label,comment = line.strip().split('\t')
|
|
|
|
|
label = int(''.join(label.split()))
|
|
|
|
|
words = comment.split()
|
|
|
|
|
word_slot = [obj.word_dict[w] for w in words if w in obj.word_dict]
|
|
|
|
|
word_slot = [settings.word_dict[w] for w in words if w in settings.word_dict]
|
|
|
|
|
label_list.append([label])
|
|
|
|
|
word_slot_list.append(word_slot)
|
|
|
|
|
else:
|
|
|
|
|