add dequantize_abs_max op and modify lookup_table op (#20899)
* add int8 kernel to lookup_table op and add dequantize op test=develop * change paddle_enforce to paddle_enforce_eq test=develop * change copyright and change some not suitable code test=develop * remove debug log test=develop * replace GetInputType with IndicateVarDataType test=develop * fix EmptyGradMaker test=develop * fix diff between cpu and gpu test=develop * use memcopy when int8_t test=developrevert-21172-masked_select_api
parent
a6ce2306f9
commit
f0b1518438
@ -0,0 +1,98 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/dequantize_abs_max_op.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
|
||||
void operator()(const platform::CPUDeviceContext& dev_ctx,
|
||||
const framework::Tensor* in, const framework::Tensor* scale,
|
||||
float max_range, framework::Tensor* out) {
|
||||
const float* scale_factor = scale->data<float>();
|
||||
const T* input_data = in->data<T>();
|
||||
float* output_data = out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
int ind = in->numel();
|
||||
for (size_t i = 0; i < (unsigned)ind; i++) {
|
||||
output_data[i] = scale_factor[0] * input_data[i] / max_range;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct DequantizeFunctor<platform::CPUDeviceContext, int8_t>;
|
||||
|
||||
class DequantizeMaxAbsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
DequantizeMaxAbsOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
"Input(X) of DequantizeMaxAbsOp should not be null.");
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
"Output(Out) of DequantizeMaxAbsOp should not be null.");
|
||||
|
||||
ctx->ShareDim("X", /*->*/ "Out");
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
auto type = framework::OpKernelType(data_type, ctx.device_context());
|
||||
return type;
|
||||
}
|
||||
};
|
||||
|
||||
class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(int8 Tensor) The input with int8 type is the "
|
||||
"low precision tensor.");
|
||||
AddInput("Scale", "(float) The scale in quantization stage.");
|
||||
AddOutput("Out",
|
||||
"(float32 Tensor) The output is the dequantized high "
|
||||
"precision tensor.");
|
||||
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
|
||||
AddComment(R"DOC(
|
||||
DequantizeMaxAbsOp operator.
|
||||
|
||||
This calculation is an opposite operation of QuantizeMaxAbsOp:
|
||||
|
||||
$$Out = \frac{scale*X}{ max\_range }$$
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
dequantize_abs_max, ops::DequantizeMaxAbsOp, ops::DequantizeMaxAbsOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(dequantize_abs_max,
|
||||
ops::DequantizeMaxAbsKernel<CPU, int8_t>);
|
@ -0,0 +1,55 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/dequantize_abs_max_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeDequantize(const T* in, const float* scale, float max_range,
|
||||
int num, float* out) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < num) {
|
||||
out[idx] = in[idx] * scale[0] / max_range;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
|
||||
void operator()(const platform::CUDADeviceContext& dev_ctx,
|
||||
const framework::Tensor* in, const framework::Tensor* scale,
|
||||
float max_range, framework::Tensor* out) {
|
||||
const T* in_data = in->data<T>();
|
||||
const float* scale_factor = scale->data<float>();
|
||||
float* out_data = out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
|
||||
int num = in->numel();
|
||||
int block = 512;
|
||||
int grid = (num + block - 1) / block;
|
||||
|
||||
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
|
||||
in_data, scale_factor, max_range, num, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(dequantize_abs_max,
|
||||
ops::DequantizeMaxAbsKernel<CUDA, int8_t>);
|
@ -0,0 +1,50 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct DequantizeFunctor {
|
||||
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
|
||||
const framework::Tensor* scale, float max_range,
|
||||
framework::Tensor* out);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DequantizeMaxAbsKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* scale = ctx.Input<framework::Tensor>("Scale");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
float max_range = ctx.Attr<float>("max_range");
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
|
||||
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale, max_range, out);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,66 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def quantize_max_abs(x, max_range):
|
||||
scale = np.max(np.abs(x).flatten())
|
||||
y = np.round(x / scale * max_range)
|
||||
return y, scale
|
||||
|
||||
|
||||
def dequantize_max_abs(x, scale, max_range):
|
||||
y = (scale / max_range) * x
|
||||
return y
|
||||
|
||||
|
||||
class TestDequantizeMaxAbsOp(OpTest):
|
||||
def set_args(self):
|
||||
self.num_bits = 8
|
||||
self.max_range = math.pow(2, self.num_bits - 1) - 1
|
||||
self.data_type = "int8"
|
||||
|
||||
def setUp(self):
|
||||
self.set_args()
|
||||
self.op_type = "dequantize_abs_max"
|
||||
x = np.random.randn(31, 65).astype(self.data_type)
|
||||
yq, scale = quantize_max_abs(x, self.max_range)
|
||||
ydq = dequantize_max_abs(yq, scale, self.max_range)
|
||||
|
||||
self.inputs = {
|
||||
'X': np.array(yq).astype(self.data_type),
|
||||
'Scale': np.array(scale).astype('float32')
|
||||
}
|
||||
self.attrs = {'max_range': self.max_range}
|
||||
self.outputs = {'Out': ydq}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp):
|
||||
def set_args(self):
|
||||
self.num_bits = 5
|
||||
self.max_range = math.pow(2, self.num_bits - 1) - 1
|
||||
self.data_type = "int8"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue