|
|
|
@ -42,13 +42,13 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
int64_t code_length = math::FindLastSet(num_classes - 1);
|
|
|
|
|
int64_t batch_size = in->dims()[0];
|
|
|
|
|
framework::Tensor sum;
|
|
|
|
|
math::SetConstant<DeviceContext, T> zero;
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
auto* pre_out_data = pre_out->mutable_data<T>(
|
|
|
|
|
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
|
|
|
|
|
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
|
|
|
|
// Not all class(leaf) nodes' path lengths equal code_length, thus init as
|
|
|
|
|
// 0s can avoid out of path's loss.
|
|
|
|
|
math::SetConstant<DeviceContext, T> zero;
|
|
|
|
|
zero(dev_ctx, pre_out, static_cast<T>(0.0));
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
math::RowwiseSum<DeviceContext, T> row_sum;
|
|
|
|
@ -72,6 +72,10 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// use softrelu to calculate cross entropy
|
|
|
|
|
pre_out_mat.device(place) = (static_cast<T>(1.0) + pre_out_mat.exp()).log();
|
|
|
|
|
row_sum(dev_ctx, *pre_out, &sum);
|
|
|
|
|
// TODO(guosheng): Subtract the out of path's loss, since not all
|
|
|
|
|
// class(leaf) nodes' path lengths equal code_length. But it won't break the
|
|
|
|
|
// gradient check since both have the out of path's loss and will cancel out
|
|
|
|
|
// each other.
|
|
|
|
|
out_mat.device(place) = sum_mat + out_mat;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -90,33 +94,38 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* pre_out = ctx.Input<framework::Tensor>("PreOut");
|
|
|
|
|
auto* out_grad =
|
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
framework::Tensor pre_out_grad;
|
|
|
|
|
|
|
|
|
|
pre_out_grad.mutable_data<T>(pre_out->dims(), ctx.GetPlace());
|
|
|
|
|
in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
w_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
math::SetConstant<DeviceContext, T> zero;
|
|
|
|
|
zero(dev_ctx, in_grad, static_cast<T>(0.0));
|
|
|
|
|
zero(dev_ctx, w_grad, static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
|
|
|
|
|
int64_t code_length = math::FindLastSet(num_classes - 1);
|
|
|
|
|
int64_t batch_size = in->dims()[0];
|
|
|
|
|
framework::Tensor pre_out_grad;
|
|
|
|
|
pre_out_grad.mutable_data<T>(
|
|
|
|
|
framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
|
|
|
|
|
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
|
|
|
|
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto pre_out_mat = EigenMatrix<T>::From(*pre_out);
|
|
|
|
|
auto pre_out_grad_mat = EigenMatrix<T>::From(pre_out_grad);
|
|
|
|
|
math::MatrixBitCodeFunctor<T> bit_code(num_classes, label->data<int64_t>());
|
|
|
|
|
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
|
|
|
|
|
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
|
|
|
|
|
pre_out_grad_mat = out_grad_mat.broadcast(bcast);
|
|
|
|
|
Eigen::array<int, 2> bcast({{1, static_cast<int>(pre_out_grad.dims()[1])}});
|
|
|
|
|
|
|
|
|
|
// softrelu derivative
|
|
|
|
|
pre_out_grad_mat.device(place) =
|
|
|
|
|
static_cast<T>(1.0) - static_cast<T>(1.0) / pre_out_mat.exp();
|
|
|
|
|
bit_code.Sub(&pre_out_grad); // the gradient of clip(w * x + b)
|
|
|
|
|
pre_out_grad_mat.device(place) =
|
|
|
|
|
pre_out_grad_mat *
|
|
|
|
|
(static_cast<T>(1.0) -
|
|
|
|
|
static_cast<T>(1.0) / pre_out_mat.exp()); // softrelu derivative
|
|
|
|
|
bit_code.Sub(&pre_out_grad);
|
|
|
|
|
pre_out_grad_mat * out_grad_mat.broadcast(bcast);
|
|
|
|
|
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
|
|
|
|
|
// be consistent with the clipping in forward.
|
|
|
|
|
if (bias_grad) {
|
|
|
|
|
bias_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
zero(dev_ctx, bias_grad, static_cast<T>(0.0));
|
|
|
|
|
bit_code.AddGrad(pre_out_grad, bias_grad);
|
|
|
|
|
}
|
|
|
|
|
in_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
w_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
bit_code.MulGradWeight(pre_out_grad, w_grad, *in);
|
|
|
|
|
bit_code.MulGradError(pre_out_grad, *w, in_grad);
|
|
|
|
|
}
|
|
|
|
|