add lookup_table_dequant_op (#22900)
add lookup_table_dequant_oprevert-22710-feature/integrated_ps_api
parent
a020a25797
commit
5ba9dfc16a
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/lookup_table_dequant_op.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
|
||||
#include "paddle/fluid/framework/var_type_inference.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LookupTableDequantOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("W"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(W) of LookupTableDequantOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("Ids"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(Ids) of LookupTableDequantOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of LookupTableDequantOp should not be null."));
|
||||
|
||||
auto table_dims = ctx->GetInputDim("W");
|
||||
auto ids_dims = ctx->GetInputDim("Ids");
|
||||
int ids_rank = ids_dims.size();
|
||||
VLOG(5) << "ids rank is " << ids_rank << std::endl;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
table_dims.size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"ShapeError: The dimensions of the 'lookup table' must be 2. "
|
||||
"But received lookup table's dimensions = %d, "
|
||||
"lookup table's shape = [%s].",
|
||||
table_dims.size(), table_dims));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ids_dims[ids_rank - 1], 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
|
||||
"But received Ids's last dimensions = %d, Ids's shape = [%s].",
|
||||
ids_dims[ids_rank - 1], ids_dims));
|
||||
|
||||
auto output_dims =
|
||||
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
|
||||
PADDLE_ENFORCE_GE(table_dims[1], 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"the second dim of table_dims should be "
|
||||
"greater or equal to 2, but the actual shape "
|
||||
"is [%s]",
|
||||
table_dims));
|
||||
|
||||
output_dims.push_back((table_dims[1] - 2) * 4);
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
|
||||
|
||||
if (ctx->GetOutputsVarType("Out")[0] ==
|
||||
framework::proto::VarType::LOD_TENSOR) {
|
||||
ctx->ShareLoD("Ids", /*->*/ "Out");
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
|
||||
return framework::OpKernelType(data_type, ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class LookupTableDequantOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("W",
|
||||
"(Tensor) The input represents embedding tensors, "
|
||||
"This tensor is a quantized tensor");
|
||||
AddInput("Ids",
|
||||
"An input with type int64 "
|
||||
"contains the ids to be looked up in W. "
|
||||
"The last dimension size must be 1.");
|
||||
AddOutput("Out", "The lookup results, which have the same type as W.");
|
||||
AddAttr<int64_t>("padding_idx",
|
||||
"(int64, default -1) "
|
||||
"If the value is -1, it makes no effect to lookup. "
|
||||
"Otherwise the given value indicates padding the output "
|
||||
"with zeros whenever lookup encounters it in Ids.")
|
||||
.SetDefault(kNoPadding);
|
||||
AddComment(R"DOC(
|
||||
Lookup Table Dequant Operator.
|
||||
|
||||
The `W` input is a quantized parameter for the sake of saving memories.
|
||||
This operator first index embeddings with `Ids`,
|
||||
then dequantizes them and contact them as output (`Out`).
|
||||
|
||||
The input Ids can carry the LoD (Level of Details) information,
|
||||
or not. And the output only shares the LoD information with input Ids.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(
|
||||
lookup_table_dequant, ops::LookupTableDequantOp,
|
||||
ops::LookupTableDequantOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(lookup_table_dequant,
|
||||
ops::LookupTableDequantKernel<float>);
|
@ -0,0 +1,109 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/selected_rows.h"
|
||||
#include "paddle/fluid/framework/var_type_traits.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
|
||||
#ifdef PADDLE_WITH_DISTRIBUTE
|
||||
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
using LoDTensor = framework::LoDTensor;
|
||||
using SelectedRows = framework::SelectedRows;
|
||||
using DDim = framework::DDim;
|
||||
|
||||
template <typename T>
|
||||
void dequant(const unsigned char *in, T *out, float min, float max,
|
||||
int emb_size, int pow_2_bits) {
|
||||
float scale = (max - min) / pow_2_bits;
|
||||
for (int i = 0; i < emb_size; ++i) {
|
||||
T x = scale * static_cast<int>(in[i]) + min;
|
||||
out[i] = x;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int64_t kNoPadding = -1;
|
||||
|
||||
template <typename T>
|
||||
class LookupTableDequantKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
|
||||
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
|
||||
auto *table_var = context.InputVar("W");
|
||||
|
||||
auto id_name = context.InputNames("Ids").front();
|
||||
auto embedding_name = context.InputNames("W").front();
|
||||
auto out_name = context.OutputNames("Out").front();
|
||||
|
||||
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
|
||||
auto *ids = ids_t->data<int64_t>();
|
||||
int64_t ids_numel = ids_t->numel();
|
||||
|
||||
PADDLE_ENFORCE_GE(
|
||||
table_var->Type(), framework::VarTypeTrait<LoDTensor>::kId,
|
||||
platform::errors::InvalidArgument("lookup table must be LodTensor"));
|
||||
auto *table_t = context.Input<LoDTensor>("W");
|
||||
int64_t row_number = table_t->dims()[0];
|
||||
int64_t quant_number = table_t->dims()[1];
|
||||
int64_t row_width = (quant_number - 2) * 4;
|
||||
|
||||
auto *table = table_t->data<float>();
|
||||
auto *output = output_t->mutable_data<T>(context.GetPlace());
|
||||
int pow_2_bits = static_cast<int>(pow(2, 8));
|
||||
for (int64_t i = 0; i < ids_numel; ++i) {
|
||||
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
|
||||
memset(output + i * row_width, 0, row_width * sizeof(T));
|
||||
} else {
|
||||
PADDLE_ENFORCE_LT(
|
||||
ids[i], row_number,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
row_number, ids[i]));
|
||||
PADDLE_ENFORCE_GE(
|
||||
ids[i], 0,
|
||||
platform::errors::InvalidArgument(
|
||||
"Variable value (input) of OP(fluid.layers.embedding) "
|
||||
"expected >= 0 and < %ld, but got %ld. Please check input "
|
||||
"value.",
|
||||
row_number, ids[i]));
|
||||
float min = *(table + ids[i] * quant_number);
|
||||
float max = *(table + ids[i] * quant_number + 1);
|
||||
int offset = ids[i] * quant_number + 2;
|
||||
const unsigned char *tensor_buf =
|
||||
reinterpret_cast<const unsigned char *>(table + offset);
|
||||
dequant(tensor_buf, output + i * row_width, min, max, row_width,
|
||||
pow_2_bits);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,55 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid.core as core
|
||||
from paddle.fluid.op import Operator
|
||||
import paddle.compat as cpt
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
import struct
|
||||
|
||||
|
||||
class TestLookupTableDequantOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "lookup_table_dequant"
|
||||
table = np.random.random((17, 32)).astype("float32")
|
||||
ids = np.random.randint(0, 17, 4).astype("int64")
|
||||
ids_expand = np.expand_dims(ids, axis=1)
|
||||
self.inputs = {'W': table, 'Ids': ids_expand}
|
||||
|
||||
# calculate output
|
||||
output = []
|
||||
for id in ids:
|
||||
tmp = []
|
||||
min, max = table[id][0], table[id][1]
|
||||
for val in table[id][2:]:
|
||||
tmp += [
|
||||
int(x) * (max - min) / pow(2, 8) + min
|
||||
for x in bytearray(struct.pack("f", val))
|
||||
]
|
||||
output.append(tmp)
|
||||
|
||||
self.outputs = {'Out': np.asarray(output, dtype="float32")}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue