|
|
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/optimizers/adam_op.h"
|
|
|
|
|
#include <gflags/gflags.h>
|
|
|
|
|
#include "gflags/gflags.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -109,7 +109,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
|
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
|
|
|
|
|
|
|
|
|
|
//update in cpu and then copy to xpu
|
|
|
|
|
// update in cpu and then copy to xpu
|
|
|
|
|
if (beta1_pow.place() == platform::CPUPlace() &&
|
|
|
|
|
beta2_pow.place() == platform::CPUPlace()) {
|
|
|
|
|
const T* beta1_pow_p = beta1_pow.template data<T>();
|
|
|
|
|