diff --git a/mindspore/lite/nnacl/infer/batch_to_space_infer.c b/mindspore/lite/nnacl/infer/batch_to_space_infer.c index 5746ddf714..36f0261bcd 100644 --- a/mindspore/lite/nnacl/infer/batch_to_space_infer.c +++ b/mindspore/lite/nnacl/infer/batch_to_space_infer.c @@ -124,11 +124,8 @@ int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten return ret; } if (inputs_size == 3) { - if (inputs[0]->data_ == NULL) { - return NNACL_INFER_INVALID; - } if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { - return NNACL_ERR; + return NNACL_INFER_INVALID; } int ret = SetOutputShapeFromInput(inputs, outputs); return ret; diff --git a/mindspore/lite/nnacl/infer/space_to_batch_nd_infer.c b/mindspore/lite/nnacl/infer/space_to_batch_nd_infer.c index 84304424e2..322c67db0f 100644 --- a/mindspore/lite/nnacl/infer/space_to_batch_nd_infer.c +++ b/mindspore/lite/nnacl/infer/space_to_batch_nd_infer.c @@ -121,11 +121,8 @@ int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, T } } if (inputs_size == 3) { - if (inputs[0]->data_ == NULL) { - return NNACL_INFER_INVALID; - } if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { - return NNACL_ERR; + return NNACL_INFER_INVALID; } int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter); if (ret != NNACL_OK) { diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl index 20280cb72b..3740338daa 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/arithmetic.cl @@ -518,13 +518,30 @@ __kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __r int X = get_global_id(0); // C4 int Y = get_global_id(1); // w int Z = get_global_id(2); // H - if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) { + + if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) { return; } - - FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z)); - FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0)); - FLT4 result = pown((a - b), (int4)2); + int H = Z % output_shape.y; + int N = Z / output_shape.y; + int a_c = X < a_shape.w ? X : 0; + int a_w = Y < a_shape.z ? Y : 0; + int a_h = H < a_shape.y ? H : 0; + int a_n = N < a_shape.x ? N : 0; + FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h)); + int b_c = X < b_shape.w ? X : 0; + int b_w = Y < b_shape.z ? Y : 0; + int b_h = H < b_shape.y ? H : 0; + int b_n = N < b_shape.x ? N : 0; + FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h)); + FLT4 result; + if (broadcastC_flag == 0) { + result = pown((a - b), (int4)2); + } else if (broadcastC_flag == 1) { + result = pown((a.x - b), (int4)2); + } else { + result = pown((a - b.x), (int4)2); + } result = clamp(result, (FLT)(act_min), (FLT)(act_max)); WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index 7f9013cac7..feb461bf7b 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -82,8 +82,7 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) { img_dtype = CL_HALF_FLOAT; break; } - case kNumberTypeInt8: - case kNumberTypeUInt8: { + case kNumberTypeInt8: { img_dtype = CL_SIGNED_INT8; break; }