|
|
|
@ -163,8 +163,12 @@ class CPUPyramidHashOPKernel : public framework::OpKernel<T> {
|
|
|
|
|
int _space_len) const {
|
|
|
|
|
for (unsigned int j = 0; j != _num_emb; j += _rand_len) {
|
|
|
|
|
unsigned int pos = XXH32(hash_id, len * sizeof(T), j) % _space_len;
|
|
|
|
|
memcpy(top_pos + j, const_cast<float*>(weights + pos),
|
|
|
|
|
_rand_len * sizeof(T));
|
|
|
|
|
if (_rand_len == 16) {
|
|
|
|
|
memcpy(top_pos + j, const_cast<float*>(weights + pos), 16 * sizeof(T));
|
|
|
|
|
} else {
|
|
|
|
|
memcpy(top_pos + j, const_cast<float*>(weights + pos),
|
|
|
|
|
_rand_len * sizeof(T));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -322,6 +326,8 @@ class PyramidHashOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true, "Input(W) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("DropPos"), true,
|
|
|
|
|
"Input(DropPos) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X_Temp_Out"), true,
|
|
|
|
|
"Input(X_Temp_Out) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) of PyramidHashGradOp should not be null.");
|
|
|
|
@ -347,6 +353,7 @@ class PyramidHashGradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
op_desc_ptr->SetInput("X", this->Input("X"));
|
|
|
|
|
op_desc_ptr->SetInput("W", this->Input("W"));
|
|
|
|
|
op_desc_ptr->SetInput("DropPos", this->Output("DropPos"));
|
|
|
|
|
op_desc_ptr->SetInput("X_Temp_Out", this->Output("X_Temp_Out"));
|
|
|
|
|
|
|
|
|
|
op_desc_ptr->SetInput(framework::GradVarName("Out"),
|
|
|
|
|
this->OutputGrad("Out"));
|
|
|
|
@ -380,13 +387,8 @@ class CPUPyramidHashOPGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int _space_len = ctx.Attr<int>("space_len");
|
|
|
|
|
int _pyramid_layer = ctx.Attr<int>("pyramid_layer");
|
|
|
|
|
|
|
|
|
|
const auto* bottom_data_ori = bottom->data<int32_t>();
|
|
|
|
|
Tensor buff;
|
|
|
|
|
buff.Resize(framework::make_ddim({bottom->dims()[0], bottom->dims()[1]}));
|
|
|
|
|
T* bottom_data = buff.mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
for (size_t i = 0; i < bottom->dims()[0]; i++) {
|
|
|
|
|
bottom_data[i] = bottom_data_ori[i];
|
|
|
|
|
}
|
|
|
|
|
auto* buff = ctx.Input<LoDTensor>("X_Temp_Out");
|
|
|
|
|
auto* bottom_data = buff->data<T>();
|
|
|
|
|
|
|
|
|
|
int _slot_len = bottom->dims()[0];
|
|
|
|
|
if (_slot_len == bottom->lod()[0].size() - 1 &&
|
|
|
|
|