You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							234 lines
						
					
					
						
							5.2 KiB
						
					
					
				
			
		
		
	
	
							234 lines
						
					
					
						
							5.2 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "paddle/utils/Util.h"
 | |
| 
 | |
| #include <stdio.h>
 | |
| 
 | |
| #include "hl_gpu.h"
 | |
| #include "paddle/gserver/gradientmachines/GradientMachine.h"
 | |
| 
 | |
| #include "TrainerConfig.pb.h"
 | |
| 
 | |
| #include <stdlib.h>
 | |
| #include <fstream>
 | |
| #include <sstream>
 | |
| #include "ParameterUpdater.h"
 | |
| 
 | |
| namespace paddle {
 | |
| /**
 | |
|  * @brief TrainerStats object will statistics sample processed and total cost.
 | |
|  *
 | |
|  * There are two stats in it, the 'AvgCost' and 'CurrentAvgCost'. 'AvgCost'
 | |
|  * means cost through one pass(all mini-batches). 'CurrentAvgCost' means cost
 | |
|  * through one mini-batch.
 | |
|  */
 | |
| class TrainerStats {
 | |
| public:
 | |
|   /**
 | |
|    * @brief reset all stats.
 | |
|    *
 | |
|    * often used before pass start.
 | |
|    */
 | |
|   inline void reset() {
 | |
|     numProcessed_ = 0;
 | |
|     totalCost_ = .0;
 | |
|     this->resetCurrentStat();
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief reset current stat.
 | |
|    *
 | |
|    * 'current' means the most recent --log_period mini-batches
 | |
|    */
 | |
|   inline void resetCurrentStat() {
 | |
|     currentCost_ = .0;
 | |
|     currentSamples_ = 0;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief add cost to stat.
 | |
|    * @param numProcessed current mini-batch size
 | |
|    * @param cost current mini-batch cost
 | |
|    */
 | |
|   inline void addCost(int64_t numProcessed, real cost) {
 | |
|     this->numProcessed_ += numProcessed;
 | |
|     this->totalCost_ += cost;
 | |
|     this->currentSamples_ += numProcessed;
 | |
|     this->currentCost_ += cost;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief get average cost through on pass(all processed mini-batches)
 | |
|    * @return pass average cost
 | |
|    */
 | |
|   inline real getAvgCost() const {
 | |
|     CHECK_NE(this->numProcessed_, 0);
 | |
|     return this->totalCost_ / this->numProcessed_;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief get current mini-batch's average cost.
 | |
|    * @return mini-batch average cost
 | |
|    */
 | |
|   inline real getCurrentAvgCost() const {
 | |
|     CHECK_NE(this->currentSamples_, 0);
 | |
|     return this->currentCost_ / this->currentSamples_;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief get all processed samples' number
 | |
|    * @return all processed samples' number
 | |
|    */
 | |
|   inline int64_t getNumProcessed() const { return this->numProcessed_; }
 | |
| 
 | |
|   /**
 | |
|    * @brief same function as addCost. But it is simple to invoke.
 | |
|    * For example:
 | |
|    *
 | |
|    * @code{.cpp}
 | |
|    * TrainerStats stat;
 | |
|    * cost = neuralNetwork.forward(batchSize);
 | |
|    * stat += {batchSize, cost};
 | |
|    * @endcode
 | |
|    *
 | |
|    * @param p a pair of parameter, first is numProcessed, second is cost.
 | |
|    * @return *this
 | |
|    */
 | |
|   inline TrainerStats& operator+=(const std::pair<int64_t, real>& p) {
 | |
|     this->addCost(p.first, p.second);
 | |
|     return *this;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief TrainerStats Constructor.
 | |
|    *
 | |
|    * reset stat when constructed.
 | |
|    */
 | |
|   inline TrainerStats() { this->reset(); }
 | |
| 
 | |
|   /**
 | |
|    * @brief show stats to ostream.
 | |
|    *
 | |
|    * If there is no need to print current cost, set withCurrentCost to False.
 | |
|    *
 | |
|    * @param os output stream.
 | |
|    * @param withCurrentCost print current cost or not.
 | |
|    */
 | |
|   void showStats(std::ostream& os, bool withCurrentCost = true) const {
 | |
|     os << "samples=" << this->getNumProcessed()
 | |
|        << " AvgCost=" << this->getAvgCost();
 | |
|     if (withCurrentCost) {
 | |
|       os << " CurrentCost=" << this->getCurrentAvgCost();
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief get stats to std::string
 | |
|    * @param withCurrentCost return current cost or not
 | |
|    * @return stats string
 | |
|    */
 | |
|   std::string getStats(bool withCurrentCost = true) const {
 | |
|     std::ostringstream os;
 | |
|     this->showStats(os, withCurrentCost);
 | |
|     return os.str();
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   int64_t numProcessed_;
 | |
|   real totalCost_;
 | |
|   real currentCost_;
 | |
|   int64_t currentSamples_;
 | |
| };
 | |
| 
 | |
| inline std::ostream& operator<<(std::ostream& os, const TrainerStats& stats) {
 | |
|   stats.showStats(os);
 | |
|   return os;
 | |
| }
 | |
| 
 | |
| /**
 | |
|  * TrainerInternalConfig
 | |
|  * general configs for training
 | |
|  */
 | |
| struct TrainerInternalConfig {
 | |
|   /**
 | |
|    * @brief Create TrainerInternalConfig from GradientMachine::CreateMode and
 | |
|    * command line arguments.
 | |
|    * @param mode
 | |
|    * @return
 | |
|    */
 | |
|   static std::unique_ptr<TrainerInternalConfig> createFromMode(
 | |
|       GradientMachine::CreateMode mode);
 | |
| 
 | |
|   /**
 | |
|    * indicate whether the training is local
 | |
|    * if local, no parameter server is used
 | |
|    */
 | |
|   bool local;
 | |
| 
 | |
|   /**
 | |
|    * indicate whether training uses GPU
 | |
|    */
 | |
|   bool use_gpu;
 | |
| 
 | |
|   /**
 | |
|    * indicate number of trainer
 | |
|    */
 | |
|   int trainer_count;
 | |
| 
 | |
|   /**
 | |
|    * how frequently to show param stats
 | |
|    */
 | |
|   int show_param_stats_period;
 | |
| 
 | |
|   /**
 | |
|    * current trainer id
 | |
|    */
 | |
|   int trainer_id;
 | |
| 
 | |
|   /**
 | |
|    * frequency to dump log
 | |
|    */
 | |
|   int log_period;
 | |
| 
 | |
|   /**
 | |
|    * dot period
 | |
|    */
 | |
|   int dot_period;
 | |
| 
 | |
|   /**
 | |
|    * num passes for training
 | |
|    */
 | |
|   int num_passes;
 | |
| 
 | |
|   /**
 | |
|    * use old updater
 | |
|    */
 | |
|   bool use_old_updater;
 | |
| 
 | |
|   /**
 | |
|    * whether to load and save parameter in pserver
 | |
|    */
 | |
|   bool loadsave_parameters_in_pserver;
 | |
| 
 | |
|   /**
 | |
|    * training mode
 | |
|    */
 | |
|   GradientMachine::CreateMode mode;
 | |
| };
 | |
| 
 | |
| }  //  namespace paddle
 |