add dequantize_log_op and make pyramid hash support int8 weight (#22548)
* add dequantize_log_op and make pyramid hash support int8 weight test=develop * add unittest and update pyramid hash op test=develop * remove paddle_enforce test=develop * fix error message test=develop * remove incorrent commit test=develop * fix error message in log_dequantize test=develop * change 2019 to 2020 test=develop * remove useless check_grad test=developrevert-23830-2.0-beta
parent
e5fef8f38a
commit
4db031902d
@ -0,0 +1,103 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/dequantize_log_op.h"
|
||||
#include <math.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
|
||||
void operator()(const platform::CPUDeviceContext& dev_ctx,
|
||||
const framework::Tensor* in, const framework::Tensor* dict,
|
||||
framework::Tensor* out) {
|
||||
const float* dict_data = dict->data<float>();
|
||||
const T* input_data = in->data<T>();
|
||||
float* output_data = out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
int ind = in->numel();
|
||||
for (size_t i = 0; i < (unsigned)ind; i++) {
|
||||
if (input_data[i] < 0) {
|
||||
output_data[i] = -pow(2, dict_data[input_data[i] + 128]);
|
||||
} else {
|
||||
output_data[i] = pow(2, dict_data[input_data[i]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template struct DequantizeFunctor<platform::CPUDeviceContext, int8_t>;
|
||||
|
||||
class DequantizeLogOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
DequantizeLogOp(const std::string& type,
|
||||
const framework::VariableNameMap& inputs,
|
||||
const framework::VariableNameMap& outputs,
|
||||
const framework::AttributeMap& attrs)
|
||||
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::NotFound(
|
||||
"Input(X) of DequantizeLogOp is not found."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound(
|
||||
"Output(Out) of DequantizeLogOp is not found."));
|
||||
|
||||
ctx->ShareDim("X", /*->*/ "Out");
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
auto type = framework::OpKernelType(data_type, ctx.device_context());
|
||||
return type;
|
||||
}
|
||||
};
|
||||
|
||||
class DequantizeLogOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(int8 Tensor) The input with int8 type is the "
|
||||
"low precision tensor.");
|
||||
AddInput("Dict", "(float) The Dict in quantization stage.");
|
||||
AddOutput("Out",
|
||||
"(float32 Tensor) The output is the dequantized high "
|
||||
"precision tensor.");
|
||||
AddComment(R"DOC(
|
||||
DequantizeLogOp operator.
|
||||
|
||||
This calculation is an opposite operation of QuantizeLogOp:
|
||||
|
||||
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
dequantize_log, ops::DequantizeLogOp, ops::DequantizeLogOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OP_CPU_KERNEL(dequantize_log, ops::DequantizeLogKernel<CPU, int8_t>);
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/dequantize_log_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
__global__ void KeDequantize(const T* in, const float* dict, int num,
|
||||
float* out) {
|
||||
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < num) {
|
||||
if (in[idx] < 0) {
|
||||
out[idx] = -pow(2, dict[in[idx] + 128]);
|
||||
} else {
|
||||
out[idx] = pow(2, dict[in[idx]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
|
||||
void operator()(const platform::CUDADeviceContext& dev_ctx,
|
||||
const framework::Tensor* in, const framework::Tensor* dict,
|
||||
framework::Tensor* out) {
|
||||
const T* in_data = in->data<T>();
|
||||
const float* dict_data = dict->data<float>();
|
||||
float* out_data = out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
|
||||
int num = in->numel();
|
||||
int block = 512;
|
||||
int grid = (num + block - 1) / block;
|
||||
|
||||
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(in_data, dict_data,
|
||||
num, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(dequantize_log, ops::DequantizeLogKernel<CUDA, int8_t>);
|
@ -0,0 +1,46 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ddim.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
struct DequantizeFunctor {
|
||||
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
|
||||
const framework::Tensor* dict, framework::Tensor* out);
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class DequantizeLogKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
virtual void Compute(const framework::ExecutionContext& ctx) const {
|
||||
auto* in = ctx.Input<framework::Tensor>("X");
|
||||
auto* dict = ctx.Input<framework::Tensor>("Dict");
|
||||
auto* out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
out->mutable_data<float>(dev_ctx.GetPlace());
|
||||
|
||||
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, dict, out);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,53 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def dequantize_log(x, dict_data):
|
||||
output_data = np.zeros_like(x).astype('float32')
|
||||
x_f = x.flatten()
|
||||
output_data_f = output_data.flatten()
|
||||
for i in range(x_f.size):
|
||||
if x_f[i] < 0:
|
||||
output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128])
|
||||
else:
|
||||
output_data_f[i] = np.power(2, dict_data[x_f[i]])
|
||||
return output_data_f.reshape(x.shape)
|
||||
|
||||
|
||||
class TestDequantizeLogOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "dequantize_log"
|
||||
x = np.random.randint(low=-128, high=127, size=(20, 10)).astype('int8')
|
||||
dict_data = np.random.random(128).astype('float32')
|
||||
xdq = dequantize_log(x, dict_data)
|
||||
|
||||
self.inputs = {
|
||||
'X': np.array(x).astype('int8'),
|
||||
'Dict': np.array(dict_data).astype('float32')
|
||||
}
|
||||
self.outputs = {'Out': xdq}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue