You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							174 lines
						
					
					
						
							5.6 KiB
						
					
					
				
			
		
		
	
	
							174 lines
						
					
					
						
							5.6 KiB
						
					
					
				/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License. */
 | 
						|
 | 
						|
#include "LearningRateScheduler.h"
 | 
						|
#include "paddle/utils/StringUtil.h"
 | 
						|
 | 
						|
namespace paddle {
 | 
						|
 | 
						|
ClassRegistrar<LearningRateScheduler, OptimizationConfig>
 | 
						|
    LearningRateScheduler::registrar_;
 | 
						|
 | 
						|
LearningRateScheduler* LearningRateScheduler::create(
 | 
						|
    const OptimizationConfig& config) {
 | 
						|
  return registrar_.createByType(config.learning_rate_schedule(), config);
 | 
						|
}
 | 
						|
 | 
						|
// LRS stands for LearningRateScheduler
 | 
						|
 | 
						|
class BaseLRS : public LearningRateScheduler {
 | 
						|
public:
 | 
						|
  explicit BaseLRS(const OptimizationConfig& config)
 | 
						|
      : learningRate_(config.learning_rate()),
 | 
						|
        a_(config.learning_rate_decay_a()),
 | 
						|
        b_(config.learning_rate_decay_b()) {}
 | 
						|
 | 
						|
protected:
 | 
						|
  real learningRate_;
 | 
						|
  real a_;
 | 
						|
  real b_;
 | 
						|
};
 | 
						|
 | 
						|
class ConstLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit ConstLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    return learningRate_;
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(constant, ConstLRS);
 | 
						|
 | 
						|
class PolyLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit PolyLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    return learningRate_ * pow(1.0 + a_ * numSamplesProcessed, -b_);
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(poly, PolyLRS);
 | 
						|
 | 
						|
class CaffePolyLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit CaffePolyLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    if (numSamplesProcessed > a_) {
 | 
						|
      LOG_FIRST_N(WARNING, 1)
 | 
						|
          << "Using caffe_poly learning rate schedule, "
 | 
						|
          << "learning rate hits ZERO when "
 | 
						|
          << "numSamplesProcessed > config.learning_rate_decay_b(), "
 | 
						|
          << "training is over and you can stop it. "
 | 
						|
          << "See common/LearningRateScheduler.cpp for more info.";
 | 
						|
      return 0;
 | 
						|
    } else {
 | 
						|
      return learningRate_ * pow(1.0 - numSamplesProcessed / a_, b_);
 | 
						|
    }
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(caffe_poly, CaffePolyLRS);
 | 
						|
 | 
						|
class ExpLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit ExpLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    double decayRatio = (double)numSamplesProcessed / b_;
 | 
						|
    return learningRate_ * pow(a_, decayRatio);
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(exp, ExpLRS);
 | 
						|
 | 
						|
class DiscreteExpLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit DiscreteExpLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    int numDecays = floor(numSamplesProcessed / b_);
 | 
						|
    return learningRate_ * pow(a_, numDecays);
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(discexp, DiscreteExpLRS);
 | 
						|
 | 
						|
class LinearLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit LinearLRS(const OptimizationConfig& config) : BaseLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    return std::max(learningRate_ - a_ * numSamplesProcessed, b_);
 | 
						|
  }
 | 
						|
};
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(linear, LinearLRS);
 | 
						|
 | 
						|
/*
 | 
						|
  specify learning rate through
 | 
						|
  learning_rate_args = 'seg0:rate0,seg1:rate1,...,segK:rateK'
 | 
						|
  if seg_{i-1} <= numSamples <= seg_i,
 | 
						|
  then learning_rate = learning_rate_base * rate_i
 | 
						|
*/
 | 
						|
class ManualLRS : public BaseLRS {
 | 
						|
public:
 | 
						|
  explicit ManualLRS(const OptimizationConfig& config)
 | 
						|
      : BaseLRS(config), currentSegment_(0), lastNum_(0) {
 | 
						|
    std::vector<std::string> pieces;
 | 
						|
    str::split(config.learning_rate_args(), ',', &pieces);
 | 
						|
    rates_.reserve(pieces.size());
 | 
						|
    std::string s1, s2;
 | 
						|
 | 
						|
    for (auto& piece : pieces) {
 | 
						|
      auto pos = piece.find(':');
 | 
						|
      CHECK(pos != std::string::npos) << "Wrong format for learning_rate_args: "
 | 
						|
                                      << config.learning_rate_args();
 | 
						|
      segments_.push_back(str::to<int64_t>(piece.substr(0, pos)));
 | 
						|
      rates_.push_back(str::to<real>(piece.substr(pos + 1)));
 | 
						|
    }
 | 
						|
  }
 | 
						|
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    return calc(numSamplesProcessed);
 | 
						|
  }
 | 
						|
 | 
						|
  real calc(int64_t num) {
 | 
						|
    // We assume that num never decreases.
 | 
						|
    CHECK_LE(lastNum_, num);
 | 
						|
    lastNum_ = num;
 | 
						|
    while (currentSegment_ < rates_.size()) {
 | 
						|
      if (num <= segments_[currentSegment_]) {
 | 
						|
        return learningRate_ * rates_[currentSegment_];
 | 
						|
      }
 | 
						|
      ++currentSegment_;
 | 
						|
      if (currentSegment_ < rates_.size()) {
 | 
						|
        LOG(INFO) << " learning_rate changes to "
 | 
						|
                  << learningRate_ * rates_[currentSegment_];
 | 
						|
      }
 | 
						|
    }
 | 
						|
    return learningRate_ * rates_.back();
 | 
						|
  }
 | 
						|
 | 
						|
protected:
 | 
						|
  std::vector<real> rates_;
 | 
						|
  std::vector<int64_t> segments_;
 | 
						|
  size_t currentSegment_;
 | 
						|
  int64_t lastNum_;
 | 
						|
};
 | 
						|
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(manual, ManualLRS);
 | 
						|
 | 
						|
class PassManualLRS : public ManualLRS {
 | 
						|
public:
 | 
						|
  explicit PassManualLRS(const OptimizationConfig& config)
 | 
						|
      : ManualLRS(config) {}
 | 
						|
  virtual real calcLearningRate(int64_t numSamplesProcessed, int64_t pass) {
 | 
						|
    return calc(pass);
 | 
						|
  }
 | 
						|
};
 | 
						|
 | 
						|
REGISTER_LEARNING_RATE_SCHEDULER(pass_manual, PassManualLRS);
 | 
						|
}  // namespace paddle
 |