support dygraph in xpu place (#30051)

* support dygraph in xpu place; test=develop

* fix cpu/gpu compile error; test=develop

* fix compile error; test=develop

* fix xpu compile error; testd=develop
revert-31562-mean
hong 4 years ago committed by GitHub
parent eea7090c26
commit 297fff1a79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,6 +30,9 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_XPU
#include "xpu/refactor/math.h"
#endif
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
@ -81,12 +84,20 @@ class TensorAddFunctor : public boost::static_visitor<> {
blas.AXPY(numel_, 1., x_, y_); blas.AXPY(numel_, 1., x_, y_);
} }
#ifdef PADDLE_WITH_XPU
void operator()(const platform::XPUPlace& place) {
platform::XPUDeviceContext* ctx = dynamic_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
xpu::add<T>(ctx->x_context(), x_, y_, y_, static_cast<int>(numel_));
}
#else
void operator()(const platform::XPUPlace& place) { void operator()(const platform::XPUPlace& place) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) " "Gradient accumulation on place (%s) "
"is not supported in imperative mode", "is not supported in imperative mode",
place)); place));
} }
#endif
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
void operator()(const platform::CUDAPlace& place) { void operator()(const platform::CUDAPlace& place) {
@ -162,11 +173,14 @@ void TensorAdd(const framework::Variable& src, framework::Variable* dst) {
} }
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU
// NOTE(phlrain): xpu only support float
PADDLE_TENSOR_ADD(double); PADDLE_TENSOR_ADD(double);
// NOTE(chenweihang): only support complex grad tensor accumulated, // NOTE(chenweihang): only support complex grad tensor accumulated,
// support selected rows if needed in the future // support selected rows if needed in the future
PADDLE_TENSOR_ADD(platform::complex64); PADDLE_TENSOR_ADD(platform::complex64);
PADDLE_TENSOR_ADD(platform::complex128); PADDLE_TENSOR_ADD(platform::complex128);
#endif
#undef PADDLE_TENSOR_ADD #undef PADDLE_TENSOR_ADD

Loading…
Cancel
Save