parent
c14f3e8ff2
commit
0ab678e9e2
@ -0,0 +1,23 @@
|
||||
from paddle.v2.framework.network import Network
|
||||
import paddle.v2.framework.core as core
|
||||
import unittest
|
||||
|
||||
|
||||
class TestNet(unittest.TestCase):
|
||||
def test_net_all(self):
|
||||
net = Network()
|
||||
out = net.add_two(X="X", Y="Y")
|
||||
fc_out = net.fc(X=out, W="w")
|
||||
net.complete_add_op()
|
||||
self.assertTrue(isinstance(fc_out, core.Variable))
|
||||
self.assertEqual(
|
||||
'''Op(naive_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, add_two@OUT@0, fc@OUT@1).
|
||||
Op(add_two), inputs:(X, Y), outputs:(add_two@OUT@0).
|
||||
Op(fc), inputs:(add_two@OUT@0, w, @EMPTY@), outputs:(fc@OUT@1, @TEMP@fc@0).
|
||||
Op(mul), inputs:(add_two@OUT@0, w), outputs:(@TEMP@fc@0).
|
||||
Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc@OUT@1).
|
||||
''', str(net))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue