fix formax. test=develop

revert-16555-model_data_cryption_link_all_lib
dengkaipeng 6 years ago
parent d54005a7f4
commit ceb31d30f0

@ -40,7 +40,7 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size()-1];
int axis_dim = logits->dims()[logits->dims().size() - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();

@ -67,9 +67,10 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
softmax_logits.set_lod(logits_lod);
int rank = logits->dims().size();
int axis_dim = logits->dims()[rank - 1];
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, -1, &in_2d,
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, axis_dim, &in_2d,
&out_2d);
// ctc needs sequences data stored in transposed padding format

Loading…
Cancel
Save