You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
326 lines
12 KiB
326 lines
12 KiB
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
#pragma once
|
|
#ifdef PADDLE_WITH_XPU
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include <tuple>
|
|
#include <utility>
|
|
#include <vector>
|
|
#include "paddle/fluid/framework/tensor.h"
|
|
#include "paddle/fluid/platform/place.h"
|
|
#include "xpu/refactor/math.h"
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
static std::pair<std::vector<int>, std::vector<int>> XPUDimsToBroadcastVector(
|
|
const framework::DDim& x, const framework::DDim& y) {
|
|
std::vector<int> x_v;
|
|
std::vector<int> y_v;
|
|
int y_size = y.size();
|
|
for (int i = 0; i < y_size; ++i) {
|
|
if (x[i] == y[i]) {
|
|
x_v.push_back(y[i]);
|
|
y_v.push_back(y[i]);
|
|
continue;
|
|
}
|
|
x_v.push_back(1);
|
|
x_v.push_back(x[i]);
|
|
y_v.push_back(y[i] / x[i]);
|
|
y_v.push_back(x[i]);
|
|
}
|
|
return std::make_pair(x_v, y_v);
|
|
}
|
|
|
|
static std::pair<std::vector<int>, std::vector<int>> XPUReducesAxisVector(
|
|
const framework::DDim& x, const framework::DDim& y) {
|
|
std::vector<int> x_vector;
|
|
std::vector<int> axis_v;
|
|
PADDLE_ENFORCE_GT(
|
|
x.size(), 0, platform::errors::OutOfRange("x size is less 1, x shape is ",
|
|
x.to_str()));
|
|
PADDLE_ENFORCE_GT(
|
|
y.size(), 0, platform::errors::OutOfRange("y size is less 1, y shape is ",
|
|
y.to_str()));
|
|
|
|
int y_nums = framework::product(y);
|
|
x_vector = framework::vectorize<int>(x);
|
|
if (y_nums == 1) {
|
|
for (int i = 0; i < x.size(); ++i) {
|
|
axis_v.push_back(i);
|
|
}
|
|
return std::make_pair(x_vector, axis_v);
|
|
}
|
|
int yidx = 0;
|
|
for (size_t i = 0; i < x_vector.size(); ++i) {
|
|
if (yidx >= y.size() || y[yidx] == 1) {
|
|
axis_v.push_back(i);
|
|
yidx++;
|
|
continue;
|
|
}
|
|
if (x_vector[i] != y[yidx]) {
|
|
axis_v.push_back(i);
|
|
continue;
|
|
}
|
|
yidx++;
|
|
}
|
|
return std::make_pair(x_vector, axis_v);
|
|
}
|
|
|
|
template <typename T>
|
|
void XPUElementwise(
|
|
const framework::ExecutionContext& ctx,
|
|
std::function<int(xpu::Context*, const T*, const T*, T*, int)> func) {
|
|
auto x_var = ctx.InputVar("X");
|
|
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
|
|
"Cannot get input Variable X"));
|
|
PADDLE_ENFORCE_EQ(
|
|
x_var->IsType<framework::LoDTensor>(), true,
|
|
platform::errors::InvalidArgument(
|
|
"XPU only support LoDTensor, Input(X) is not LoDTensor"));
|
|
|
|
auto x = x_var->Get<framework::LoDTensor>();
|
|
auto* y = ctx.Input<framework::LoDTensor>("Y");
|
|
auto* z = ctx.Output<framework::LoDTensor>("Out");
|
|
z->mutable_data<T>(ctx.GetPlace());
|
|
auto x_dims = x.dims();
|
|
auto y_dims = y->dims();
|
|
int max_dim = std::max(x_dims.size(), y_dims.size());
|
|
int axis = ctx.Attr<int>("axis");
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
axis, 0,
|
|
platform::errors::InvalidArgument(
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
axis));
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
platform::errors::InvalidArgument(
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
max_dim, axis));
|
|
|
|
std::vector<int> x_dims_array(max_dim);
|
|
std::vector<int> y_dims_array(max_dim);
|
|
std::vector<int> out_dims_array(max_dim);
|
|
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
|
y_dims_array.data(), out_dims_array.data(), max_dim,
|
|
axis);
|
|
framework::DDim out_dim = framework::make_ddim(out_dims_array);
|
|
|
|
const T* x_data = x.data<T>();
|
|
const T* y_data = y->data<T>();
|
|
T* z_data = z->data<T>();
|
|
bool need_wait = false;
|
|
framework::Tensor x_broadcast_tensor;
|
|
framework::Tensor y_broadcast_tensor;
|
|
auto& dev_ctx =
|
|
ctx.template device_context<paddle::platform::XPUDeviceContext>();
|
|
int ret = xpu::SUCCESS;
|
|
// begin broadcast now
|
|
if (x.numel() != z->numel()) {
|
|
// broadcast x
|
|
std::pair<std::vector<int>, std::vector<int>> bcast_v =
|
|
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
|
|
|
|
ret = xpu::broadcast<T>(dev_ctx.x_context(), x_data,
|
|
x_broadcast_tensor.mutable_data<T>(
|
|
ctx.GetPlace(), z->numel() * sizeof(T)),
|
|
bcast_v.first, bcast_v.second);
|
|
PADDLE_ENFORCE_EQ(
|
|
ret, xpu::SUCCESS,
|
|
platform::errors::External(
|
|
"XPU kernel broadcast occur error in XPUElementwise error code %d",
|
|
ret));
|
|
need_wait = true;
|
|
x_data = x_broadcast_tensor.data<T>();
|
|
}
|
|
|
|
if (y->numel() != z->numel()) {
|
|
// broadcast y
|
|
std::vector<int> bcast_x_v;
|
|
std::vector<int> bcast_y_v;
|
|
std::pair<std::vector<int>, std::vector<int>> bcast_v =
|
|
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
|
|
ret = xpu::broadcast<T>(dev_ctx.x_context(), y_data,
|
|
y_broadcast_tensor.mutable_data<T>(
|
|
ctx.GetPlace(), z->numel() * sizeof(T)),
|
|
bcast_v.first, bcast_v.second);
|
|
PADDLE_ENFORCE_EQ(
|
|
ret, xpu::SUCCESS,
|
|
platform::errors::External(
|
|
"XPU kernel broadcast occur error in XPUElementwise error code %d",
|
|
ret));
|
|
need_wait = true;
|
|
y_data = y_broadcast_tensor.data<T>();
|
|
}
|
|
int len = z->numel();
|
|
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, len);
|
|
PADDLE_ENFORCE_EQ(
|
|
ret, xpu::SUCCESS,
|
|
platform::errors::External(
|
|
"XPU kernel Elementwise occur error in XPUElementwise error code ",
|
|
ret));
|
|
|
|
if (need_wait && dev_ctx.x_context()->xpu_stream) {
|
|
dev_ctx.Wait();
|
|
}
|
|
}
|
|
|
|
template <typename T>
|
|
void XPUElementwiseGrad(const framework::ExecutionContext& ctx,
|
|
std::function<int(xpu::Context*, const T*, const T*,
|
|
const T*, const T*, T*, T*, int len)>
|
|
func,
|
|
bool use_x_y_data) {
|
|
auto* x = ctx.Input<framework::Tensor>("X");
|
|
auto* y = ctx.Input<framework::Tensor>("Y");
|
|
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
auto* z = dz;
|
|
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
|
|
int axis = ctx.Attr<int>("axis");
|
|
const framework::DDim& x_dims = x->dims();
|
|
const framework::DDim& y_dims = y->dims();
|
|
int max_dim = std::max(x_dims.size(), y_dims.size());
|
|
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
|
|
PADDLE_ENFORCE_GE(
|
|
axis, 0,
|
|
platform::errors::InvalidArgument(
|
|
"Axis should be great than or equal to 0, but received axis is %d.",
|
|
axis));
|
|
PADDLE_ENFORCE_LT(axis, max_dim,
|
|
platform::errors::InvalidArgument(
|
|
"Axis should be less than %d, but received axis is %d.",
|
|
max_dim, axis));
|
|
|
|
std::vector<int> x_dims_array(max_dim);
|
|
std::vector<int> y_dims_array(max_dim);
|
|
std::vector<int> out_dims_array(max_dim);
|
|
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
|
|
y_dims_array.data(), out_dims_array.data(), max_dim,
|
|
axis);
|
|
framework::DDim out_dim = framework::make_ddim(out_dims_array);
|
|
|
|
int len = framework::product(out_dim);
|
|
|
|
framework::Tensor x_broadcast_tensor;
|
|
framework::Tensor y_broadcast_tensor;
|
|
|
|
framework::Tensor dx_local_tensor;
|
|
framework::Tensor dy_local_tensor;
|
|
|
|
bool need_wait = false;
|
|
const T* x_data = use_x_y_data ? x->data<T>() : z->data<T>();
|
|
const T* y_data = use_x_y_data ? y->data<T>() : z->data<T>();
|
|
|
|
const T* z_data = z->data<T>();
|
|
const T* dz_data = (const T*)dz->data<T>();
|
|
|
|
bool dx_need_reduce = (dx != nullptr) && (dx->numel() != len);
|
|
bool dy_need_reduce = (dy != nullptr) && (dy->numel() != len);
|
|
|
|
T* dx_data =
|
|
((dx == nullptr) || dx_need_reduce)
|
|
? (dx_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
|
|
: (dx->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
T* dy_data =
|
|
((dy == nullptr) || dy_need_reduce)
|
|
? (dy_local_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)))
|
|
: (dy->mutable_data<T>(ctx.GetPlace()));
|
|
|
|
int ret = xpu::SUCCESS;
|
|
auto& dev_ctx =
|
|
ctx.template device_context<paddle::platform::XPUDeviceContext>();
|
|
|
|
if (use_x_y_data && x->numel() != len) {
|
|
std::vector<int> bcast_x_v;
|
|
std::vector<int> bcast_y_v;
|
|
std::pair<std::vector<int>, std::vector<int>> bcast_v =
|
|
XPUDimsToBroadcastVector(framework::make_ddim(x_dims_array), out_dim);
|
|
ret = xpu::broadcast<T>(
|
|
dev_ctx.x_context(), x_data,
|
|
x_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
|
|
bcast_v.first, bcast_v.second);
|
|
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
|
|
platform::errors::External(
|
|
"XPU kernel broadcast error occur! %d", ret));
|
|
need_wait = true;
|
|
x_data = x_broadcast_tensor.data<T>();
|
|
}
|
|
|
|
if (use_x_y_data && y->numel() != len) {
|
|
// broadcast y
|
|
std::vector<int> bcast_x_v;
|
|
std::vector<int> bcast_y_v;
|
|
std::pair<std::vector<int>, std::vector<int>> bcast_v =
|
|
XPUDimsToBroadcastVector(framework::make_ddim(y_dims_array), out_dim);
|
|
ret = xpu::broadcast<T>(
|
|
dev_ctx.x_context(), y_data,
|
|
y_broadcast_tensor.mutable_data<T>(ctx.GetPlace(), len * sizeof(T)),
|
|
bcast_v.first, bcast_v.second);
|
|
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS,
|
|
platform::errors::External(
|
|
"XPU kernel broadcast error occur! %d", ret));
|
|
need_wait = true;
|
|
y_data = y_broadcast_tensor.data<T>();
|
|
}
|
|
|
|
ret = func(dev_ctx.x_context(), x_data, y_data, z_data, dz_data, dx_data,
|
|
dy_data, len);
|
|
PADDLE_ENFORCE_EQ(ret, xpu::SUCCESS, platform::errors::External(
|
|
"XPU kernel binary occur error in "
|
|
"XPUElementwiseGrad, error code %d",
|
|
ret));
|
|
|
|
if (dx_need_reduce) {
|
|
const framework::DDim& dx_dims = dx->dims();
|
|
std::pair<std::vector<int>, std::vector<int>> reduce_v =
|
|
XPUReducesAxisVector(out_dim, dx_dims);
|
|
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dx_data,
|
|
dx->mutable_data<T>(ctx.GetPlace()),
|
|
reduce_v.first, reduce_v.second);
|
|
PADDLE_ENFORCE_EQ(
|
|
ret, xpu::SUCCESS,
|
|
platform::errors::External("XPU kernel reduce_sum occur error in "
|
|
"XPUElementwiseGrad, error code %d",
|
|
ret));
|
|
need_wait = true;
|
|
}
|
|
|
|
if (dy_need_reduce) {
|
|
const framework::DDim& dy_dims = dy->dims();
|
|
std::pair<std::vector<int>, std::vector<int>> reduce_v =
|
|
XPUReducesAxisVector(out_dim, dy_dims);
|
|
ret = xpu::reduce_sum<T>(dev_ctx.x_context(), dy_data,
|
|
dy->mutable_data<T>(ctx.GetPlace()),
|
|
reduce_v.first, reduce_v.second);
|
|
PADDLE_ENFORCE_EQ(
|
|
ret, xpu::SUCCESS,
|
|
platform::errors::External("XPU kernel reduce_sum occur error in "
|
|
"XPUElementwiseGrad, error code %d",
|
|
ret));
|
|
need_wait = true;
|
|
}
|
|
|
|
if (need_wait && dev_ctx.x_context()->xpu_stream) {
|
|
dev_ctx.Wait();
|
|
}
|
|
}
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|
|
#endif
|