|  |  | @ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' | 
			
		
	
		
		
			
				
					
					|  |  |  | MD5 = '30177ea32e27c525793142b6bf2c8e2d' |  |  |  | MD5 = '30177ea32e27c525793142b6bf2c8e2d' | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | class DataType(object): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     NGRAM = 1 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     SEQ = 2 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def word_count(f, word_freq=None): |  |  |  | def word_count(f, word_freq=None): | 
			
		
	
		
		
			
				
					
					|  |  |  |     if word_freq is None: |  |  |  |     if word_freq is None: | 
			
		
	
		
		
			
				
					
					|  |  |  |         word_freq = collections.defaultdict(int) |  |  |  |         word_freq = collections.defaultdict(int) | 
			
		
	
	
		
		
			
				
					|  |  | @ -41,7 +46,7 @@ def word_count(f, word_freq=None): | 
			
		
	
		
		
			
				
					
					|  |  |  |     return word_freq |  |  |  |     return word_freq | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def build_dict(typo_freq=50): |  |  |  | def build_dict(min_word_freq=50): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     Build a word dictionary from the corpus,  Keys of the dictionary are words, |  |  |  |     Build a word dictionary from the corpus,  Keys of the dictionary are words, | 
			
		
	
		
		
			
				
					
					|  |  |  |     and values are zero-based IDs of these words. |  |  |  |     and values are zero-based IDs of these words. | 
			
		
	
	
		
		
			
				
					|  |  | @ -59,7 +64,7 @@ def build_dict(typo_freq=50): | 
			
		
	
		
		
			
				
					
					|  |  |  |             # remove <unk> for now, since we will set it as last index |  |  |  |             # remove <unk> for now, since we will set it as last index | 
			
		
	
		
		
			
				
					
					|  |  |  |             del word_freq['<unk>'] |  |  |  |             del word_freq['<unk>'] | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         word_freq = filter(lambda x: x[1] > typo_freq, word_freq.items()) |  |  |  |         word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |         word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) |  |  |  |         word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) | 
			
		
	
		
		
			
				
					
					|  |  |  |         words, _ = list(zip(*word_freq_sorted)) |  |  |  |         words, _ = list(zip(*word_freq_sorted)) | 
			
		
	
	
		
		
			
				
					|  |  | @ -69,7 +74,7 @@ def build_dict(typo_freq=50): | 
			
		
	
		
		
			
				
					
					|  |  |  |     return word_idx |  |  |  |     return word_idx | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def reader_creator(filename, word_idx, n): |  |  |  | def reader_creator(filename, word_idx, n, data_type): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     def reader(): |  |  |  |     def reader(): | 
			
		
	
		
		
			
				
					
					|  |  |  |         with tarfile.open( |  |  |  |         with tarfile.open( | 
			
		
	
		
		
			
				
					
					|  |  |  |                 paddle.v2.dataset.common.download( |  |  |  |                 paddle.v2.dataset.common.download( | 
			
		
	
	
		
		
			
				
					|  |  | @ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n): | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |             UNK = word_idx['<unk>'] |  |  |  |             UNK = word_idx['<unk>'] | 
			
		
	
		
		
			
				
					
					|  |  |  |             for l in f: |  |  |  |             for l in f: | 
			
		
	
		
		
			
				
					
					|  |  |  |                 l = ['<s>'] + l.strip().split() + ['<e>'] |  |  |  |                 if DataType.NGRAM == data_type: | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                 if len(l) >= n: |  |  |  |                     assert n > -1, 'Invalid gram length' | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     l = ['<s>'] + l.strip().split() + ['<e>'] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     if len(l) >= n: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         l = [word_idx.get(w, UNK) for w in l] | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                         for i in range(n, len(l) + 1): | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                             yield tuple(l[i - n:i]) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                 elif DataType.SEQ == data_type: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     l = l.strip().split() | 
			
		
	
		
		
			
				
					
					|  |  |  |                     l = [word_idx.get(w, UNK) for w in l] |  |  |  |                     l = [word_idx.get(w, UNK) for w in l] | 
			
		
	
		
		
			
				
					
					|  |  |  |                     for i in range(n, len(l) + 1): |  |  |  |                     src_seq = [word_idx['<s>']] + l | 
			
				
				
			
		
	
		
		
			
				
					
					|  |  |  |                         yield tuple(l[i - n:i]) |  |  |  |                     trg_seq = l + [word_idx['<e>']] | 
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     if n > 0 and len(src_seq) > n: continue | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     yield src_seq, trg_seq | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                 else: | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                     assert False, 'Unknow data type' | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     return reader |  |  |  |     return reader | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def train(word_idx, n): |  |  |  | def train(word_idx, n, data_type=DataType.NGRAM): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     imikolov training set creator. |  |  |  |     imikolov training set creator. | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
	
		
		
			
				
					|  |  | @ -97,15 +113,18 @@ def train(word_idx, n): | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     :param word_idx: word dictionary |  |  |  |     :param word_idx: word dictionary | 
			
		
	
		
		
			
				
					
					|  |  |  |     :type word_idx: dict |  |  |  |     :type word_idx: dict | 
			
		
	
		
		
			
				
					
					|  |  |  |     :param n: sliding window size |  |  |  |     :param n: sliding window size if type is ngram, otherwise max length of sequence | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     :type n: int |  |  |  |     :type n: int | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     :param data_type: data type (ngram or sequence) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     :type data_type: member variable of DataType (NGRAM or SEQ) | 
			
		
	
		
		
			
				
					
					|  |  |  |     :return: Training reader creator |  |  |  |     :return: Training reader creator | 
			
		
	
		
		
			
				
					
					|  |  |  |     :rtype: callable |  |  |  |     :rtype: callable | 
			
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) |  |  |  |     return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n, | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                           data_type) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def test(word_idx, n): |  |  |  | def test(word_idx, n, data_type=DataType.NGRAM): | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     imikolov test set creator. |  |  |  |     imikolov test set creator. | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
	
		
		
			
				
					|  |  | @ -114,12 +133,15 @@ def test(word_idx, n): | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  |     :param word_idx: word dictionary |  |  |  |     :param word_idx: word dictionary | 
			
		
	
		
		
			
				
					
					|  |  |  |     :type word_idx: dict |  |  |  |     :type word_idx: dict | 
			
		
	
		
		
			
				
					
					|  |  |  |     :param n: sliding window size |  |  |  |     :param n: sliding window size if type is ngram, otherwise max length of sequence | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |     :type n: int |  |  |  |     :type n: int | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     :param data_type: data type (ngram or sequence) | 
			
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |     :type data_type: member variable of DataType (NGRAM or SEQ) | 
			
		
	
		
		
			
				
					
					|  |  |  |     :return: Test reader creator |  |  |  |     :return: Test reader creator | 
			
		
	
		
		
			
				
					
					|  |  |  |     :rtype: callable |  |  |  |     :rtype: callable | 
			
		
	
		
		
			
				
					
					|  |  |  |     """ |  |  |  |     """ | 
			
		
	
		
		
			
				
					
					|  |  |  |     return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) |  |  |  |     return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n, | 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					|  |  |  |  |  |  |  |                           data_type) | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | 
 |  |  |  | 
 | 
			
		
	
		
		
			
				
					
					|  |  |  | def fetch(): |  |  |  | def fetch(): | 
			
		
	
	
		
		
			
				
					|  |  | 
 |