|
|
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/softmax.h"
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -25,7 +25,8 @@ template <typename T>
|
|
|
|
|
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto* X = context.Input<Tensor>("X");
|
|
|
|
|
auto* Out = context.Output<Tensor>("Out");
|
|
|
|
|
const int axis = context.Attr<int>("axis");
|
|
|
|
@ -41,9 +42,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor X_trans, Out_trans;
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
X_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *X, &X_trans,
|
|
|
|
|
perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
|
|
|
|
|
&Out_trans, perm);
|
|
|
|
|
X_2d = framework::ReshapeToMatrix(X_trans, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
|
|
|
|
|
} else {
|
|
|
|
@ -52,11 +56,12 @@ class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::SoftmaxCUDNNFunctor<T>()(
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(),
|
|
|
|
|
&X_2d, &Out_2d);
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(), &X_2d,
|
|
|
|
|
&Out_2d);
|
|
|
|
|
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans, Out, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, Out_trans,
|
|
|
|
|
Out, perm);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -65,7 +70,8 @@ template <typename T>
|
|
|
|
|
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto& dev_ctx = context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto* Out = context.Input<Tensor>("Out");
|
|
|
|
|
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
@ -82,11 +88,16 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor dX_trans, Out_trans, dOut_trans;
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
dX_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
dOut_trans.mutable_data<T>(framework::make_ddim(shape), context.GetPlace());
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX, &dX_trans, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out, &Out_trans, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut, &dOut_trans, perm);
|
|
|
|
|
Out_trans.mutable_data<T>(framework::make_ddim(shape),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
dOut_trans.mutable_data<T>(framework::make_ddim(shape),
|
|
|
|
|
context.GetPlace());
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dX,
|
|
|
|
|
&dX_trans, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *Out,
|
|
|
|
|
&Out_trans, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, *dOut,
|
|
|
|
|
&dOut_trans, perm);
|
|
|
|
|
dX_2d = framework::ReshapeToMatrix(dX_trans, rank - 1);
|
|
|
|
|
Out_2d = framework::ReshapeToMatrix(Out_trans, rank - 1);
|
|
|
|
|
dOut_2d = framework::ReshapeToMatrix(dOut_trans, rank - 1);
|
|
|
|
@ -97,11 +108,12 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
math::SoftmaxGradCUDNNFunctor<T>()(
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(),
|
|
|
|
|
&Out_2d, &dOut_2d, &dX_2d);
|
|
|
|
|
context.template device_context<platform::CUDADeviceContext>(), &Out_2d,
|
|
|
|
|
&dOut_2d, &dX_2d);
|
|
|
|
|
|
|
|
|
|
if (axis != -1 && axis != rank - 1) {
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX, perm);
|
|
|
|
|
TransCompute<platform::CUDADeviceContext, T>(rank, dev_ctx, dX_trans, dX,
|
|
|
|
|
perm);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|