@ -28,6 +28,11 @@
# include "paddle/fluid/platform/device_context.h"
# include "paddle/fluid/platform/device_context.h"
# include "paddle/fluid/platform/enforce.h"
# include "paddle/fluid/platform/enforce.h"
# include "paddle/fluid/platform/profiler.h"
# include "paddle/fluid/platform/profiler.h"
# ifdef PADDLE_WITH_MKLDNN
# include "paddle/fluid/platform/mkldnn_helper.h"
# endif
DECLARE_bool ( use_mkldnn ) ;
namespace paddle {
namespace paddle {
namespace imperative {
namespace imperative {
@ -192,6 +197,9 @@ void VarBase::ClearGradient() {
auto * grad_t =
auto * grad_t =
grad_var_ - > MutableVar ( ) - > GetMutable < framework : : SelectedRows > ( ) ;
grad_var_ - > MutableVar ( ) - > GetMutable < framework : : SelectedRows > ( ) ;
if ( grad_t - > mutable_value ( ) - > IsInitialized ( ) ) {
if ( grad_t - > mutable_value ( ) - > IsInitialized ( ) ) {
# ifdef PADDLE_WITH_MKLDNN
if ( FLAGS_use_mkldnn ) ClearMKLDNNCache ( grad_t - > place ( ) ) ;
# endif
grad_t - > mutable_rows ( ) - > clear ( ) ;
grad_t - > mutable_rows ( ) - > clear ( ) ;
grad_t - > mutable_value ( ) - > clear ( ) ;
grad_t - > mutable_value ( ) - > clear ( ) ;
}
}
@ -202,6 +210,9 @@ void VarBase::ClearGradient() {
auto * dev_ctx =
auto * dev_ctx =
platform : : DeviceContextPool : : Instance ( ) . Get ( grad_t - > place ( ) ) ;
platform : : DeviceContextPool : : Instance ( ) . Get ( grad_t - > place ( ) ) ;
operators : : math : : set_constant ( * dev_ctx , grad_t , 0.0 ) ;
operators : : math : : set_constant ( * dev_ctx , grad_t , 0.0 ) ;
# ifdef PADDLE_WITH_MKLDNN
if ( FLAGS_use_mkldnn ) ClearMKLDNNCache ( grad_t - > place ( ) ) ;
# endif
}
}
}
}
}
}