|
|
|
@ -62,7 +62,7 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
|
|
|
|
@ -87,19 +87,18 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
void Make() override {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor, required) The input Tensor, which the shape is"
|
|
|
|
|
"[N * D], which N is the size of mini-batch,"
|
|
|
|
|
"[N, D], which N is the size of mini-batch,"
|
|
|
|
|
"D is the embded size");
|
|
|
|
|
AddInput("W",
|
|
|
|
|
"(Tensor, required), The parameters of hierarchical "
|
|
|
|
|
"sigmoid operator, each of them is s a 3-D tensor, the shape is"
|
|
|
|
|
"sigmoid operator, each of them is s a 2-D tensor, the shape is"
|
|
|
|
|
"[num_classes - 1, D]");
|
|
|
|
|
AddInput("Ids",
|
|
|
|
|
AddInput("Label",
|
|
|
|
|
"(Tensor, required), The labels of training data. It's a"
|
|
|
|
|
"1-D tensor, which the shape is [1, N]");
|
|
|
|
|
AddInput("Bias",
|
|
|
|
|
"(Tensor, optional), The bias is a 1-D tensor, "
|
|
|
|
|
"which is applied to the output, the shape is"
|
|
|
|
|
"[1, num_classes -1]");
|
|
|
|
|
"(Tensor, optional), The bias is a tensor with shape"
|
|
|
|
|
"[1, num_classes - 1]");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor, required) The output of hierarchical sigmoid operator."
|
|
|
|
|
"the shape is [N, 1]");
|
|
|
|
@ -111,7 +110,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
.SetDefault(2);
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
The hierarchical sigmoid operator organize the classes into a binary tree.
|
|
|
|
|
At each node, a sigmoid function is used to caculate the probability of
|
|
|
|
|
At each node, a sigmoid function is used to calculate the probability of
|
|
|
|
|
belonging to the right branch. This idea is from
|
|
|
|
|
"F. Morin, Y. Bengio (AISTATS 05):
|
|
|
|
|
Hierarchical Probabilistic Neural Network Language Model."
|
|
|
|
@ -124,7 +123,7 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("PreOut"),
|
|
|
|
|
"Input(Preout) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("W")),
|
|
|
|
@ -155,9 +154,14 @@ REGISTER_OPERATOR(hierarchical_sigmoid, ops::HierarchicalSigmoidOp,
|
|
|
|
|
ops::HierarchicalSigmoidOpMaker<int>,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(hierarchical_sigmoid_grad, ops::HierarchicalSigmoidGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid,
|
|
|
|
|
ops::HierarchicalSigmoidOpKernel<
|
|
|
|
|
paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(hierarchical_sigmoid_grad,
|
|
|
|
|
ops::HierarchicalSigmoidGradOpKernel<
|
|
|
|
|
paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
hierarchical_sigmoid,
|
|
|
|
|
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::HierarchicalSigmoidOpKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
double>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
hierarchical_sigmoid_grad,
|
|
|
|
|
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
float>,
|
|
|
|
|
ops::HierarchicalSigmoidGradOpKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
double>);
|
|
|
|
|