| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -266,7 +266,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            # FIXME(wuyi): For use_reader_op, if the current
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            # pass is not the last, the last batch of this pass
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            # is also equal to args.batch_size.
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            num_samples += len(args.batch_size)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if args.use_reader_op:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                num_samples += args.batch_size
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            else:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                num_samples += len(data)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            train_losses.append(loss)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            print("Pass: %d, Iter: %d, Loss: %f\n" %
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                  (pass_id, iters, np.mean(train_losses)))
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -350,9 +353,6 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if iters == args.skip_batch_num:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                start_time = time.time()
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                num_samples = 0
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            # NOTE: if use reader ops, the input data is not splited to multiple cards
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if args.use_reader_op and iters >= args.iterations / args.gpus:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                break
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if args.use_fake_data or args.use_reader_op:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                try:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                    loss, = exe.run([avg_loss.name])
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -362,7 +362,10 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                loss, = exe.run([avg_loss.name], feed=feeder.feed(data))
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if args.update_method == "pserver":
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                exe.bcast_params()
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            num_samples += len(data)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if args.use_reader_op:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                num_samples += args.batch_size
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            else:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                num_samples += len(data)
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            iters += 1
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					            if batch_id % 1 == 0:
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                print("Pass %d, batch %d, loss %s" %
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |