@ -52,10 +52,48 @@ struct GPUAdam;
struct CPUAdam ;
template < typename T , typename Flavour >
struct AdamFunctor ;
class AdamFunctor ;
template < typename T >
struct AdamFunctor < T , GPUAdam > {
class BetaPowFunctor {
private :
T beta1_ ;
T beta2_ ;
const T * beta1_pow_ ;
const T * beta2_pow_ ;
T * beta1_pow_out_ ;
T * beta2_pow_out_ ;
public :
BetaPowFunctor ( T beta1 , T beta2 , const T * beta1_pow , const T * beta2_pow ,
T * beta1_pow_out , T * beta2_pow_out )
: beta1_ ( beta1 ) ,
beta2_ ( beta2 ) ,
beta1_pow_ ( beta1_pow ) ,
beta2_pow_ ( beta2_pow ) ,
beta1_pow_out_ ( beta1_pow_out ) ,
beta2_pow_out_ ( beta2_pow_out ) { }
inline HOSTDEVICE void update_step ( size_t i ) const {
T beta1_pow_i = beta1_pow_ [ i ] ;
T beta2_pow_i = beta2_pow_ [ i ] ;
beta1_pow_out_ [ i ] = beta1_pow_i * beta1_ ;
beta2_pow_out_ [ i ] = beta2_pow_i * beta2_ ;
}
inline HOSTDEVICE void operator ( ) ( size_t i ) const { update_step ( i ) ; }
inline HOSTDEVICE void apply_update ( size_t limit ) const {
for ( size_t i = 0 ; i < limit ; + + i ) {
update_step ( i ) ;
}
}
} ;
template < typename T >
class AdamFunctor < T , GPUAdam > {
private :
T beta1_ ;
T beta2_ ;
T epsilon_ ;
@ -71,6 +109,7 @@ struct AdamFunctor<T, GPUAdam> {
const T * param_ ;
T * param_out_ ;
public :
AdamFunctor ( T beta1 , T beta2 , T epsilon , const T * beta1_pow ,
const T * beta2_pow , const T * mom1 , T * mom1_out , const T * mom2 ,
T * mom2_out , const T * lr , const T * grad , const T * param ,
@ -114,7 +153,8 @@ struct AdamFunctor<T, GPUAdam> {
} ;
template < typename T >
struct AdamFunctor < T , CPUAdam > {
class AdamFunctor < T , CPUAdam > {
private :
T beta1_ ;
T beta2_ ;
T epsilon_ ;
@ -130,6 +170,7 @@ struct AdamFunctor<T, CPUAdam> {
const T * param_ ;
T * param_out_ ;
public :
AdamFunctor ( T beta1 , T beta2 , T epsilon , const T * beta1_pow ,
const T * beta2_pow , const T * mom1 , T * mom1_out , const T * mom2 ,
T * mom2_out , const T * lr , const T * grad , const T * param ,
@ -179,10 +220,11 @@ struct AdamFunctor<T, CPUAdam> {
} ;
template < typename T , typename Flavour >
struct SparseAdamFunctor ;
class SparseAdamFunctor ;
template < typename T >
struct SparseAdamFunctor < T , GPUAdam > {
class SparseAdamFunctor < T , GPUAdam > {
private :
T beta1_ ;
T beta2_ ;
T epsilon_ ;
@ -203,6 +245,7 @@ struct SparseAdamFunctor<T, GPUAdam> {
int64_t row_count_ ;
bool lazy_mode_ ;
public :
SparseAdamFunctor ( T beta1 , T beta2 , T epsilon , const T * beta1_pow ,
const T * beta2_pow , const T * mom1 , T * mom1_out ,
const T * mom2 , T * mom2_out , const T * lr , const T * grad ,
@ -261,7 +304,8 @@ struct SparseAdamFunctor<T, GPUAdam> {
} ;
template < typename T >
struct SparseAdamFunctor < T , CPUAdam > {
class SparseAdamFunctor < T , CPUAdam > {
private :
T beta1_ ;
T beta2_ ;
T epsilon_ ;
@ -281,6 +325,7 @@ struct SparseAdamFunctor<T, CPUAdam> {
int64_t row_numel_ ;
int64_t row_count_ ;
public :
SparseAdamFunctor ( T beta1 , T beta2 , T epsilon , const T * beta1_pow ,
const T * beta2_pow , const T * mom1 , T * mom1_out ,
const T * mom2 , T * mom2_out , const T * lr , const T * grad ,
@ -397,6 +442,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
Ref ( ctx . Output < LoDTensor > ( " Moment1Out " ) , " Must set Moment1Out " ) ;
auto & mom2_out =
Ref ( ctx . Output < LoDTensor > ( " Moment2Out " ) , " Must set Moment1Out " ) ;
auto & beta1_pow_out =
Ref ( ctx . Output < LoDTensor > ( " Beta1PowOut " ) , " Must set Beta1PowOut " ) ;
auto & beta2_pow_out =
Ref ( ctx . Output < LoDTensor > ( " Beta2PowOut " ) , " Must set Beta2PowOut " ) ;
T beta1 = static_cast < T > ( ctx . Attr < float > ( " beta1 " ) ) ;
if ( ctx . HasInput ( " Beta1Tensor " ) ) {
@ -408,6 +457,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto * beta2_tensor = ctx . Input < framework : : Tensor > ( " Beta2Tensor " ) ;
beta2 = static_cast < T > ( GetAttrFromTensor ( beta2_tensor ) ) ;
}
VLOG ( 3 ) < < " beta1_pow.numel() : " < < beta1_pow . numel ( )
< < " beta2_pow.numel() : " < < beta2_pow . numel ( ) ;
VLOG ( 3 ) < < " param.numel(): " < < param . numel ( ) ;
BetaPowFunctor < T > beta_functor (
beta1 , beta2 , beta1_pow . template data < T > ( ) ,
beta2_pow . template data < T > ( ) ,
beta1_pow_out . template mutable_data < T > ( ctx . GetPlace ( ) ) ,
beta2_pow_out . template mutable_data < T > ( ctx . GetPlace ( ) ) ) ;
if ( grad_var - > IsType < framework : : LoDTensor > ( ) ) {
auto & grad = Ref ( ctx . Input < LoDTensor > ( " Grad " ) , " Must set Grad " ) ;
@ -423,6 +480,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
param . template data < T > ( ) ,
param_out . template mutable_data < T > ( ctx . GetPlace ( ) ) ) ;
functor ( param . numel ( ) ) ;
beta_functor . apply_update ( beta2_pow . numel ( ) ) ;
} else if ( platform : : is_gpu_place ( ctx . GetPlace ( ) ) ) {
AdamFunctor < T , GPUAdam > functor (
beta1 , beta2 , epsilon , beta1_pow . template data < T > ( ) ,
@ -433,11 +491,16 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr . template data < T > ( ) , grad . template data < T > ( ) ,
param . template data < T > ( ) ,
param_out . template mutable_data < T > ( ctx . GetPlace ( ) ) ) ;
// update param and moment
platform : : ForRange < DeviceContext > for_range (
static_cast < const DeviceContext & > ( ctx . device_context ( ) ) ,
param . numel ( ) ) ;
for_range ( functor ) ;
// update beta1 and beta2
platform : : ForRange < DeviceContext > for_range_beta (
static_cast < const DeviceContext & > ( ctx . device_context ( ) ) ,
beta2_pow . numel ( ) ) ;
for_range_beta ( beta_functor ) ;
}
} else if ( grad_var - > IsType < framework : : SelectedRows > ( ) ) {
auto & grad =
@ -485,6 +548,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
lr . template data < T > ( ) , grad_data , param . template data < T > ( ) ,
param_out . template mutable_data < T > ( ctx . GetPlace ( ) ) , rows , row_numel ,
grad_merge . rows ( ) . size ( ) , lazy_mode ) ;
// update beta1 and beta2
beta_functor . apply_update ( beta2_pow . numel ( ) ) ;
if ( lazy_mode ) {
VLOG ( 3 ) < < " run cpu lazy mode " ;
size_t row_count = grad_merge . rows ( ) . size ( ) ;
@ -574,6 +639,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
static_cast < const DeviceContext & > ( ctx . device_context ( ) ) ,
param . numel ( ) ) ;
for_range ( functor ) ;
// update beta1 and beta2
platform : : ForRange < DeviceContext > for_range_beta (
static_cast < const DeviceContext & > ( ctx . device_context ( ) ) ,
beta2_pow . numel ( ) ) ;
for_range_beta ( beta_functor ) ;
}
} else {
PADDLE_THROW ( " Variable type not supported by adam_op " ) ;