|
|
|
@ -86,7 +86,6 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
trans(ctx.template device_context<DeviceContext>(), pre_out_data,
|
|
|
|
|
pre_out_data + pre_out->numel(), pre_out_data,
|
|
|
|
|
ClipFunctor<T>(static_cast<T>(-40.0), static_cast<T>(40.0)));
|
|
|
|
|
pre_out_mat = -1 * pre_out_mat;
|
|
|
|
|
bit_code->Sum(*pre_out, out, static_cast<T>(-1));
|
|
|
|
|
// use softrelu to calculate cross entropy
|
|
|
|
|
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
|
|
|
|
@ -162,16 +161,9 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
bias_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
|
|
|
|
|
bit_code->AddGrad(pre_out_grad, bias_grad);
|
|
|
|
|
auto bias_grad_mat = EigenMatrix<T>::From(*bias_grad);
|
|
|
|
|
bias_grad_mat = -1 * bias_grad_mat;
|
|
|
|
|
}
|
|
|
|
|
bit_code->MulGradWeight(pre_out_grad, w_grad, *in);
|
|
|
|
|
bit_code->MulGradError(pre_out_grad, *w, in_grad);
|
|
|
|
|
auto w_grad_mat = EigenMatrix<T>::From(*w_grad);
|
|
|
|
|
auto in_grad_mat = EigenMatrix<T>::From(*in_grad);
|
|
|
|
|
|
|
|
|
|
w_grad_mat = -1 * w_grad_mat;
|
|
|
|
|
in_grad_mat = -1 * in_grad_mat;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|