You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							126 lines
						
					
					
						
							3.8 KiB
						
					
					
				
			
		
		
	
	
							126 lines
						
					
					
						
							3.8 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| 
 | |
| #include "paddle/utils/Util.h"
 | |
| 
 | |
| #include <stdio.h>
 | |
| 
 | |
| #include "hl_gpu.h"
 | |
| #include "paddle/gserver/dataproviders/DataProvider.h"
 | |
| #include "paddle/gserver/gradientmachines/GradientMachine.h"
 | |
| 
 | |
| #include <stdlib.h>
 | |
| #include <fstream>
 | |
| #include "ParameterUpdater.h"
 | |
| #include "TrainerConfig.pb.h"
 | |
| #include "TrainerConfigHelper.h"
 | |
| 
 | |
| namespace paddle {
 | |
| 
 | |
| /**
 | |
|  * Configuration for parameter utils.
 | |
|  */
 | |
| struct ParameterUtilConfig {
 | |
|   DISABLE_COPY(ParameterUtilConfig);
 | |
| 
 | |
|   ParameterUtilConfig(bool save_only_one,
 | |
|                       int saving_period,
 | |
|                       bool load_save_parameters_in_pserver,
 | |
|                       std::string config)
 | |
|       : save_only_one_(save_only_one),
 | |
|         saving_period_(saving_period),
 | |
|         load_save_param_pserver_(load_save_parameters_in_pserver),
 | |
|         config_(config) {}
 | |
| 
 | |
|   bool save_only_one_;
 | |
|   int saving_period_;
 | |
|   bool load_save_param_pserver_;
 | |
|   std::string config_;
 | |
| };
 | |
| 
 | |
| /**
 | |
|  * ParameterUtil
 | |
|  * Utility class for loading and saving parameters
 | |
|  */
 | |
| class ParameterUtil {
 | |
| public:
 | |
|   /**
 | |
|    * Ctor.
 | |
|    *
 | |
|    * @param config
 | |
|    * @param intconfig
 | |
|    * @param gradientMachine
 | |
|    * @param parameterUpdater
 | |
|    * @return
 | |
|    */
 | |
|   ParameterUtil(const std::shared_ptr<TrainerConfigHelper> &config,
 | |
|                 std::unique_ptr<ParameterUtilConfig> &&intconfig,
 | |
|                 const GradientMachinePtr &gradientMachine,
 | |
|                 const std::shared_ptr<ParameterUpdater> ¶meterUpdater);
 | |
| 
 | |
|   /// Load parameter from the saved parameter file as pass passId
 | |
|   /// if loadsave_parameters_in_pserver is set, some parameters MUST
 | |
|   /// load in pserver, which is "remote".
 | |
|   /// loadParameters can choose to load local/remote parameter, or both.
 | |
|   bool loadParameters(int passId, bool local = true, bool remote = false);
 | |
| 
 | |
|   /// load parameters given path info
 | |
|   void loadParametersWithPath(const std::string &dir,
 | |
|                               bool local = true,
 | |
|                               bool remote = false);
 | |
| 
 | |
|   /// Save parameter to dist for pass passId
 | |
|   /// passInnerId means saving times in one pass, some users want to
 | |
|   /// save parameters when have processed some batches in one pass
 | |
|   /// passInnerId = 0 means do not need to save in one inner pass
 | |
|   void saveParameters(int passId, int passInnerId = 0);
 | |
| 
 | |
|   /// save parameters for one pass, when passInnerId > 0 means saving
 | |
|   /// the passInnerId times in one pass
 | |
|   void saveParametersOnePass(int passId, int passInnerId = 0);
 | |
| 
 | |
|   /// delete parameter from disk via passId
 | |
|   void deleteParameters(int passId, int passInnerId = 0);
 | |
| 
 | |
|   /// save config given path info
 | |
|   void saveConfigWithPath(const std::string &path);
 | |
| 
 | |
|   /**
 | |
|    * Try to load parameter from config.
 | |
|    * @return true if can load from trainer config.
 | |
|    */
 | |
|   inline bool tryLoadParametersFromConfig() {
 | |
|     auto &c = config_->getConfig();
 | |
|     if (!c.init_model_path().empty()) {
 | |
|       loadParametersWithPath(c.init_model_path());
 | |
|       return true;
 | |
|     } else if (c.start_pass() > 0) {
 | |
|       CHECK(loadParameters(c.start_pass() - 1));
 | |
|       return true;
 | |
|     } else {
 | |
|       return false;
 | |
|     }
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   std::shared_ptr<TrainerConfigHelper> config_;
 | |
|   std::unique_ptr<ParameterUtilConfig> intConfig_;
 | |
|   GradientMachinePtr gserver_;
 | |
|   std::shared_ptr<ParameterUpdater> pUpdater_;
 | |
| };
 | |
| 
 | |
| }  //  namespace paddle
 |