parent
e83950b0d2
commit
e63f1e6952
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,226 @@
|
||||
/**
|
||||
* TensorApply.h
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-06
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief The tensor evaluator classes.
|
||||
*/
|
||||
template<typename Derived, class T>
|
||||
class TensorApply {
|
||||
public:
|
||||
explicit INLINE TensorApply(const Derived& p)
|
||||
: data_(p.data_), stride_(p.stride_),
|
||||
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return data_[i * stride_ + j];
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return data_[index];
|
||||
}
|
||||
INLINE T& applyRef(int i, int j) {
|
||||
return data_[i * stride_ + j];
|
||||
}
|
||||
INLINE T& applyRef(int index) {
|
||||
return data_[index];
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return width_; }
|
||||
INLINE size_t getHeight() const { return height_; }
|
||||
INLINE bool isContiguous() const { return stride_ == width_ || height_ == 1; }
|
||||
INLINE bool useGpu() const { return useGpu_; }
|
||||
|
||||
T* data_;
|
||||
size_t stride_;
|
||||
size_t height_;
|
||||
size_t width_;
|
||||
bool useGpu_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The tensor evaluator classes.
|
||||
*
|
||||
* evaluator for rvalues
|
||||
*/
|
||||
template<typename Derived, class T>
|
||||
class TensorApply<const Derived, T> {
|
||||
public:
|
||||
explicit INLINE TensorApply(const Derived& p)
|
||||
: data_(p.data_), stride_(p.stride_),
|
||||
height_(p.height_), width_(p.width_), useGpu_(p.useGpu_) {}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return data_[i * stride_ + j];
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return data_[index];
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return width_; }
|
||||
INLINE size_t getHeight() const { return height_; }
|
||||
INLINE bool isContiguous() const { return stride_ == width_ || height_ == 1; }
|
||||
INLINE bool useGpu() const { return useGpu_; }
|
||||
|
||||
const T* data_;
|
||||
size_t stride_;
|
||||
size_t height_;
|
||||
size_t width_;
|
||||
bool useGpu_;
|
||||
};
|
||||
|
||||
template<typename Derived, class T>
|
||||
class TensorApply<const TensorExpression<Derived, T>, T> {
|
||||
public:
|
||||
explicit TensorApply(const TensorExpression<Derived, T>& expr)
|
||||
: expr_(expr.derived()) {}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return expr_.apply(i, j);
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return expr_.apply(index);
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return expr_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return expr_.getHeight(); }
|
||||
INLINE bool isContiguous() const { return expr_.isContiguous(); }
|
||||
INLINE bool useGpu() const { return expr_.useGpu(); }
|
||||
|
||||
TensorApply<const Derived, T> expr_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The unary expression evaluator classes.
|
||||
*/
|
||||
template<class OP, typename ArgType, class T>
|
||||
class TensorApply<const TensorUnaryOp<OP, ArgType, T>, T> {
|
||||
public:
|
||||
explicit INLINE TensorApply(const TensorUnaryOp<OP, ArgType, T>& expr)
|
||||
: op_(expr.op_), expr_(expr.expr_) {}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return op_(expr_.apply(i, j));
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return op_(expr_.apply(index));
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return expr_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return expr_.getHeight(); }
|
||||
INLINE bool isContiguous() const { return expr_.isContiguous(); }
|
||||
INLINE bool useGpu() const { return expr_.useGpu(); }
|
||||
|
||||
const OP op_;
|
||||
TensorApply<ArgType, T> expr_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The binary expression evaluator classes.
|
||||
*/
|
||||
template<class OP, typename LhsType, typename RhsType, class T>
|
||||
class TensorApply<const TensorBinaryOp<OP, LhsType, RhsType, T>, T> {
|
||||
public:
|
||||
explicit INLINE TensorApply(
|
||||
const TensorBinaryOp<OP, LhsType, RhsType, T>& expr)
|
||||
: op_(expr.op_), lhs_(expr.lhs_), rhs_(expr.rhs_) {
|
||||
#ifndef __CUDA_ARCH__
|
||||
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
|
||||
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
|
||||
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
|
||||
#endif
|
||||
}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return op_(lhs_.apply(i, j), rhs_.apply(i, j));
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return op_(lhs_.apply(index), rhs_.apply(index));
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return lhs_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return rhs_.getHeight(); }
|
||||
INLINE bool isContiguous() const {
|
||||
return lhs_.isContiguous() && rhs_.isContiguous();
|
||||
}
|
||||
INLINE bool useGpu() const { return lhs_.useGpu(); }
|
||||
|
||||
const OP op_;
|
||||
TensorApply<LhsType, T> lhs_;
|
||||
TensorApply<RhsType, T> rhs_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The ternary expression evaluator classes.
|
||||
*/
|
||||
template<typename ArgType1, typename ArgType2, typename ArgType3, class T>
|
||||
class TensorApply<const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>, T> {
|
||||
public:
|
||||
explicit INLINE TensorApply(
|
||||
const TensorTernaryOp<ArgType1, ArgType2, ArgType3, T>& expr)
|
||||
: expr1_(expr.expr1_), expr2_(expr.expr2_), expr3_(expr.expr3_) {
|
||||
#ifndef __CUDA_ARCH__
|
||||
CHECK_EQ(expr1_.getWidth(), expr2_.getWidth());
|
||||
CHECK_EQ(expr1_.getWidth(), expr3_.getWidth());
|
||||
CHECK_EQ(expr1_.getHeight(), expr2_.getHeight());
|
||||
CHECK_EQ(expr1_.getHeight(), expr3_.getHeight());
|
||||
CHECK_EQ(expr1_.useGpu(), expr2_.useGpu());
|
||||
CHECK_EQ(expr1_.useGpu(), expr3_.useGpu());
|
||||
#endif
|
||||
}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return expr1_.apply(i, j) ? expr2_.apply(i, j) : expr3_.apply(i, j);
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return expr1_.apply(index) ? expr2_.apply(index) : expr3_.apply(index);
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return expr1_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return expr1_.getHeight(); }
|
||||
INLINE bool isContiguous() const {
|
||||
return expr1_.isContiguous() &&
|
||||
expr2_.isContiguous() && expr3_.isContiguous();
|
||||
}
|
||||
INLINE bool useGpu() const { return expr1_.useGpu(); }
|
||||
|
||||
TensorApply<ArgType1, T> expr1_;
|
||||
TensorApply<ArgType2, T> expr2_;
|
||||
TensorApply<ArgType3, T> expr3_;
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief The const expression evaluator classes.
|
||||
*/
|
||||
template<class OP, typename ArgType, class T>
|
||||
class TensorApply<const TensorConstant<OP, ArgType, T>, T> {
|
||||
public:
|
||||
explicit INLINE TensorApply(const TensorConstant<OP, ArgType, T>& expr)
|
||||
: op_(expr.op_), expr_(expr.expr_) {}
|
||||
|
||||
INLINE T apply(int i, int j) const {
|
||||
return op_(i, j);
|
||||
}
|
||||
INLINE T apply(int index) const {
|
||||
return op_(index);
|
||||
}
|
||||
|
||||
INLINE size_t getWidth() const { return expr_.getWidth(); }
|
||||
INLINE size_t getHeight() const { return expr_.getHeight(); }
|
||||
INLINE bool isContiguous() const { return true; }
|
||||
INLINE bool useGpu() const { return expr_.useGpu(); }
|
||||
|
||||
const OP op_;
|
||||
TensorApply<ArgType, T> expr_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,104 @@
|
||||
/**
|
||||
* TensorEvaluate.h
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-06
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "hl_base.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief The tensor cpu evaluate api.
|
||||
*/
|
||||
template<class T, typename LeftType, typename RightType>
|
||||
inline void TensorCpuApply(LeftType& lhs, const RightType& rhs) {
|
||||
TensorApply<LeftType, T> lhs_(lhs);
|
||||
TensorApply<const RightType, T> rhs_(rhs);
|
||||
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
|
||||
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
|
||||
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
|
||||
|
||||
if (lhs_.isContiguous() && rhs_.isContiguous()) {
|
||||
int size = lhs_.getHeight() * lhs_.getWidth();
|
||||
for (int index = 0; index < size; index++) {
|
||||
lhs_.applyRef(index) = rhs_.apply(index);
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < lhs_.getHeight(); i++) {
|
||||
for (size_t j = 0; j < lhs_.getWidth(); j++) {
|
||||
lhs_.applyRef(i, j) = rhs_.apply(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef __NVCC__
|
||||
template<typename LeftType, typename RightType>
|
||||
__global__
|
||||
void TensorElementWiseOp(LeftType lhs, RightType rhs, const int border) {
|
||||
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx < border) {
|
||||
lhs.applyRef(idx) = rhs.apply(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LeftType, typename RightType>
|
||||
__global__ void TensorElementWiseOp(LeftType lhs, RightType rhs) {
|
||||
const int colIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int rowIdx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
for (int i = rowIdx; i < lhs.getHeight(); i += gridDim.y * blockDim.y) {
|
||||
for (int j = colIdx; j < lhs.getWidth(); j += gridDim.x * blockDim.x) {
|
||||
lhs.applyRef(i, j) = rhs.apply(i, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief The tensor gpu evaluate api.
|
||||
*/
|
||||
template<class T, typename LeftType, typename RightType>
|
||||
inline void TensorGpuApply(LeftType& lhs, const RightType& rhs) {
|
||||
TensorApply<LeftType, T> lhs_(lhs);
|
||||
TensorApply<const RightType, T> rhs_(rhs);
|
||||
CHECK_EQ(lhs_.getWidth(), rhs_.getWidth());
|
||||
CHECK_EQ(lhs_.getHeight(), rhs_.getHeight());
|
||||
CHECK_EQ(lhs_.useGpu(), rhs_.useGpu());
|
||||
|
||||
int dimM = lhs_.getHeight();
|
||||
int dimN = lhs_.getWidth();
|
||||
|
||||
if (lhs_.isContiguous() && rhs_.isContiguous()) {
|
||||
int size = dimM * dimN;
|
||||
int blockSize = size <= 1024 ? size : 1024;
|
||||
int gridSize = (size + 1024 - 1) / 1024;
|
||||
TensorElementWiseOp
|
||||
<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>(lhs_, rhs_, size);
|
||||
} else {
|
||||
int blockSizeY = std::min(32, dimM);
|
||||
int blockSizeX = (32 / blockSizeY) * 32;
|
||||
int gridSizeX = std::min(32, (dimN + blockSizeX - 1) / blockSizeX);
|
||||
int gridSizeY = std::min(32, (dimM + blockSizeY - 1) / blockSizeY);
|
||||
dim3 threads(blockSizeX, blockSizeY);
|
||||
dim3 grid(gridSizeX, gridSizeY);
|
||||
TensorElementWiseOp
|
||||
<<<grid, threads, 0, STREAM_DEFAULT>>>(lhs_, rhs_);
|
||||
}
|
||||
|
||||
CHECK_SYNC("TensorGpuApply failed");
|
||||
}
|
||||
#else
|
||||
template<class T, typename LeftType, typename RightType>
|
||||
inline void TensorGpuApply(LeftType& lhs, RightType& rhs) {
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,182 @@
|
||||
/**
|
||||
* TrainingAlgorithmOp.cu
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-29
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*
|
||||
*/
|
||||
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "BaseMatrix.h"
|
||||
#include "TrainingAlgorithmOp.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
void sparseMomentumApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& momU,
|
||||
BaseMatrix& momV,
|
||||
real alpha,
|
||||
real beta,
|
||||
real gamma,
|
||||
real tau,
|
||||
real learningRate) {
|
||||
/**
|
||||
* \alpha_t = \alpha_{t-1} / k
|
||||
* \beta_t = \beta_{t-1} / (1 + \lambda\gamma_t)
|
||||
* u_t = u_{t-1} - \alpha_t \gamma_t g_t
|
||||
* v_t = v_{t-1} + \tau_{t-1} \alpha_t \gamma_t g_t
|
||||
* \tau_t = \tau_{t-1} + \beta_t / \alpha_t
|
||||
*/
|
||||
momU -= (alpha * gamma * learningRate) * grad;
|
||||
momV += (tau * alpha * gamma * learningRate) * grad;
|
||||
value = (tau / beta + (real)1 / alpha) * momU + ((real)1 / beta) * momV;
|
||||
}
|
||||
|
||||
void adadeltaApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& accum,
|
||||
BaseMatrix& accum_update,
|
||||
BaseMatrix& lr,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
accum = rou * accum + ((real)1 - rou) * grad.square();
|
||||
|
||||
// learn_rate: sqrt(( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ))
|
||||
lr = ((accum_update + epsilon) / (accum + epsilon)).sqrt();
|
||||
|
||||
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
|
||||
accum_update = rou * accum_update + ((real)1 - rou) * (grad * lr).square();
|
||||
|
||||
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
|
||||
value += mom;
|
||||
}
|
||||
|
||||
void adagradApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& accum_buffer,
|
||||
BaseMatrix& accum,
|
||||
BaseMatrix& lr,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate) {
|
||||
accum += grad.square();
|
||||
lr = (accum_buffer + accum + epsilon).sqrt().reciprocal();
|
||||
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
|
||||
value += mom;
|
||||
}
|
||||
|
||||
void rmspropApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& g,
|
||||
BaseMatrix& f,
|
||||
BaseMatrix& lr,
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
// For the first time update, make the sum be the current square
|
||||
// so that the initial estimation of E(g_t^2) will not be too small.
|
||||
if (firstTime) {
|
||||
g = accumulatedRou * g + grad.square();
|
||||
} else {
|
||||
g = accumulatedRou * g + ((real)1 - rou) * grad.square();
|
||||
}
|
||||
|
||||
// E(f_t) = \rou * E(f_{t-1}) + (1-\rou) * g
|
||||
f = accumulatedRou * f + ((real)1 - rou) * grad;
|
||||
|
||||
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(f_t))^2 + epsilon )
|
||||
// Basiclly if the sign of the gradient changes more often,
|
||||
// the learning rate will be decreased.
|
||||
lr = (g - f.square() + epsilon).sqrt().reciprocal();
|
||||
|
||||
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
|
||||
value += mom;
|
||||
}
|
||||
|
||||
void decayedAdagradApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& accum,
|
||||
BaseMatrix& lr,
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
// For the first time update, make the sum be the current square
|
||||
// so that the initial estimation of E(g_t^2) will not be too small.
|
||||
if (firstTime) {
|
||||
accum = accumulatedRou * accum + grad.square();
|
||||
} else {
|
||||
accum = accumulatedRou * accum + ((real)1 - rou) * grad.square();
|
||||
}
|
||||
|
||||
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
|
||||
// Basiclly if the bigger the magnitude gradient is,
|
||||
// the smaller the learning rate will be.
|
||||
lr = (accum + epsilon).sqrt().reciprocal();
|
||||
|
||||
mom = mom * momentum - learningRate * lr * (grad + value * decayRate);
|
||||
value += mom;
|
||||
}
|
||||
|
||||
void adamApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom, // firse moment
|
||||
BaseMatrix& v, // second moment
|
||||
real beta1,
|
||||
real beta2,
|
||||
real beta1_power,
|
||||
real beta2_power,
|
||||
real epsilon,
|
||||
real learningRate) {
|
||||
real alpha = learningRate *
|
||||
std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
|
||||
|
||||
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
|
||||
mom = beta1 * mom + ((real)1 - beta1) * grad;
|
||||
|
||||
// v_t = \beta_2 * v_{t-1} + (1-\beta_2)* g_{t-1}^2
|
||||
v = beta2 * v + ((real)1 - beta2) * grad.square();
|
||||
|
||||
value -= (mom * alpha) / (v.sqrt() + epsilon);
|
||||
}
|
||||
|
||||
void adamaxApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom, // firse moment
|
||||
BaseMatrix& u, // weighted infinity norm
|
||||
real beta1,
|
||||
real beta2,
|
||||
int64_t step,
|
||||
real alpha) {
|
||||
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
|
||||
mom = beta1 * mom + ((real)1 - beta1) * grad;
|
||||
|
||||
// u_t = max(\beta_2*u_{t-1}, abs(g_t))
|
||||
u = (beta2 * u > grad.abs()).condition(beta2 * u, grad.abs());
|
||||
|
||||
// \theta_t = \theta_{t-1} - (\alpha/(1-\beta_1^t))*m_t/u_t
|
||||
value -= (alpha / ((real)1 - (real)std::pow(beta1, step))) * (mom / u);
|
||||
}
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* TrainingAlgorithmOp.h
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-29
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/utils/Logging.h"
|
||||
#include "BaseMatrix.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
/**
|
||||
* \brief Sparse Momentum optimizer.
|
||||
*/
|
||||
extern void sparseMomentumApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& momU,
|
||||
BaseMatrix& momV,
|
||||
real alpha,
|
||||
real beta,
|
||||
real gamma,
|
||||
real tau,
|
||||
real learningRate);
|
||||
|
||||
/**
|
||||
* \brief AdaDelta optimizer.
|
||||
*/
|
||||
extern void adadeltaApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& sum,
|
||||
BaseMatrix& sum1,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& lr,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate);
|
||||
|
||||
/**
|
||||
* \brief AdaGrad optimizer.
|
||||
*/
|
||||
extern void adagradApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& sum,
|
||||
BaseMatrix& sum1,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& lr,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate);
|
||||
|
||||
/**
|
||||
* \brief RMSProp optimizer.
|
||||
*/
|
||||
extern void rmspropApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& g,
|
||||
BaseMatrix& f,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& lr,
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime);
|
||||
|
||||
/**
|
||||
* \brief Decayed AdaGrad optimizer.
|
||||
*/
|
||||
extern void decayedAdagradApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& accum,
|
||||
BaseMatrix& lr,
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime);
|
||||
|
||||
/**
|
||||
* \brief Adam optimizer.
|
||||
*/
|
||||
extern void adamApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom,
|
||||
BaseMatrix& v,
|
||||
real beta1,
|
||||
real beta2,
|
||||
real beta1_power,
|
||||
real beta2_power,
|
||||
real epsilon,
|
||||
real learningRate);
|
||||
|
||||
/**
|
||||
* \brief AdaMax optimizer.
|
||||
*/
|
||||
extern void adamaxApply(BaseMatrix& value,
|
||||
BaseMatrix& grad,
|
||||
BaseMatrix& mom, // firse moment
|
||||
BaseMatrix& u, // weighted infinity norm
|
||||
real beta1,
|
||||
real beta2,
|
||||
int64_t step,
|
||||
real alpha);
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,190 @@
|
||||
/**
|
||||
* OriginalOptimizerApi.h
|
||||
*
|
||||
* Author: hedaoyuan (hedaoyuan@baidu.com)
|
||||
* Created on: 2016-06-29
|
||||
*
|
||||
* Copyright (c) Baidu.com, Inc. All Rights Reserved
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/utils/GlobalConstants.h"
|
||||
#include "paddle/math/Vector.h"
|
||||
|
||||
using namespace paddle; // NOLINT
|
||||
|
||||
void SparseMomentumParameterOptimizer(const VectorPtr vecs[],
|
||||
real alpha,
|
||||
real beta,
|
||||
real gamma,
|
||||
real tau,
|
||||
real learningRate) {
|
||||
vecs[PARAMETER_MOMENTUM_UT]->add(*vecs[PARAMETER_GRADIENT],
|
||||
-alpha * gamma * learningRate);
|
||||
vecs[PARAMETER_MOMENTUM_VT]->add(*vecs[PARAMETER_GRADIENT],
|
||||
tau * alpha * gamma * learningRate);
|
||||
vecs[PARAMETER_VALUE]->add(*vecs[PARAMETER_MOMENTUM_UT],
|
||||
tau / beta + 1.0 / alpha,
|
||||
*vecs[PARAMETER_MOMENTUM_VT], 1.0 / beta);
|
||||
}
|
||||
|
||||
void AdagradParameterOptimizer(const VectorPtr vecs[],
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate) {
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM1]->addSquare(*vecs[PARAMETER_GRADIENT],
|
||||
1.0f);
|
||||
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM],
|
||||
*vecs[PARAMETER_GRADIENT_SQURESUM1]);
|
||||
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
|
||||
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
|
||||
|
||||
vecs[PARAMETER_VALUE]->sgdUpdate(
|
||||
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
|
||||
*vecs[PARAMETER_LEARNING_RATE], learningRate,
|
||||
momentum, decayRate);
|
||||
}
|
||||
|
||||
void AdaDeltaParameterOptimizer(const VectorPtr vecs[],
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(*vecs[PARAMETER_GRADIENT],
|
||||
rou, 1.0f - rou);
|
||||
|
||||
// learn_rate = sqrt( ( E(dx_{t-1}^2) + epsilon ) / ( E(g_t^2) + epsilon ) )
|
||||
vecs[PARAMETER_LEARNING_RATE]->dotDiv(*vecs[PARAMETER_GRADIENT_SQURESUM1],
|
||||
*vecs[PARAMETER_GRADIENT_SQURESUM],
|
||||
epsilon, epsilon);
|
||||
vecs[PARAMETER_LEARNING_RATE]->sqrt2();
|
||||
|
||||
// E(dx_t^2) = \rou * E(dx_{t-1}^2) + (1-\rou) * (-g*learn_rate)^2
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM1]->decayAddSquareMul(
|
||||
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_LEARNING_RATE], rou,
|
||||
1.0f - rou);
|
||||
|
||||
vecs[PARAMETER_VALUE]->sgdUpdate(
|
||||
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
|
||||
*vecs[PARAMETER_LEARNING_RATE], learningRate,
|
||||
momentum, decayRate);
|
||||
}
|
||||
|
||||
void RMSPropParameterOptimizer(const VectorPtr vecs[],
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
// For the first time update, make the sum be the current square
|
||||
// so that the initial estimation of E(g_t^2) will not be too small.
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
|
||||
*vecs[PARAMETER_GRADIENT], accumulatedRou,
|
||||
firstTime ? 1.0f : 1.0f - rou);
|
||||
|
||||
// E(g_t) = \rou * E(g_{t-1}) + (1-\rou) * g
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM1]->add(*vecs[PARAMETER_GRADIENT],
|
||||
accumulatedRou, 1.0f - rou);
|
||||
|
||||
// learn_rate = 1/sqrt( ( E(g_t^2) - (E(g_t))^2 + epsilon )
|
||||
// Basiclly if the sign of the gradient changes more often,
|
||||
// the learning rate will be decreased.
|
||||
vecs[PARAMETER_LEARNING_RATE]->assign(*vecs[PARAMETER_GRADIENT_SQURESUM]);
|
||||
vecs[PARAMETER_LEARNING_RATE]->addSquare(*vecs[PARAMETER_GRADIENT_SQURESUM1],
|
||||
-1.0f);
|
||||
vecs[PARAMETER_LEARNING_RATE]->add(epsilon);
|
||||
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
|
||||
|
||||
vecs[PARAMETER_VALUE]->sgdUpdate(
|
||||
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
|
||||
*vecs[PARAMETER_LEARNING_RATE], learningRate,
|
||||
momentum, decayRate);
|
||||
}
|
||||
|
||||
void DecayedAdagradParameterOptimizer(const VectorPtr vecs[],
|
||||
real accumulatedRou,
|
||||
real rou,
|
||||
real epsilon,
|
||||
real learningRate,
|
||||
real momentum,
|
||||
real decayRate,
|
||||
bool firstTime) {
|
||||
// E(g_t^2) = \rou * E(g_{t-1}^2) + (1-\rou) * g^2
|
||||
// For the first time update, make the sum be the current square
|
||||
// so that the initial estimation of E(g_t^2) will not be too small.
|
||||
vecs[PARAMETER_GRADIENT_SQURESUM]->decayAddSquare(
|
||||
*vecs[PARAMETER_GRADIENT], accumulatedRou,
|
||||
firstTime ? 1.0f : 1.0f - rou);
|
||||
|
||||
// learn_rate = 1/sqrt( ( E(g_t^2) + epsilon )
|
||||
// Basiclly if the bigger the magnitude gradient is,
|
||||
// the smaller the learning rate will be.
|
||||
vecs[PARAMETER_LEARNING_RATE]->assign(epsilon);
|
||||
vecs[PARAMETER_LEARNING_RATE]->add(*vecs[PARAMETER_GRADIENT_SQURESUM]);
|
||||
vecs[PARAMETER_LEARNING_RATE]->invSqrt(*vecs[PARAMETER_LEARNING_RATE]);
|
||||
|
||||
vecs[PARAMETER_VALUE]->sgdUpdate(
|
||||
*vecs[PARAMETER_GRADIENT], *vecs[PARAMETER_MOMENTUM],
|
||||
*vecs[PARAMETER_LEARNING_RATE], learningRate,
|
||||
momentum, decayRate);
|
||||
}
|
||||
|
||||
void AdamParameterOptimizer(const VectorPtr vecs[],
|
||||
real beta1,
|
||||
real beta2,
|
||||
real beta1_power,
|
||||
real beta2_power,
|
||||
real epsilon,
|
||||
real learningRate) {
|
||||
Vector* m = vecs[PARAMETER_MOMENTUM].get();
|
||||
Vector* g = vecs[PARAMETER_GRADIENT].get();
|
||||
Vector* v = vecs[PARAMETER_SECOND_MOMENTUM].get();
|
||||
Vector* theta = vecs[PARAMETER_VALUE].get();
|
||||
|
||||
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
|
||||
m->add(*g, beta1, 1 - beta1);
|
||||
|
||||
// v_t = \beta_2 * v_{t-1} + (1-\beta_2)* g_{t-1}^2
|
||||
g->square2();
|
||||
v->add(*g, beta2, 1 - beta2);
|
||||
|
||||
// tmp = m_t / ( \sqrt{v_t} + \epsilon )
|
||||
// \theta_t = \theta_{t-1} - \alpha * \sqrt(1-\beta_2^t) / (1-\beta_1^t) * tmp
|
||||
g->sqrt2(*v);
|
||||
g->dotDiv(*m, *g, 0., epsilon);
|
||||
real alpha = learningRate *
|
||||
std::sqrt((real)1 - beta2_power) / ((real)1 - beta1_power);
|
||||
theta->add(*theta, 1.0, *g, -alpha);
|
||||
}
|
||||
|
||||
void AdamaxParameterOptimizer(const VectorPtr vecs[],
|
||||
real beta1,
|
||||
real beta2,
|
||||
int64_t step,
|
||||
real alpha) {
|
||||
Vector* m = vecs[PARAMETER_MOMENTUM].get();
|
||||
Vector* g = vecs[PARAMETER_GRADIENT].get();
|
||||
Vector* u = vecs[PARAMETER_WEIGHTED_INFINITY_NORM].get();
|
||||
Vector* theta = vecs[PARAMETER_VALUE].get();
|
||||
|
||||
// m_t = \beta_1 * m_{t-1} + (1-\beta_1)* g_t;
|
||||
m->add(*g, beta1, 1 - beta1);
|
||||
|
||||
// u_t = max(\beta_2*u_{t-1}, abs(g_t))
|
||||
u->mulScalar(beta2);
|
||||
g->abs2();
|
||||
u->max2(*u, *g);
|
||||
|
||||
// \theta_t = \theta_{t-1} - (\alpha/(1-\beta_1^t))*m_t/u_t
|
||||
g->dotDiv(*m, *u);
|
||||
real learningRate = alpha / (1 - std::pow(beta1, step));
|
||||
theta->add(*theta, 1.0, *g, -learningRate);
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue