|
|
|
@ -146,9 +146,20 @@ class SequentialCell(Cell):
|
|
|
|
|
cell.set_grad(flag)
|
|
|
|
|
|
|
|
|
|
def append(self, cell):
|
|
|
|
|
"""Appends a given cell to the end of the list."""
|
|
|
|
|
"""Appends a given cell to the end of the list.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> conv = nn.Conv2d(3, 2, 3, pad_mode='valid')
|
|
|
|
|
>>> bn = nn.BatchNorm2d(2)
|
|
|
|
|
>>> relu = nn.ReLU()
|
|
|
|
|
>>> seq = nn.SequentialCell([conv, bn])
|
|
|
|
|
>>> seq.append(relu)
|
|
|
|
|
>>> x = Tensor(np.random.random((1, 3, 4, 4)), dtype=mindspore.float32)
|
|
|
|
|
>>> seq(x)
|
|
|
|
|
"""
|
|
|
|
|
if _valid_cell(cell):
|
|
|
|
|
self._cells[str(len(self))] = cell
|
|
|
|
|
self.cell_list = list(self._cells.values())
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data):
|
|
|
|
|