You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							150 lines
						
					
					
						
							4.5 KiB
						
					
					
				
			
		
		
	
	
							150 lines
						
					
					
						
							4.5 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| #include "GradientMachine.h"
 | |
| #include "unordered_map"
 | |
| 
 | |
| namespace paddle {
 | |
| 
 | |
| class IGradientMachineMode {
 | |
| public:
 | |
|   virtual ~IGradientMachineMode() {}
 | |
| 
 | |
| public:  // interfaces
 | |
|          /**
 | |
|           * @brief create current mode's gradient machine by model config.
 | |
|           * @param config model config
 | |
|           */
 | |
|   virtual GradientMachine* create(const ModelConfig& config) = 0;
 | |
| 
 | |
|   /**
 | |
|    * @brief shouldBeMe the current mode of GradientMachine should be this mode.
 | |
|    * @param algo training algorithm name.
 | |
|    * @param trainerCount trainer count.
 | |
|    * @param isLocal is local mode (without pserver)
 | |
|    * @param isGpu is using gpu.
 | |
|    * @return true if mode should be this mode.
 | |
|    */
 | |
|   virtual bool shouldBeMe(const std::string& algo,
 | |
|                           size_t trainerCount,
 | |
|                           bool isLocal,
 | |
|                           bool isGpu) const = 0;
 | |
| 
 | |
|   /**
 | |
|    * @brief Is data must be in cpu even if using gpu mode.
 | |
|    * @param trainerCount trainer count
 | |
|    * @return true if data must be gpu.
 | |
|    */
 | |
|   virtual bool isDataMustInCpu(size_t trainerCount) const = 0;
 | |
| 
 | |
|   /**
 | |
|    * @brief Need not to use mini-batch method, and should train all data in one
 | |
|    * batch in one pass.
 | |
|    */
 | |
|   virtual bool needTrainWholeDataInOneBatch() const = 0;
 | |
| 
 | |
| public:  // static methods.
 | |
|          /**
 | |
|           * @brief register a custom gradient machine mode.
 | |
|           * @note For user to register a custom gradient machine mode, id should >=
 | |
|           * kCustom.
 | |
|           * @param mode mode id.
 | |
|           * @param ptr mode description object.
 | |
|           */
 | |
|   static void regGradientMachineMode(
 | |
|       int32_t mode, std::unique_ptr<IGradientMachineMode>&& ptr) {
 | |
|     modes_.insert(std::make_pair(mode, std::move(ptr)));
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief get custom mode from mode id.
 | |
|    * @param mode mode id
 | |
|    * @return mode description object.
 | |
|    */
 | |
|   static IGradientMachineMode* mode(int32_t mode) {
 | |
|     if (modes_.find(mode) != modes_.end()) {
 | |
|       return modes_[mode].get();
 | |
|     } else {
 | |
|       return nullptr;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief helper function to test trainWholeDataInOneBatch or not for mode
 | |
|    */
 | |
|   static bool trainWholeDataInOneBatch(int32_t mode) {
 | |
|     if (modes_.find(mode) != modes_.end()) {
 | |
|       return modes_[mode]->needTrainWholeDataInOneBatch();
 | |
|     } else {
 | |
|       return false;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief Try to get custom mode if we can.
 | |
|    * @param [out] mode the custom mode id.
 | |
|    * @param [in] algo algorithm name
 | |
|    * @param [in] trainerCount trainer count.
 | |
|    * @param [in] isLocal is local or not
 | |
|    * @param [in] isGpu using gpu or not.
 | |
|    * @return true if there is a custom mode fit these conditions.
 | |
|    */
 | |
|   static bool tryGetMode(int* mode,
 | |
|                          const std::string& algo,
 | |
|                          int32_t trainerCount,
 | |
|                          bool isLocal,
 | |
|                          bool isGpu) {
 | |
|     for (auto it = modes_.begin(); it != modes_.end(); ++it) {
 | |
|       if (it->second->shouldBeMe(algo, trainerCount, isLocal, isGpu)) {
 | |
|         *mode = it->first;
 | |
|         return true;
 | |
|       }
 | |
|     }
 | |
|     return false;
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief helper function for data must in cpu
 | |
|    */
 | |
|   static bool dataMustInCpu(int32_t mode, size_t trainerCount) {
 | |
|     if (modes_.find(mode) != modes_.end()) {
 | |
|       return modes_[mode]->isDataMustInCpu(trainerCount);
 | |
|     } else {
 | |
|       // provide data to cpu if using synchronized multi-gpu gradient machine.
 | |
|       return trainerCount > 1;
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   /**
 | |
|    * @brief try to create gradient machine by mode & config.
 | |
|    * @return nullptr if we cannot create a gradient machine by such mode.
 | |
|    */
 | |
|   static GradientMachine* tryCreateGradientMachine(int32_t mode,
 | |
|                                                    const ModelConfig& config) {
 | |
|     auto m = IGradientMachineMode::mode(mode);
 | |
|     if (m) {
 | |
|       return m->create(config);
 | |
|     } else {
 | |
|       return nullptr;
 | |
|     }
 | |
|   }
 | |
| 
 | |
| private:
 | |
|   static std::unordered_map<int32_t, std::unique_ptr<IGradientMachineMode>>
 | |
|       modes_;
 | |
| };
 | |
| 
 | |
| }  // namespace paddle
 |