|
|
@ -136,7 +136,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
|
|
|
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
|
|
|
|
sum.mutable_data<T>(framework::make_ddim(sum_dims), ctx.GetPlace());
|
|
|
|
auto sum_mat = EigenMatrix<T>::From(sum);
|
|
|
|
auto sum_mat = EigenMatrix<T>::From(sum);
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
auto out_mat = framework::EigenVector<T>::Flatten(*out);
|
|
|
|
auto out_mat = framework::EigenMatrix<T>::From(*out);
|
|
|
|
if (bias) {
|
|
|
|
if (bias) {
|
|
|
|
bit_code->Add(*bias, pre_out);
|
|
|
|
bit_code->Add(*bias, pre_out);
|
|
|
|
}
|
|
|
|
}
|
|
|
|