|
|
|
@ -72,11 +72,21 @@ class Network(object):
|
|
|
|
|
self.__complete_add_op__ = False
|
|
|
|
|
|
|
|
|
|
def infer_shape(self):
|
|
|
|
|
self.complete_add_op()
|
|
|
|
|
self.net.infer_shape(get_cur_scope())
|
|
|
|
|
|
|
|
|
|
def run(self, device_context):
|
|
|
|
|
self.complete_add_op()
|
|
|
|
|
self.net.run(get_cur_scope(), device_context)
|
|
|
|
|
|
|
|
|
|
def __str__(self):
|
|
|
|
|
return str(self.net)
|
|
|
|
|
|
|
|
|
|
def complete_add_op(self):
|
|
|
|
|
if not self.__complete_add_op__:
|
|
|
|
|
self.net.complete_add_op()
|
|
|
|
|
self.__complete_add_op__ = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
net = Network()
|
|
|
|
|