|
|
@ -74,7 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
|
|
|
|
"output size is 1, but received "
|
|
|
|
"output size is 1, but received "
|
|
|
|
"value is:%d.",
|
|
|
|
"value is:%d.",
|
|
|
|
beta2_pow_out->numel()));
|
|
|
|
beta2_pow_out->numel()));
|
|
|
|
|
|
|
|
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
|
|
|
|
if (ctx.HasInput("Beta1Tensor")) {
|
|
|
|
if (ctx.HasInput("Beta1Tensor")) {
|
|
|
|
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
|
|
|
|
auto* beta1_tensor = ctx.Input<framework::Tensor>("Beta1Tensor");
|
|
|
@ -88,30 +88,53 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
|
|
|
|
if (grad_var->IsType<framework::LoDTensor>()) {
|
|
|
|
if (grad_var->IsType<framework::LoDTensor>()) {
|
|
|
|
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
|
|
|
|
auto& grad = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Grad"), "Input",
|
|
|
|
"Grad", "Adam");
|
|
|
|
"Grad", "Adam");
|
|
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
const T* beta1_pow_ptr = beta1_pow.template data<T>();
|
|
|
|
|
|
|
|
const T* beta2_pow_ptr = beta2_pow.template data<T>();
|
|
|
|
|
|
|
|
Tensor xpu_beta1_pow;
|
|
|
|
|
|
|
|
Tensor xpu_beta2_pow;
|
|
|
|
|
|
|
|
if (beta1_pow.place() == platform::CPUPlace() &&
|
|
|
|
|
|
|
|
beta2_pow.place() == platform::CPUPlace()) {
|
|
|
|
|
|
|
|
TensorCopy(beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow);
|
|
|
|
|
|
|
|
TensorCopy(beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow);
|
|
|
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
|
|
|
beta1_pow_ptr = xpu_beta1_pow.template data<T>();
|
|
|
|
|
|
|
|
beta2_pow_ptr = xpu_beta2_pow.template data<T>();
|
|
|
|
|
|
|
|
}
|
|
|
|
int r = xpu::adam(
|
|
|
|
int r = xpu::adam(
|
|
|
|
dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(),
|
|
|
|
dev_ctx.x_context(), grad.template data<T>(), mom1.template data<T>(),
|
|
|
|
mom2.template data<T>(), param.template data<T>(),
|
|
|
|
mom2.template data<T>(), param.template data<T>(), beta1_pow_ptr,
|
|
|
|
beta1_pow.template data<T>(), beta2_pow.template data<T>(), beta1,
|
|
|
|
beta2_pow_ptr, beta1, beta2, epsilon, lr.template data<T>(),
|
|
|
|
beta2, epsilon, lr.template data<T>(),
|
|
|
|
|
|
|
|
mom1_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom1_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
mom2_out.template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
|
|
|
|
param_out.template mutable_data<T>(ctx.GetPlace()), param.numel());
|
|
|
|
|
|
|
|
|
|
|
|
const float* ptr0 = beta1_pow.template data<T>();
|
|
|
|
//update in cpu and then copy to xpu
|
|
|
|
float* ptr1 = beta1_pow_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
if (beta1_pow.place() == platform::CPUPlace() &&
|
|
|
|
float cpudata;
|
|
|
|
beta2_pow.place() == platform::CPUPlace()) {
|
|
|
|
xpu_memcpy(&cpudata, ptr0, sizeof(float), XPU_DEVICE_TO_HOST);
|
|
|
|
const T* beta1_pow_p = beta1_pow.template data<T>();
|
|
|
|
cpudata = cpudata * beta1;
|
|
|
|
beta1_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
|
|
|
|
xpu_memcpy(ptr1, &cpudata, sizeof(float), XPU_HOST_TO_DEVICE);
|
|
|
|
beta1 * beta1_pow_p[0];
|
|
|
|
|
|
|
|
const T* beta2_pow_p = beta2_pow.template data<T>();
|
|
|
|
const float* ptr2 = beta2_pow.template data<T>();
|
|
|
|
beta2_pow_out->mutable_data<T>(platform::CPUPlace())[0] =
|
|
|
|
float* ptr3 = beta2_pow_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
beta2 * beta2_pow_p[0];
|
|
|
|
float cpudata1;
|
|
|
|
} else {
|
|
|
|
xpu_memcpy(&cpudata1, ptr2, sizeof(float), XPU_DEVICE_TO_HOST);
|
|
|
|
T cpu_beta1_pow_out_data;
|
|
|
|
cpudata1 = cpudata1 * beta2;
|
|
|
|
T cpu_beta2_pow_out_data;
|
|
|
|
xpu_memcpy(ptr3, &cpudata1, sizeof(float), XPU_HOST_TO_DEVICE);
|
|
|
|
xpu_memcpy(&cpu_beta1_pow_out_data, beta1_pow_ptr, sizeof(T),
|
|
|
|
|
|
|
|
XPU_DEVICE_TO_HOST);
|
|
|
|
|
|
|
|
cpu_beta1_pow_out_data = cpu_beta1_pow_out_data * beta1;
|
|
|
|
|
|
|
|
xpu_memcpy(&cpu_beta2_pow_out_data, beta2_pow_ptr, sizeof(T),
|
|
|
|
|
|
|
|
XPU_DEVICE_TO_HOST);
|
|
|
|
|
|
|
|
cpu_beta2_pow_out_data = cpu_beta2_pow_out_data * beta2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T* beta1_pow_out_p = beta1_pow_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
T* beta2_pow_out_p = beta2_pow_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
xpu_memcpy(beta1_pow_out_p, &cpu_beta1_pow_out_data, sizeof(T),
|
|
|
|
|
|
|
|
XPU_HOST_TO_DEVICE);
|
|
|
|
|
|
|
|
xpu_memcpy(beta2_pow_out_p, &cpu_beta2_pow_out_data, sizeof(T),
|
|
|
|
|
|
|
|
XPU_HOST_TO_DEVICE);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
PADDLE_ENFORCE_EQ(r == xpu::Error_t::SUCCESS, true,
|
|
|
|
platform::errors::External(
|
|
|
|
platform::errors::External(
|
|
|
|