|
|
|
@ -21,6 +21,7 @@ import mindspore as ms
|
|
|
|
|
from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops.operations.comm_ops import _VirtualDataset
|
|
|
|
|
from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple
|
|
|
|
|
from mindspore import context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -73,6 +74,29 @@ def test_virtual_dataset_3_input():
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
_executor.compile(net, x, y, b)
|
|
|
|
|
|
|
|
|
|
def test_virtualdataset_cell_3_inputs():
|
|
|
|
|
class Net(nn.Cell):
|
|
|
|
|
def __init__(self, strategy0, strategy1, strategy2, strategy3):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.matmul1 = P.MatMul().set_strategy(strategy1)
|
|
|
|
|
self.matmul2 = P.MatMul().set_strategy(strategy2)
|
|
|
|
|
self.gelu = P.Gelu().set_strategy(strategy3)
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y, b):
|
|
|
|
|
out = self.gelu(self.matmul1(x, y))
|
|
|
|
|
out = self.matmul2(out, b)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
net = GradWrap(VirtualDatasetCellTriple(NetWithLoss(Net(None, None, None, None))))
|
|
|
|
|
context.set_context(save_graphs=True)
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel")
|
|
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0)
|
|
|
|
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
|
|
|
|
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
|
|
|
|
|
b = Tensor(np.ones([64, 2048]), dtype=ms.float32)
|
|
|
|
|
net.set_auto_parallel()
|
|
|
|
|
_executor.compile(net, x, y, b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
test_virtual_dataset_3_input()
|
|
|
|
|