python3
fengjiayi 7 years ago
parent 855c9e3311
commit 5e7aa8c7e5

@ -59,6 +59,14 @@ inline T* Tensor::mutable_data(platform::Place place) {
}
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
int rank = src.dims().size();
PADDLE_ENFORCE_GE(
rank, 2,
"'ReshapeToMatrix()' is only used for flatten high rank "
"tensors to matrixs. Can not be used in reshaping vectors.");
if (rank == 2) {
return src;
}
Tensor res;
res.ShareDataWith(src);
res.Resize(flatten_to_2d(src.dims(), num_col_dims));

@ -45,11 +45,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
"Input(Label) should be 1.");
}
auto out_dim_vec =
framework::vectorize(framework::slice_ddim(x_dims, 0, rank - 1));
out_dim_vec.push_back(1);
ctx->SetOutputDim("Y", framework::make_ddim(out_dim_vec));
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}

@ -34,10 +34,9 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size();
Tensor x_2d = rank > 2 ? framework::ReshapeToMatrix(*x, rank - 1) : *x;
Tensor labels_2d =
rank > 2 ? framework::ReshapeToMatrix(*labels, rank - 1) : *labels;
Tensor y_2d = rank > 2 ? framework::ReshapeToMatrix(*y, rank - 1) : *y;
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
math::CrossEntropyFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &y_2d, &x_2d, &labels_2d,

@ -32,9 +32,8 @@ class SoftmaxKernel : public framework::OpKernel<T> {
Out->mutable_data<T>(context.GetPlace());
int rank = X->dims().size();
Tensor X_2d = rank > 2 ? framework::ReshapeToMatrix(*X, rank - 1) : *X;
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
Tensor X_2d = framework::ReshapeToMatrix(*X, rank - 1);
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &X_2d, &Out_2d);
@ -53,11 +52,9 @@ class SoftmaxGradKernel : public framework::OpKernel<T> {
dX->mutable_data<T>(context.GetPlace());
int rank = Out->dims().size();
Tensor Out_2d =
rank > 2 ? framework::ReshapeToMatrix(*Out, rank - 1) : *Out;
Tensor dOut_2d =
rank > 2 ? framework::ReshapeToMatrix(*dOut, rank - 1) : *dOut;
Tensor dX_2d = rank > 2 ? framework::ReshapeToMatrix(*dX, rank - 1) : *dX;
Tensor Out_2d = framework::ReshapeToMatrix(*Out, rank - 1);
Tensor dOut_2d = framework::ReshapeToMatrix(*dOut, rank - 1);
Tensor dX_2d = framework::ReshapeToMatrix(*dX, rank - 1);
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &Out_2d, &dOut_2d,

Loading…
Cancel
Save