|
|
|
@ -400,9 +400,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Ten
|
|
|
|
|
kernel::LiteKernel *kernel;
|
|
|
|
|
auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
|
|
|
|
|
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
|
|
|
|
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
} else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) {
|
|
|
|
|
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
} else {
|
|
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
|
|
|
|
}
|
|
|
|
|