|
|
@ -159,7 +159,7 @@ class ParallelOpTest(BaseParallelForTest):
|
|
|
|
|
|
|
|
|
|
|
|
def test_simple_fc(self):
|
|
|
|
def test_simple_fc(self):
|
|
|
|
self.run_test(
|
|
|
|
self.run_test(
|
|
|
|
callback=ParallelOpTest.__network__,
|
|
|
|
callback=self.__network__,
|
|
|
|
feed={
|
|
|
|
feed={
|
|
|
|
'img': numpy.random.random(size=(51, 784)).astype('float32')
|
|
|
|
'img': numpy.random.random(size=(51, 784)).astype('float32')
|
|
|
|
},
|
|
|
|
},
|
|
|
@ -167,10 +167,31 @@ class ParallelOpTest(BaseParallelForTest):
|
|
|
|
|
|
|
|
|
|
|
|
def test_fc_with_tiny_data(self):
|
|
|
|
def test_fc_with_tiny_data(self):
|
|
|
|
self.run_test(
|
|
|
|
self.run_test(
|
|
|
|
callback=ParallelOpTest.__network__,
|
|
|
|
callback=self.__network__,
|
|
|
|
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
|
|
|
|
feed={'img': numpy.random.random(size=(1, 784)).astype('float32')},
|
|
|
|
fetch=['fc1.w@GRAD'])
|
|
|
|
fetch=['fc1.w@GRAD'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelOpTestMultipleInput(BaseParallelForTest):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
|
|
def __network__():
|
|
|
|
|
|
|
|
x = fluid.layers.data(shape=[784], dtype='float32', name='img1', stop_gradient=False)
|
|
|
|
|
|
|
|
y = fluid.layers.data(shape=[784], dtype='float32', name='img2', stop_gradient=False)
|
|
|
|
|
|
|
|
yield [x, y]
|
|
|
|
|
|
|
|
x = x + y
|
|
|
|
|
|
|
|
hidden = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
|
|
|
|
|
|
|
|
loss = fluid.layers.mean(x=hidden)
|
|
|
|
|
|
|
|
yield loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_simple_fc(self):
|
|
|
|
|
|
|
|
self.run_test(
|
|
|
|
|
|
|
|
callback=self.__network__,
|
|
|
|
|
|
|
|
feed={
|
|
|
|
|
|
|
|
'img1': numpy.random.random(size=(51, 784)).astype('float32'),
|
|
|
|
|
|
|
|
'img2': numpy.random.random(size=(51, 784)).astype('float32')
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
fetch=['fc1.w@GRAD'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|
|
|
|
unittest.main()
|
|
|
|