!14689 fix batch_to_space&space_to_batch&SquaredDifference bug

From: @yeyunpeng2020
Reviewed-by: @ddwsky,@jpc_chenjianping
Signed-off-by: @ddwsky
pull/14689/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0e127e86ed

@ -124,11 +124,8 @@ int BatchToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten
return ret; return ret;
} }
if (inputs_size == 3) { if (inputs_size == 3) {
if (inputs[0]->data_ == NULL) {
return NNACL_INFER_INVALID;
}
if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) {
return NNACL_ERR; return NNACL_INFER_INVALID;
} }
int ret = SetOutputShapeFromInput(inputs, outputs); int ret = SetOutputShapeFromInput(inputs, outputs);
return ret; return ret;

@ -121,11 +121,8 @@ int SpaceToBatchNdInferShape(const TensorC *const *inputs, size_t inputs_size, T
} }
} }
if (inputs_size == 3) { if (inputs_size == 3) {
if (inputs[0]->data_ == NULL) {
return NNACL_INFER_INVALID;
}
if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) { if (inputs[1]->data_ == NULL || inputs[2]->data_ == NULL) {
return NNACL_ERR; return NNACL_INFER_INVALID;
} }
int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter); int ret = SpaceSetOutputShapeFromInput(inputs, inputs_size, outputs, outputs_size, parameter);
if (ret != NNACL_OK) { if (ret != NNACL_OK) {

@ -518,13 +518,30 @@ __kernel void BroadcastNHWC4SquaredDifference(__read_only image2d_t input_a, __r
int X = get_global_id(0); // C4 int X = get_global_id(0); // C4
int Y = get_global_id(1); // w int Y = get_global_id(1); // w
int Z = get_global_id(2); // H int Z = get_global_id(2); // H
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y) {
if (X >= output_shape.w || Y >= output_shape.z || Z >= output_shape.y * output_shape.x) {
return; return;
} }
int H = Z % output_shape.y;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(Y * a_shape.w + X, Z)); int N = Z / output_shape.y;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(X, 0)); int a_c = X < a_shape.w ? X : 0;
FLT4 result = pown((a - b), (int4)2); int a_w = Y < a_shape.z ? Y : 0;
int a_h = H < a_shape.y ? H : 0;
int a_n = N < a_shape.x ? N : 0;
FLT4 a = READ_IMAGE(input_a, smp_none, (int2)(a_w * a_shape.w + a_c, a_n * a_shape.y + a_h));
int b_c = X < b_shape.w ? X : 0;
int b_w = Y < b_shape.z ? Y : 0;
int b_h = H < b_shape.y ? H : 0;
int b_n = N < b_shape.x ? N : 0;
FLT4 b = READ_IMAGE(input_b, smp_none, (int2)(b_w * b_shape.w + b_c, b_n * b_shape.y + b_h));
FLT4 result;
if (broadcastC_flag == 0) {
result = pown((a - b), (int4)2);
} else if (broadcastC_flag == 1) {
result = pown((a.x - b), (int4)2);
} else {
result = pown((a - b.x), (int4)2);
}
result = clamp(result, (FLT)(act_min), (FLT)(act_max)); result = clamp(result, (FLT)(act_min), (FLT)(act_max));
WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result); WRITE_IMAGE(output, (int2)(Y * output_shape.w + X, Z), result);
} }

@ -82,8 +82,7 @@ int OpenCLKernel::GetImageSize(size_t idx, lite::opencl::ImageSize *img_size) {
img_dtype = CL_HALF_FLOAT; img_dtype = CL_HALF_FLOAT;
break; break;
} }
case kNumberTypeInt8: case kNumberTypeInt8: {
case kNumberTypeUInt8: {
img_dtype = CL_SIGNED_INT8; img_dtype = CL_SIGNED_INT8;
break; break;
} }

Loading…
Cancel
Save