|
|
|
@ -23,59 +23,58 @@ namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
static inline void TransposeAxisToEnd(const Tensor& x, const Tensor& out,
|
|
|
|
|
Tensor* x_trans, Tensor* out_trans,
|
|
|
|
|
const int axis, std::vector<int> perm,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
static inline void CalcTransPermAndShapeByAxis(const Tensor& x, const int axis,
|
|
|
|
|
std::vector<int>* perm, std::vector<int>* shape) {
|
|
|
|
|
auto dim_x = x.dims();
|
|
|
|
|
int rank = dim_x.size();
|
|
|
|
|
|
|
|
|
|
if (axis == -1 || axis == rank - 1) {
|
|
|
|
|
*x_trans = x;
|
|
|
|
|
*out_trans = out;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
std::vector<int> shape;
|
|
|
|
|
for (int i = 0; i < rank - 1; i++) {
|
|
|
|
|
if (i == axis) {
|
|
|
|
|
perm.push_back(rank - 1);
|
|
|
|
|
shape.push_back(dim_x[rank - 1]);
|
|
|
|
|
perm->push_back(rank - 1);
|
|
|
|
|
shape->push_back(dim_x[rank - 1]);
|
|
|
|
|
} else {
|
|
|
|
|
perm.push_back(i);
|
|
|
|
|
shape.push_back(dim_x[i]);
|
|
|
|
|
perm->push_back(i);
|
|
|
|
|
shape->push_back(dim_x[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
perm.push_back(axis);
|
|
|
|
|
shape.push_back(dim_x[axis]);
|
|
|
|
|
|
|
|
|
|
x_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
|
|
|
|
|
out_trans->mutable_data<T>(framework::make_ddim(shape), ctx.GetPlace());
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, x, x_trans, perm);
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, out, out_trans, perm);
|
|
|
|
|
perm->push_back(axis);
|
|
|
|
|
shape->push_back(dim_x[axis]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class SoftmaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
auto* X = context.Input<Tensor>("X");
|
|
|
|
|
auto* Out = context.Output<Tensor>("Out");
|
|
|
|
|
const int axis = context.Attr<int>("axis");
|
|
|
|
|
int rank = X->dims().size();
|
|
|
|
|
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
std::vector<int> perm, shape;
|
|
|
|
|
CalcTransPermAndShapeByAxis(*X, axis, &perm, &shape);
|
|
|
|
|
|
|
|
|
|
Tensor X_2d, Out_2d;
|
|
|
|
|
Tensor X_trans, Out_trans;
|
|
|
|
|
std::vector<int> perm;
|
|
|
|
|
TransposeAxisToEnd<DeviceContext, T>(*X, *Out, &X_trans, &Out_trans, axis,
|
|
|
|
|
perm, context);
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
|
|
|
|
|
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
|
|
|
|
|
} else {
|
|
|
|
|
X_2d = framework::ReshapeToMatrix(*X, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int rank = X->dims().size();
|
|
|
|
|
Tensor X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
|
|
|
|
|
Tensor Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_ON_INFERENCE
|
|
|
|
|
math::SoftmaxFunctor<DeviceContext, T, true>()(
|
|
|
|
@ -86,7 +85,6 @@ class SoftmaxKernel : public framework::OpKernel<T> {
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -96,21 +94,44 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class SoftmaxGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
auto* Out = context.Input<Tensor>("Out");
|
|
|
|
|
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
const int axis = context.Attr<int>("axis");
|
|
|
|
|
int rank = Out->dims().size();
|
|
|
|
|
|
|
|
|
|
// allocate memory on device.
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int rank = Out->dims().size();
|
|
|
|
|
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
|
|
|
|
|
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
|
|
|
|
|
std::vector<int> perm, shape;
|
|
|
|
|
CalcTransPermAndShapeByAxis(*dX, axis, &perm, &shape);
|
|
|
|
|
|
|
|
|
|
Tensor dX_2d, Out_2d, dOut_2d;
|
|
|
|
|
Tensor dX_trans, Out_trans, dOut_trans;
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
|
|
|
|
|
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
|
|
|
|
|
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
|
|
|
|
|
} else {
|
|
|
|
|
dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
|
|
|
|
|
dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::SoftmaxGradFunctor<DeviceContext, T>()(
|
|
|
|
|
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,
|
|
|
|
|
&dX_2d);
|
|
|
|
|
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
TransCompute<DeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|