|
|
@ -13,11 +13,12 @@
|
|
|
|
# limitations under the License.
|
|
|
|
# limitations under the License.
|
|
|
|
# ============================================================================
|
|
|
|
# ============================================================================
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
import mindspore as ms
|
|
|
|
import mindspore as ms
|
|
|
|
from mindspore import context, Tensor, Parameter
|
|
|
|
from mindspore import context, Tensor, Parameter
|
|
|
|
from mindspore.nn import Cell, TrainOneStepCell, Momentum
|
|
|
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
|
|
|
from mindspore.common.api import _executor
|
|
|
|
from mindspore.common.api import _executor
|
|
|
|
|
|
|
|
from mindspore.nn import Cell
|
|
|
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Net(Cell):
|
|
|
|
class Net(Cell):
|
|
|
@ -54,15 +55,15 @@ def test_train_and_eval():
|
|
|
|
context.set_context(save_graphs=True, mode=0)
|
|
|
|
context.set_context(save_graphs=True, mode=0)
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16)
|
|
|
|
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16)
|
|
|
|
strategy1 = ((4, 4), (4, 4))
|
|
|
|
strategy1 = ((4, 4), (4, 4))
|
|
|
|
strategy2 = ((4, 4), )
|
|
|
|
strategy2 = ((4, 4),)
|
|
|
|
net = Net(_w1, strategy1, strategy2)
|
|
|
|
net = Net(_w1, strategy1, strategy2)
|
|
|
|
eval_net = EvalNet(net, strategy2=strategy2)
|
|
|
|
eval_net = EvalNet(net, strategy2=strategy2)
|
|
|
|
net.set_train()
|
|
|
|
net.set_train()
|
|
|
|
net.set_auto_parallel()
|
|
|
|
net.set_auto_parallel()
|
|
|
|
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
|
|
|
|
_executor.compile(net, _x, _b, phase='train', auto_parallel_mode=True)
|
|
|
|
|
|
|
|
|
|
|
|
eval_net.set_train(mode=False)
|
|
|
|
eval_net.set_train(mode=False)
|
|
|
|
eval_net.set_auto_parallel()
|
|
|
|
eval_net.set_auto_parallel()
|
|
|
|
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
|
|
|
_executor.compile(eval_net, _x, _b, phase='eval', auto_parallel_mode=True)
|
|
|
|
|
|
|
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
context.reset_auto_parallel_context()
|