|  |  |  | @ -27,19 +27,13 @@ class ITrainer(object): | 
			
		
	
		
			
				
					|  |  |  |  |     The interface of Trainer. The only exposed method is `train`. | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def train(self, | 
			
		
	
		
			
				
					|  |  |  |  |               train_data_reader, | 
			
		
	
		
			
				
					|  |  |  |  |               cost, | 
			
		
	
		
			
				
					|  |  |  |  |               parameters, | 
			
		
	
		
			
				
					|  |  |  |  |               test_data_reader=None, | 
			
		
	
		
			
				
					|  |  |  |  |               event_handler=None): | 
			
		
	
		
			
				
					|  |  |  |  |     def train(self, reader, topology, parameters, event_handler=None): | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |  |         train method. | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         :param train_data_reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param cost: | 
			
		
	
		
			
				
					|  |  |  |  |         :param reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param topology: | 
			
		
	
		
			
				
					|  |  |  |  |         :param parameters: | 
			
		
	
		
			
				
					|  |  |  |  |         :param test_data_reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param event_handler: | 
			
		
	
		
			
				
					|  |  |  |  |         :return: | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
	
		
			
				
					|  |  |  | @ -61,26 +55,22 @@ class SGD(ITrainer): | 
			
		
	
		
			
				
					|  |  |  |  |         self.__optimizer__ = update_equation | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def train(self, | 
			
		
	
		
			
				
					|  |  |  |  |               train_data_reader, | 
			
		
	
		
			
				
					|  |  |  |  |               reader, | 
			
		
	
		
			
				
					|  |  |  |  |               cost, | 
			
		
	
		
			
				
					|  |  |  |  |               parameters, | 
			
		
	
		
			
				
					|  |  |  |  |               num_passes=1, | 
			
		
	
		
			
				
					|  |  |  |  |               test_data_reader=None, | 
			
		
	
		
			
				
					|  |  |  |  |               event_handler=None, | 
			
		
	
		
			
				
					|  |  |  |  |               batch_size=32, | 
			
		
	
		
			
				
					|  |  |  |  |               reader_dict=None): | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |  |         Training method. Will train num_passes of input data. | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         :param train_data_reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param cost: cost layers, to be optimized. | 
			
		
	
		
			
				
					|  |  |  |  |         :param reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param topology: Network Topology, use one or more Layers to represent it. | 
			
		
	
		
			
				
					|  |  |  |  |         :param parameters: The parameter pools. | 
			
		
	
		
			
				
					|  |  |  |  |         :param num_passes: The total train passes. | 
			
		
	
		
			
				
					|  |  |  |  |         :param test_data_reader: | 
			
		
	
		
			
				
					|  |  |  |  |         :param event_handler: Event handler. A method will be invoked when event | 
			
		
	
		
			
				
					|  |  |  |  |                               occurred. | 
			
		
	
		
			
				
					|  |  |  |  |         :type event_handler: (BaseEvent) => None | 
			
		
	
		
			
				
					|  |  |  |  |         :param batch_size: Not important, will be removed after data refactor. | 
			
		
	
		
			
				
					|  |  |  |  |         :return: | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |  |         if event_handler is None: | 
			
		
	
	
		
			
				
					|  |  |  | @ -112,9 +102,9 @@ class SGD(ITrainer): | 
			
		
	
		
			
				
					|  |  |  |  |             event_handler(v2_event.BeginPass(pass_id)) | 
			
		
	
		
			
				
					|  |  |  |  |             pass_evaluator.start() | 
			
		
	
		
			
				
					|  |  |  |  |             updater.startPass() | 
			
		
	
		
			
				
					|  |  |  |  |             for batch_id, data_batch in enumerate( | 
			
		
	
		
			
				
					|  |  |  |  |                     __data_reader_to_batch__(train_data_reader, batch_size, | 
			
		
	
		
			
				
					|  |  |  |  |                                              topology)): | 
			
		
	
		
			
				
					|  |  |  |  |             for batch_id, data_batch in enumerate(reader()): | 
			
		
	
		
			
				
					|  |  |  |  |                 pass_type = updater.startBatch(len(data_batch)) | 
			
		
	
		
			
				
					|  |  |  |  |                 gm.forwardBackward(feeder(data_batch), out_args, pass_type) | 
			
		
	
		
			
				
					|  |  |  |  |                 batch_evaluator.start() | 
			
		
	
		
			
				
					|  |  |  |  |                 event_handler( | 
			
		
	
		
			
				
					|  |  |  |  |                     v2_event.BeginIteration( | 
			
		
	
	
		
			
				
					|  |  |  | @ -144,56 +134,19 @@ class SGD(ITrainer): | 
			
		
	
		
			
				
					|  |  |  |  |         gm.finish() | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def __data_reader_to_batch__(reader, batch_size, topology): | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     This function is not important, and will be removed when data refactored. | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def input_reorder(func): | 
			
		
	
		
			
				
					|  |  |  |  |         for item in func(): | 
			
		
	
		
			
				
					|  |  |  |  |             retv = [] | 
			
		
	
		
			
				
					|  |  |  |  |             for __layer_name__ in topology.proto().input_layer_names: | 
			
		
	
		
			
				
					|  |  |  |  |                 retv.append(item[__layer_name__]) | 
			
		
	
		
			
				
					|  |  |  |  |             yield retv | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     return __generator_to_batch__(input_reorder(reader), batch_size=batch_size) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def __generator_to_batch__(generator, batch_size): | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     This function is not important, and will be removed when data refactored. | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     ret_val = list() | 
			
		
	
		
			
				
					|  |  |  |  |     for each_item in generator: | 
			
		
	
		
			
				
					|  |  |  |  |         ret_val.append(each_item) | 
			
		
	
		
			
				
					|  |  |  |  |         if len(ret_val) == batch_size: | 
			
		
	
		
			
				
					|  |  |  |  |             yield ret_val | 
			
		
	
		
			
				
					|  |  |  |  |             ret_val = list() | 
			
		
	
		
			
				
					|  |  |  |  |     if len(ret_val) != 0: | 
			
		
	
		
			
				
					|  |  |  |  |         yield ret_val | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | def __check_train_args__(train_data_reader, topology, parameters, | 
			
		
	
		
			
				
					|  |  |  |  |                          test_data_reader, event_handler, **kwargs): | 
			
		
	
		
			
				
					|  |  |  |  | def __check_train_args__(reader, topology, parameters, event_handler, **kwargs): | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     Check train function's argument types | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  |     if not callable(train_data_reader) or not isinstance(train_data_reader(), | 
			
		
	
		
			
				
					|  |  |  |  |                                                          collections.Iterator): | 
			
		
	
		
			
				
					|  |  |  |  |         raise ValueError('train_data_reader should be a function, ' | 
			
		
	
		
			
				
					|  |  |  |  |                          'which can return a iterator') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if test_data_reader is not None: | 
			
		
	
		
			
				
					|  |  |  |  |         if not callable(test_data_reader) or not isinstance( | 
			
		
	
		
			
				
					|  |  |  |  |                 test_data_reader(), collections.Iterator): | 
			
		
	
		
			
				
					|  |  |  |  |             raise ValueError('test_data_reader should be a function, which can ' | 
			
		
	
		
			
				
					|  |  |  |  |                              'return a iterator') | 
			
		
	
		
			
				
					|  |  |  |  |     if not callable(reader) or not isinstance(reader(), collections.Iterator): | 
			
		
	
		
			
				
					|  |  |  |  |         raise TypeError('train_data_reader should be a function, ' | 
			
		
	
		
			
				
					|  |  |  |  |                         'which can return a iterator') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if not isinstance(topology, Topology): | 
			
		
	
		
			
				
					|  |  |  |  |         raise ValueError('topology should be a model config') | 
			
		
	
		
			
				
					|  |  |  |  |         raise TypeError('topology should be a model config') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if not isinstance(parameters, v2_parameters.Parameters): | 
			
		
	
		
			
				
					|  |  |  |  |         raise ValueError('parameters should be a parameter pool') | 
			
		
	
		
			
				
					|  |  |  |  |         raise TypeError('parameters should be a parameter pool') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     if not callable(event_handler): | 
			
		
	
		
			
				
					|  |  |  |  |         raise ValueError('event handler should be a function') | 
			
		
	
		
			
				
					|  |  |  |  |         raise TypeError('event handler should be a function') | 
			
		
	
	
		
			
				
					|  |  |  | 
 |