From e3387e65e499e71d6c24a3fe938900c22c5b4941 Mon Sep 17 00:00:00 2001 From: wangdongxu Date: Fri, 29 Jan 2021 10:53:27 +0800 Subject: [PATCH] fix opencl winograd fp16 bug --- mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl | 4 ++++ mindspore/lite/src/runtime/opencl/opencl_runtime.cc | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl index b9b1644c2b..6eb7f96a9f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/winograd.cl @@ -52,6 +52,10 @@ __kernel void Winograd4x4To36(__read_only image2d_t input, // height=N*H for (int x = 0; x < 6; x++) { acc += BtD_row[x] * Bt[y * 6 + x]; } +#if FP16_ENABLE + acc = min(acc, HALF_MAX); + acc = max(acc, -HALF_MAX); +#endif WRITE_IMAGE(output, (int2)(tile_hw, y_idx + y), acc); } } diff --git a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc index 989bbee529..471e17977f 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_runtime.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_runtime.cc @@ -367,12 +367,12 @@ int OpenCLRuntime::BuildKernel(cl::Kernel &kernel, const std::string &program_na std::string build_option = default_build_option_; if (fp16_enable_) { build_option += - " -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4 " - "-DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; + " -DFP16_ENABLE=1 -DFLT=half -DFLT4=half4 -DFLT16=half16 -DAS_FLT4=as_half4 -DAS_UINT4=as_ushort4 -DUINT4=ushort4" + " -DWRITE_IMAGE=write_imageh -DREAD_IMAGE=read_imageh -DTO_FLT=convert_half -DTO_FLT4=convert_half4"; } else { build_option += - " -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4 " - "-DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4"; + " -DFP16_ENABLE=0 -DFLT=float -DFLT4=float4 -DFLT16=float16 -DAS_FLT4=as_float4 -DAS_UINT4=as_uint4 -DUINT4=uint4" + " -DWRITE_IMAGE=write_imagef -DREAD_IMAGE=read_imagef -DTO_FLT=convert_float -DTO_FLT4=convert_float4"; } build_option = std::accumulate(build_options_ext.begin(), build_options_ext.end(), build_option,