|
|
@ -37,7 +37,7 @@ class TestScatterOp(OpTest):
|
|
|
|
self.check_output()
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
self.check_grad(['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad(["X", "Updates"], "Out")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestScatterOp0(OpTest):
|
|
|
|
class TestScatterOp0(OpTest):
|
|
|
@ -56,7 +56,7 @@ class TestScatterOp0(OpTest):
|
|
|
|
self.check_output()
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
self.check_grad(['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad(["X", "Updates"], "Out")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestScatterOp1(OpTest):
|
|
|
|
class TestScatterOp1(OpTest):
|
|
|
@ -78,7 +78,7 @@ class TestScatterOp1(OpTest):
|
|
|
|
self.check_output()
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
self.check_grad(['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad(["X", "Updates"], "Out")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
@ -102,7 +102,7 @@ class TestScatterOp2(OpTest):
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
@ -130,7 +130,7 @@ class TestScatterOp3(OpTest):
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestScatterOp4(OpTest):
|
|
|
|
class TestScatterOp4(OpTest):
|
|
|
@ -148,7 +148,7 @@ class TestScatterOp4(OpTest):
|
|
|
|
self.check_output()
|
|
|
|
self.check_output()
|
|
|
|
|
|
|
|
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
self.check_grad(['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad(['X', 'Updates'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
|
@unittest.skipIf(not core.is_compiled_with_cuda(),
|
|
|
@ -172,7 +172,7 @@ class TestScatterOp5(OpTest):
|
|
|
|
def test_check_grad(self):
|
|
|
|
def test_check_grad(self):
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
place = core.CUDAPlace(0)
|
|
|
|
self.check_grad_with_place(place, ['Updates'], 'Out', in_place=True)
|
|
|
|
self.check_grad_with_place(place, ['X', 'Updates'], 'Out')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestScatterAPI(unittest.TestCase):
|
|
|
|
class TestScatterAPI(unittest.TestCase):
|
|
|
|