fix unittest. test=develop

revert-16555-model_data_cryption_link_all_lib
dengkaipeng 6 years ago
parent 90bd038d35
commit d54005a7f4

@ -40,10 +40,12 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size()-1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, -1, logits, softmax);
dev_ctx, axis_dim, logits, softmax);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));

Loading…
Cancel
Save