remove pow to speed up in dequantize_log op (#24607)

* remove pow in speed up in dequantize_log test=develop

* remove pow in speed up in dequantize_log test=develop

* fix unittest test=develop
v1.8
Liufang Sang 5 years ago committed by GitHub
parent 9fd1aad6e7
commit 55b664a131
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,9 +31,9 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
int ind = in->numel();
for (size_t i = 0; i < (unsigned)ind; i++) {
if (input_data[i] < 0) {
output_data[i] = -std::pow(2.0, dict_data[input_data[i] + 128]);
output_data[i] = -dict_data[input_data[i] + 128];
} else {
output_data[i] = std::pow(2.0, dict_data[input_data[i]]);
output_data[i] = dict_data[input_data[i]];
}
}
}

@ -26,9 +26,9 @@ __global__ void KeDequantize(const T* in, const float* dict, int num,
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
if (in[idx] < 0) {
out[idx] = -std::pow(static_cast<float>(2.0), dict[in[idx] + 128]);
out[idx] = -dict[in[idx] + 128];
} else {
out[idx] = std::pow(static_cast<float>(2.0), dict[in[idx]]);
out[idx] = dict[in[idx]];
}
}
}

@ -26,9 +26,9 @@ def dequantize_log(x, dict_data):
output_data_f = output_data.flatten()
for i in range(x_f.size):
if x_f[i] < 0:
output_data_f[i] = -np.power(2, dict_data[x_f[i] + 128])
output_data_f[i] = -dict_data[x_f[i] + 128]
else:
output_data_f[i] = np.power(2, dict_data[x_f[i]])
output_data_f[i] = dict_data[x_f[i]]
return output_data_f.reshape(x.shape)

Loading…
Cancel
Save