You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							150 lines
						
					
					
						
							4.1 KiB
						
					
					
				
			
		
		
	
	
							150 lines
						
					
					
						
							4.1 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "paddle/utils/Util.h"
 | |
| 
 | |
| #include <stdio.h>
 | |
| 
 | |
| #include "hl_gpu.h"
 | |
| #include "paddle/gserver/dataproviders/DataProvider.h"
 | |
| #include "paddle/gserver/gradientmachines/GradientMachine.h"
 | |
| 
 | |
| #include "TrainerConfig.pb.h"
 | |
| 
 | |
| #include <stdlib.h>
 | |
| #include <fstream>
 | |
| #include "ParamUtil.h"
 | |
| #include "ParameterUpdater.h"
 | |
| #include "TesterConfig.h"
 | |
| #include "TrainerInternalConfig.h"
 | |
| 
 | |
| namespace paddle {
 | |
| 
 | |
| /**
 | |
|  * Neural Network test logics code.
 | |
|  * It is a private class for Trainer.
 | |
|  */
 | |
| class Tester {
 | |
| public:
 | |
|   /**
 | |
|    * Ctor
 | |
|    * @param config Trainer Config.
 | |
|    * @param intconfig Tester Config.
 | |
|    * @param gradientMachine Gradient machine(neuralnetwork) that will be tested.
 | |
|    * @param parameterUpdater Parameter Updater. Not for updating parameter, just
 | |
|    *                         for getting parameter from parameter-server.
 | |
|    * @param testDataProvider Test data provider.
 | |
|    */
 | |
|   Tester(const std::shared_ptr<TrainerConfigHelper>& config,
 | |
|          std::unique_ptr<TesterConfig>&& intconfig,
 | |
|          const GradientMachinePtr& gradientMachine,
 | |
|          const std::shared_ptr<ParameterUpdater>& parameterUpdater,
 | |
|          std::shared_ptr<DataProvider> testDataProvider);
 | |
| 
 | |
|   /**
 | |
|    * test one period.
 | |
|    *
 | |
|    * One period means 2 things.
 | |
|    *   if test_period !=0 and not test_all_data_in_one_period, then
 | |
|    *      will test test_period * batch_size data.
 | |
|    *   else
 | |
|    *      will test whole test data.
 | |
|    *
 | |
|    * It is convenience to test small set of data when test data set is large and
 | |
|    * is training at same time.
 | |
|    */
 | |
|   void testOnePeriod();
 | |
|   void startTestPeriod();
 | |
|   void finishTestPeriod();
 | |
|   void testOneDataBatch(const DataBatch& dataBatch,
 | |
|                         std::vector<Argument>* outArgs);
 | |
| 
 | |
|   /**
 | |
|    * Test for given data batch.
 | |
|    * @param dataBatch Data batch.
 | |
|    * @param evaluator Evaluator
 | |
|    * @return cost
 | |
|    */
 | |
|   real forwardOneBatch(const DataBatch& dataBatch,
 | |
|                        Evaluator* evaluator,
 | |
|                        std::vector<Argument>* outArgs);
 | |
| 
 | |
|   /**
 | |
|    * performance the full pass of test given test data provider
 | |
|    */
 | |
|   void test();
 | |
| 
 | |
| protected:
 | |
|   std::shared_ptr<ParameterClient2> testParameterClient_;
 | |
|   std::shared_ptr<TrainerConfigHelper> config_;
 | |
|   std::unique_ptr<TesterConfig> intconfig_;
 | |
|   GradientMachinePtr gradientMachine_;
 | |
|   std::shared_ptr<ParameterUpdater> parameterUpdater_;
 | |
|   std::unique_ptr<Evaluator> testEvaluator_;
 | |
|   std::unique_ptr<ParameterUtil> paramUtil_;
 | |
|   DataProviderPtr testDataProvider_;
 | |
|   TrainerStats stats_;
 | |
| 
 | |
|   // Used for saving the values of output layers
 | |
|   std::ofstream os_;
 | |
|   std::vector<MatrixPtr> cpuMat_;
 | |
|   std::vector<IVectorPtr> cpuVec_;
 | |
|   struct {
 | |
|     int64_t numSamples;
 | |
|     real cost;
 | |
|   } testContext_;
 | |
| 
 | |
| private:
 | |
|   /**
 | |
|    * Test one batch by batchId. It is only used for testOnePass.
 | |
|    *
 | |
|    * Durning testOnePass, each log_period will print cost statistics.
 | |
|    *
 | |
|    * @param batchId current batch id (from 0)
 | |
|    * @return num of tested samples. Zero if end of pass.
 | |
|    */
 | |
|   int64_t testOneBatchById(int64_t batchId);
 | |
| 
 | |
|   /**
 | |
|    * Test whole pass in one batch.
 | |
|    *
 | |
|    *
 | |
|    * @param passId current pass id (from 0)
 | |
|    */
 | |
|   void testOnePassBatch(int passId);
 | |
| 
 | |
|   /**
 | |
|    * test for one pass in several mini-batches.
 | |
|    *
 | |
|    * Used for sgd method.
 | |
|    *
 | |
|    * @param passId current pass id (from 0)
 | |
|    */
 | |
|   void testOnePass(int passId);
 | |
| 
 | |
|   /**
 | |
|    * print the outArgs to a stream
 | |
|    *
 | |
|    * used for save feature file
 | |
|    *
 | |
|    * @param [in] outArgs output arguments for network.
 | |
|    * @param [in,out] os output stream.
 | |
|    */
 | |
|   void printOutput(const std::vector<Argument>& outArgs, std::ostream& os);
 | |
| };
 | |
| 
 | |
| }  //  namespace paddle
 |