make reverse op support negative axis (#21925)

* make reverse op support negative axis
release/1.7
mapingshuo 6 years ago committed by GitHub
parent 03479469a7
commit c3e1954918
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,7 +31,13 @@ class ReverseOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty."); PADDLE_ENFORCE(!axis.empty(), "'axis' can not be empty.");
for (int a : axis) { for (int a : axis) {
PADDLE_ENFORCE_LT(a, x_dims.size(), PADDLE_ENFORCE_LT(a, x_dims.size(),
"The axis must be less than input tensor's rank."); paddle::platform::errors::OutOfRange(
"The axis must be less than input tensor's rank."));
PADDLE_ENFORCE_GE(
a, -x_dims.size(),
paddle::platform::errors::OutOfRange(
"The axis must be greater than the negative number of "
"input tensor's rank."));
} }
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
} }

@ -28,7 +28,11 @@ struct ReverseFunctor {
reverse_axis[i] = false; reverse_axis[i] = false;
} }
for (int a : axis) { for (int a : axis) {
reverse_axis[a] = true; if (a >= 0) {
reverse_axis[a] = true;
} else {
reverse_axis[Rank + a] = true;
}
} }
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); auto in_eigen = framework::EigenTensor<T, Rank>::From(in);

@ -47,23 +47,47 @@ class TestCase0(TestReverseOp):
self.axis = [1] self.axis = [1]
class TestCase0(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64')
self.axis = [-1]
class TestCase1(TestReverseOp): class TestCase1(TestReverseOp):
def initTestCase(self): def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64') self.x = np.random.random((3, 40)).astype('float64')
self.axis = [0, 1] self.axis = [0, 1]
class TestCase0(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 40)).astype('float64')
self.axis = [0, -1]
class TestCase2(TestReverseOp): class TestCase2(TestReverseOp):
def initTestCase(self): def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64') self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [0, 2] self.axis = [0, 2]
class TestCase2(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [0, -2]
class TestCase3(TestReverseOp): class TestCase3(TestReverseOp):
def initTestCase(self): def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64') self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [1, 2] self.axis = [1, 2]
class TestCase3(TestReverseOp):
def initTestCase(self):
self.x = np.random.random((3, 4, 10)).astype('float64')
self.axis = [-1, -2]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save