|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/optimizers/adam_op.h"
|
|
|
|
#include "paddle/fluid/operators/optimizers/adam_op.h"
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
@ -74,7 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
|
|
|
|
"output size is 1, but received "
|
|
|
|
"output size is 1, but received "
|
|
|
|
"value is:%d.",
|
|
|
|
"value is:%d.",
|
|
|
|
beta2_pow_out->numel()));
|
|
|
|
beta2_pow_out->numel()));
|
|
|
|
|
|
|
|
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
if (ctx.HasInput("Beta1Tensor")) {
|
|
|
|
if (ctx.HasInput("Beta1Tensor")) {
|
|
|
|
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
|
|
|
|
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
|
|
|
@ -109,7 +109,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
|
|
|
|
|
|
|
|
|
|
|
|
//update in cpu and then copy to xpu
|
|
|
|
// update in cpu and then copy to xpu
|
|
|
|
if (beta1_pow.place() == platform::CPUPlace() &&
|
|
|
|
if (beta1_pow.place() == platform::CPUPlace() &&
|
|
|
|
beta2_pow.place() == platform::CPUPlace()) {
|
|
|
|
beta2_pow.place() == platform::CPUPlace()) {
|
|
|
|
const T* beta1_pow_p = beta1_pow.template data<T>();
|
|
|
|
const T* beta1_pow_p = beta1_pow.template data<T>();
|
|
|
|