fix interpolate cu. test=develop (#17101)

feature/fluid_trt_int8
Kaipeng Deng 6 years ago committed by GitHub
parent aca60e9a20
commit 10c487eb21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -286,7 +286,7 @@ class InterpolateGradOpCUDAKernel : public framework::OpKernel<T> {
float scale = ctx.Attr<float>("scale"); float scale = ctx.Attr<float>("scale");
if (scale > 0) { if (scale > 0) {
out_h = in_h * scale; out_h = in_h * scale;
out_w - in_w* scale; out_w = in_w * scale;
} }
auto out_size = ctx.Input<Tensor>("OutSize"); auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) { if (out_size != nullptr) {

@ -305,7 +305,7 @@ class TestBilinearInterpWithMethod3(TestBilinearInterpOp):
class TestBilinearInterpScale1(TestBilinearInterpOp): class TestBilinearInterpScale1(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 2. self.scale = 2.
@ -316,7 +316,7 @@ class TestBilinearInterpScale1(TestBilinearInterpOp):
class TestBilinearInterpScale2(TestBilinearInterpOp): class TestBilinearInterpScale2(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 1. self.scale = 1.
@ -327,7 +327,7 @@ class TestBilinearInterpScale2(TestBilinearInterpOp):
class TestBilinearInterpScale3(TestBilinearInterpOp): class TestBilinearInterpScale3(TestBilinearInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'bilinear' self.interp_method = 'bilinear'
self.input_shape = [2, 3, 16, 32] self.input_shape = [2, 3, 5, 7]
self.out_h = 60 self.out_h = 60
self.out_w = 25 self.out_w = 25
self.scale = 1.5 self.scale = 1.5

@ -259,7 +259,7 @@ class TestNearestInterpWithoutCorners(TestNearestInterpOp):
class TestNearestNeighborInterpScale1(TestNearestInterpOp): class TestNearestNeighborInterpScale1(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 7, 5]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 2. self.scale = 2.
@ -270,7 +270,7 @@ class TestNearestNeighborInterpScale1(TestNearestInterpOp):
class TestNearestNeighborInterpScale2(TestNearestInterpOp): class TestNearestNeighborInterpScale2(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 5, 7]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 1.5 self.scale = 1.5
@ -281,7 +281,7 @@ class TestNearestNeighborInterpScale2(TestNearestInterpOp):
class TestNearestNeighborInterpScale3(TestNearestInterpOp): class TestNearestNeighborInterpScale3(TestNearestInterpOp):
def init_test_case(self): def init_test_case(self):
self.interp_method = 'nearest' self.interp_method = 'nearest'
self.input_shape = [3, 2, 32, 16] self.input_shape = [3, 2, 7, 5]
self.out_h = 64 self.out_h = 64
self.out_w = 32 self.out_w = 32
self.scale = 1. self.scale = 1.

Loading…
Cancel
Save