|
|
|
@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
|
|
|
|
|
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataType(object):
|
|
|
|
|
NGRAM = 1
|
|
|
|
|
SEQ = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def word_count(f, word_freq=None):
|
|
|
|
|
if word_freq is None:
|
|
|
|
|
word_freq = collections.defaultdict(int)
|
|
|
|
@ -41,7 +46,7 @@ def word_count(f, word_freq=None):
|
|
|
|
|
return word_freq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_dict(typo_freq=50):
|
|
|
|
|
def build_dict(min_word_freq=50):
|
|
|
|
|
"""
|
|
|
|
|
Build a word dictionary from the corpus, Keys of the dictionary are words,
|
|
|
|
|
and values are zero-based IDs of these words.
|
|
|
|
@ -59,7 +64,7 @@ def build_dict(typo_freq=50):
|
|
|
|
|
# remove <unk> for now, since we will set it as last index
|
|
|
|
|
del word_freq['<unk>']
|
|
|
|
|
|
|
|
|
|
word_freq = filter(lambda x: x[1] > typo_freq, word_freq.items())
|
|
|
|
|
word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
|
|
|
|
|
|
|
|
|
|
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
|
|
|
|
|
words, _ = list(zip(*word_freq_sorted))
|
|
|
|
@ -69,7 +74,7 @@ def build_dict(typo_freq=50):
|
|
|
|
|
return word_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reader_creator(filename, word_idx, n):
|
|
|
|
|
def reader_creator(filename, word_idx, n, data_type):
|
|
|
|
|
def reader():
|
|
|
|
|
with tarfile.open(
|
|
|
|
|
paddle.v2.dataset.common.download(
|
|
|
|
@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n):
|
|
|
|
|
|
|
|
|
|
UNK = word_idx['<unk>']
|
|
|
|
|
for l in f:
|
|
|
|
|
if DataType.NGRAM == data_type:
|
|
|
|
|
assert n > -1, 'Invalid gram length'
|
|
|
|
|
l = ['<s>'] + l.strip().split() + ['<e>']
|
|
|
|
|
if len(l) >= n:
|
|
|
|
|
l = [word_idx.get(w, UNK) for w in l]
|
|
|
|
|
for i in range(n, len(l) + 1):
|
|
|
|
|
yield tuple(l[i - n:i])
|
|
|
|
|
elif DataType.SEQ == data_type:
|
|
|
|
|
l = l.strip().split()
|
|
|
|
|
l = [word_idx.get(w, UNK) for w in l]
|
|
|
|
|
src_seq = [word_idx['<s>']] + l
|
|
|
|
|
trg_seq = l + [word_idx['<e>']]
|
|
|
|
|
if n > 0 and len(src_seq) > n: continue
|
|
|
|
|
yield src_seq, trg_seq
|
|
|
|
|
else:
|
|
|
|
|
assert False, 'Unknow data type'
|
|
|
|
|
|
|
|
|
|
return reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(word_idx, n):
|
|
|
|
|
def train(word_idx, n, data_type=DataType.NGRAM):
|
|
|
|
|
"""
|
|
|
|
|
imikolov training set creator.
|
|
|
|
|
|
|
|
|
@ -97,15 +113,18 @@ def train(word_idx, n):
|
|
|
|
|
|
|
|
|
|
:param word_idx: word dictionary
|
|
|
|
|
:type word_idx: dict
|
|
|
|
|
:param n: sliding window size
|
|
|
|
|
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
|
|
|
|
:type n: int
|
|
|
|
|
:param data_type: data type (ngram or sequence)
|
|
|
|
|
:type data_type: member variable of DataType (NGRAM or SEQ)
|
|
|
|
|
:return: Training reader creator
|
|
|
|
|
:rtype: callable
|
|
|
|
|
"""
|
|
|
|
|
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n)
|
|
|
|
|
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
|
|
|
|
|
data_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test(word_idx, n):
|
|
|
|
|
def test(word_idx, n, data_type=DataType.NGRAM):
|
|
|
|
|
"""
|
|
|
|
|
imikolov test set creator.
|
|
|
|
|
|
|
|
|
@ -114,12 +133,15 @@ def test(word_idx, n):
|
|
|
|
|
|
|
|
|
|
:param word_idx: word dictionary
|
|
|
|
|
:type word_idx: dict
|
|
|
|
|
:param n: sliding window size
|
|
|
|
|
:param n: sliding window size if type is ngram, otherwise max length of sequence
|
|
|
|
|
:type n: int
|
|
|
|
|
:param data_type: data type (ngram or sequence)
|
|
|
|
|
:type data_type: member variable of DataType (NGRAM or SEQ)
|
|
|
|
|
:return: Test reader creator
|
|
|
|
|
:rtype: callable
|
|
|
|
|
"""
|
|
|
|
|
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
|
|
|
|
|
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
|
|
|
|
|
data_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fetch():
|
|
|
|
|