|
|
|
@ -65,32 +65,11 @@ class IthOutputCell(nn.Cell):
|
|
|
|
|
self.output_index = output_index
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
def construct1(self, x1):
|
|
|
|
|
predict = self.network(x1)[self.output_index]
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
def construct2(self, x1, x2):
|
|
|
|
|
predict = self.network(x1, x2)[self.output_index]
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
def construct3(self, x1, x2, x3):
|
|
|
|
|
predict = self.network(x1, x2, x3)[self.output_index]
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
def construct4(self, x1, x2, x3, x4):
|
|
|
|
|
predict = self.network(x1, x2, x3, x4)[self.output_index]
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
def construct5(self, x1, x2, x3, x4, x5):
|
|
|
|
|
predict = self.network(x1, x2, x3, x4, x5)[self.output_index]
|
|
|
|
|
predict = self.network(*inputs)[self.output_index]
|
|
|
|
|
return predict
|
|
|
|
|
|
|
|
|
|
def get_output_cell(network, num_input, output_index, training=True):
|
|
|
|
|
net = IthOutputCell(network, output_index)
|
|
|
|
|
f = getattr(net, 'construct%d' % num_input)
|
|
|
|
|
setattr(net, "construct", f)
|
|
|
|
|
set_block_training(net, training)
|
|
|
|
|
return net
|
|
|
|
|
|
|
|
|
|