|
|
|
@ -114,6 +114,36 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() {
|
|
|
|
|
workspace_size_list_.push_back(workspace_size_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
|
|
|
|
|
float *input_max, float *d_nudge_min, float *d_nudge_max,
|
|
|
|
|
float *d_scale, uintptr_t stream_ptr) {
|
|
|
|
|
// calculate the input min and max according by the parameter ema and ema_decay.
|
|
|
|
|
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
// control flow for quant_delay
|
|
|
|
|
if (global_step_ >= quant_delay_) {
|
|
|
|
|
// real launch
|
|
|
|
|
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
|
|
|
|
|
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
|
|
|
|
|
"Copy gpu memory failed.");
|
|
|
|
|
}
|
|
|
|
|
global_step_++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
|
|
|
|
|
float *input_max, float *d_nudge_min, float *d_nudge_max,
|
|
|
|
|
float *d_scale, uintptr_t stream_ptr) {
|
|
|
|
|
// real launch
|
|
|
|
|
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
|
|
|
|
|
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
const std::vector<AddressPtr> &workspace,
|
|
|
|
|
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
|
|
|
|
@ -126,11 +156,8 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
if (input == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
|
|
|
|
|
}
|
|
|
|
|
if (input_min == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min is null.";
|
|
|
|
|
}
|
|
|
|
|
if (input_max == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input max is null.";
|
|
|
|
|
if (input_min == nullptr || input_max == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min or max is null.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Allocate space for device copies
|
|
|
|
@ -143,30 +170,11 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
"Malloc gpu memory failed");
|
|
|
|
|
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_),
|
|
|
|
|
"Malloc gpu memory failed");
|
|
|
|
|
int total_size = input_size_ / sizeof(float);
|
|
|
|
|
bool symmetric = false;
|
|
|
|
|
|
|
|
|
|
if (training_) {
|
|
|
|
|
// calculate the input min and max according by the parameter ema and ema_decay.
|
|
|
|
|
CalMinMaxPerChannel(input, input_min, input_max, total_size, channel_out_, ema_decay_, ema_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
// control flow for quant_delay
|
|
|
|
|
if (global_step_ >= quant_delay_) {
|
|
|
|
|
// real launch
|
|
|
|
|
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
} else {
|
|
|
|
|
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
|
|
|
|
|
"Copy gpu memory failed.");
|
|
|
|
|
}
|
|
|
|
|
global_step_++;
|
|
|
|
|
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
|
|
|
|
} else {
|
|
|
|
|
// real launch
|
|
|
|
|
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
|
|
|
|
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
|
|
|
|
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Cleanup
|
|
|
|
|