|
|
@ -26,9 +26,6 @@ class HashOp : public framework::OperatorWithKernel {
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
"Input(X) of HashOp should not be null.");
|
|
|
|
"Input(X) of HashOp should not be null.");
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
@ -57,6 +54,8 @@ $$Out = scale * X$$
|
|
|
|
)DOC");
|
|
|
|
)DOC");
|
|
|
|
AddAttr<int>("num_hash", "").SetDefault(1);
|
|
|
|
AddAttr<int>("num_hash", "").SetDefault(1);
|
|
|
|
AddAttr<int>("mod_by", "").SetDefault(100000);
|
|
|
|
AddAttr<int>("mod_by", "").SetDefault(100000);
|
|
|
|
|
|
|
|
AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape, "")
|
|
|
|
|
|
|
|
.SetDefault(true);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|