|
|
|
|
@ -49,10 +49,13 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
|
|
|
|
|
max_len = max([len(sent) for sent in batch_tokens])
|
|
|
|
|
mask_label = []
|
|
|
|
|
mask_pos = []
|
|
|
|
|
np.random.seed(SEED)
|
|
|
|
|
prob_mask = np.random.rand(total_token_num)
|
|
|
|
|
# NOTE: numpy random is not thread-safe, for async DataLoader,
|
|
|
|
|
# using np.random.seed() directly is risky, using RandomState
|
|
|
|
|
# class is a better way
|
|
|
|
|
self_random = np.random.RandomState(SEED)
|
|
|
|
|
prob_mask = self_random.rand(total_token_num)
|
|
|
|
|
# Note: the first token is [CLS], so [low=1]
|
|
|
|
|
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
|
|
|
|
|
replace_ids = self_random.randint(1, high=vocab_size, size=total_token_num)
|
|
|
|
|
pre_sent_len = 0
|
|
|
|
|
prob_index = 0
|
|
|
|
|
for sent_index, sent in enumerate(batch_tokens):
|
|
|
|
|
@ -85,7 +88,9 @@ def mask(batch_tokens, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
|
|
|
|
|
|
|
|
|
|
# ensure at least mask one word in a sentence
|
|
|
|
|
while not mask_flag:
|
|
|
|
|
token_index = int(np.random.randint(1, high=len(sent) - 1, size=1))
|
|
|
|
|
token_index = int(
|
|
|
|
|
self_random.randint(
|
|
|
|
|
1, high=len(sent) - 1, size=1))
|
|
|
|
|
if sent[token_index] != SEP and sent[token_index] != CLS:
|
|
|
|
|
mask_label.append(sent[token_index])
|
|
|
|
|
sent[token_index] = MASK
|
|
|
|
|
@ -244,13 +249,16 @@ class DataReader(object):
|
|
|
|
|
|
|
|
|
|
def build_fake_data(self):
|
|
|
|
|
for _ in range(1000000):
|
|
|
|
|
random.seed(SEED)
|
|
|
|
|
sent0_len = random.randint(50, 100)
|
|
|
|
|
sent1_len = random.randint(50, 100)
|
|
|
|
|
# NOTE: python random has bug in python2,
|
|
|
|
|
# we should avoid using random module,
|
|
|
|
|
# please using numpy.random
|
|
|
|
|
self_random = np.random.RandomState(SEED)
|
|
|
|
|
sent0_len = self_random.randint(50, 100)
|
|
|
|
|
sent1_len = self_random.randint(50, 100)
|
|
|
|
|
|
|
|
|
|
token_ids = [1] \
|
|
|
|
|
+ [random.randint(0, 10000) for i in range(sent0_len-1)] \
|
|
|
|
|
+ [random.randint(0, 10000) for i in range(sent1_len-1)] \
|
|
|
|
|
+ [self_random.randint(0, 10000) for i in range(sent0_len-1)] \
|
|
|
|
|
+ [self_random.randint(0, 10000) for i in range(sent1_len-1)] \
|
|
|
|
|
+ [2]
|
|
|
|
|
|
|
|
|
|
sent_ids = [0 for i in range(sent0_len)
|
|
|
|
|
@ -260,48 +268,51 @@ class DataReader(object):
|
|
|
|
|
yield token_ids, sent_ids, pos_ids, label
|
|
|
|
|
|
|
|
|
|
def data_generator(self):
|
|
|
|
|
def reader():
|
|
|
|
|
for epoch in range(self.epoch):
|
|
|
|
|
self.current_epoch = epoch + 1
|
|
|
|
|
sample_generator = self.build_fake_data()
|
|
|
|
|
for sample in sample_generator:
|
|
|
|
|
if sample is None:
|
|
|
|
|
continue
|
|
|
|
|
yield sample
|
|
|
|
|
|
|
|
|
|
def batch_reader(reader, batch_size, in_tokens):
|
|
|
|
|
batch, total_token_num, max_len = [], 0, 0
|
|
|
|
|
for parsed_line in reader():
|
|
|
|
|
token_ids, sent_ids, pos_ids, label = parsed_line
|
|
|
|
|
max_len = max(max_len, len(token_ids))
|
|
|
|
|
if in_tokens:
|
|
|
|
|
to_append = (len(batch) + 1) * max_len <= batch_size
|
|
|
|
|
else:
|
|
|
|
|
to_append = len(batch) < batch_size
|
|
|
|
|
if to_append:
|
|
|
|
|
batch.append(parsed_line)
|
|
|
|
|
total_token_num += len(token_ids)
|
|
|
|
|
else:
|
|
|
|
|
def wrapper():
|
|
|
|
|
def reader():
|
|
|
|
|
for epoch in range(self.epoch):
|
|
|
|
|
self.current_epoch = epoch + 1
|
|
|
|
|
sample_generator = self.build_fake_data()
|
|
|
|
|
for sample in sample_generator:
|
|
|
|
|
if sample is None:
|
|
|
|
|
continue
|
|
|
|
|
yield sample
|
|
|
|
|
|
|
|
|
|
def batch_reader(reader, batch_size, in_tokens):
|
|
|
|
|
batch, total_token_num, max_len = [], 0, 0
|
|
|
|
|
for parsed_line in reader():
|
|
|
|
|
token_ids, sent_ids, pos_ids, label = parsed_line
|
|
|
|
|
max_len = max(max_len, len(token_ids))
|
|
|
|
|
if in_tokens:
|
|
|
|
|
to_append = (len(batch) + 1) * max_len <= batch_size
|
|
|
|
|
else:
|
|
|
|
|
to_append = len(batch) < batch_size
|
|
|
|
|
if to_append:
|
|
|
|
|
batch.append(parsed_line)
|
|
|
|
|
total_token_num += len(token_ids)
|
|
|
|
|
else:
|
|
|
|
|
yield batch, total_token_num
|
|
|
|
|
batch, total_token_num, max_len = [parsed_line], len(
|
|
|
|
|
token_ids), len(token_ids)
|
|
|
|
|
|
|
|
|
|
if len(batch) > 0:
|
|
|
|
|
yield batch, total_token_num
|
|
|
|
|
batch, total_token_num, max_len = [parsed_line], len(
|
|
|
|
|
token_ids), len(token_ids)
|
|
|
|
|
|
|
|
|
|
if len(batch) > 0:
|
|
|
|
|
yield batch, total_token_num
|
|
|
|
|
|
|
|
|
|
for batch_data, total_token_num in batch_reader(reader, self.batch_size,
|
|
|
|
|
self.in_tokens):
|
|
|
|
|
yield prepare_batch_data(
|
|
|
|
|
batch_data,
|
|
|
|
|
total_token_num,
|
|
|
|
|
voc_size=self.voc_size,
|
|
|
|
|
pad_id=self.pad_id,
|
|
|
|
|
cls_id=self.cls_id,
|
|
|
|
|
sep_id=self.sep_id,
|
|
|
|
|
mask_id=self.mask_id,
|
|
|
|
|
return_input_mask=True,
|
|
|
|
|
return_max_len=False,
|
|
|
|
|
return_num_token=False)
|
|
|
|
|
|
|
|
|
|
for batch_data, total_token_num in batch_reader(
|
|
|
|
|
reader, self.batch_size, self.in_tokens):
|
|
|
|
|
yield prepare_batch_data(
|
|
|
|
|
batch_data,
|
|
|
|
|
total_token_num,
|
|
|
|
|
voc_size=self.voc_size,
|
|
|
|
|
pad_id=self.pad_id,
|
|
|
|
|
cls_id=self.cls_id,
|
|
|
|
|
sep_id=self.sep_id,
|
|
|
|
|
mask_id=self.mask_id,
|
|
|
|
|
return_input_mask=True,
|
|
|
|
|
return_max_len=False,
|
|
|
|
|
return_num_token=False)
|
|
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelHyperParams(object):
|
|
|
|
|
|