You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							72 lines
						
					
					
						
							2.3 KiB
						
					
					
				
			
		
		
	
	
							72 lines
						
					
					
						
							2.3 KiB
						
					
					
				| # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import os
 | |
| import sys
 | |
| 
 | |
| from paddle.trainer.PyDataProvider2 import *
 | |
| 
 | |
| 
 | |
| def hook(settings, dict_file, **kwargs):
 | |
|     settings.word_dict = dict_file
 | |
|     settings.input_types = [
 | |
|         integer_value_sequence(len(settings.word_dict)), integer_value(3)
 | |
|     ]
 | |
|     settings.logger.info('dict len : %d' % (len(settings.word_dict)))
 | |
| 
 | |
| 
 | |
| @provider(init_hook=hook, should_shuffle=False)
 | |
| def process(settings, file_name):
 | |
|     with open(file_name, 'r') as fdata:
 | |
|         for line in fdata:
 | |
|             label, comment = line.strip().split('\t')
 | |
|             label = int(''.join(label.split()))
 | |
|             words = comment.split()
 | |
|             words = [
 | |
|                 settings.word_dict[w] for w in words if w in settings.word_dict
 | |
|             ]
 | |
|             yield words, label
 | |
| 
 | |
| 
 | |
| ## for hierarchical sequence network
 | |
| def hook2(settings, dict_file, **kwargs):
 | |
|     settings.word_dict = dict_file
 | |
|     settings.input_types = [
 | |
|         integer_value_sub_sequence(len(settings.word_dict)),
 | |
|         integer_value_sequence(3)
 | |
|     ]
 | |
|     settings.logger.info('dict len : %d' % (len(settings.word_dict)))
 | |
| 
 | |
| 
 | |
| @provider(init_hook=hook2, should_shuffle=False)
 | |
| def process2(settings, file_name):
 | |
|     with open(file_name) as fdata:
 | |
|         labels = []
 | |
|         sentences = []
 | |
|         for line in fdata:
 | |
|             if (len(line)) > 1:
 | |
|                 label, comment = line.strip().split('\t')
 | |
|                 label = int(''.join(label.split()))
 | |
|                 words = comment.split()
 | |
|                 words = [
 | |
|                     settings.word_dict[w] for w in words
 | |
|                     if w in settings.word_dict
 | |
|                 ]
 | |
|                 labels.append(label)
 | |
|                 sentences.append(words)
 | |
|             else:
 | |
|                 yield sentences, labels
 | |
|                 labels = []
 | |
|                 sentences = []
 |