pipeline_split adapt parallel

pull/9291/head
lichenever 4 years ago
parent cffe2c94fe
commit 78e131cf15

@ -98,7 +98,7 @@ class DeviceManager {
std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list
int64_t global_rank_ = 0; // the real rank in all devices
int64_t stage_num_ = 0; // the stage num
int64_t stage_num_ = 1; // the stage num
int64_t stage_id_ = 0; // the stage id of the global_rank_
int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage
int64_t stage_device_num_ = 0; // the device num of one stage

@ -75,7 +75,8 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
EMBED,
CREATINSTANCE,
REF_TO_EMBED,
STOP_GRADIENT};
STOP_GRADIENT,
SEND};
const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER};

@ -182,6 +182,8 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog
constexpr char MATMUL[] = "MatMul";
constexpr char GELU[] = "Gelu";
constexpr char TANH[] = "Tanh";
constexpr char RECEIVE[] = "Receive";
constexpr char SEND[] = "Send";
constexpr char SHAPE_OP[] = "Shape";
constexpr char SOFTMAX[] = "Softmax";
constexpr char LOG_SOFTMAX[] = "LogSoftmax";

@ -19,13 +19,18 @@
#include <utility>
#include <string>
#include <memory>
#include "ir/value.h"
#include "ir/graph_utils.h"
#include "base/base.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;
typedef struct {
ValueListPtr shape;
TypePtr type;
@ -59,8 +64,10 @@ class PipelineTransformer {
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int user_node_stage, int node_stage);
void CutBorder(const FuncGraphPtr &graph);
void ElimRootParameter();
bool IsStageNode(const CNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
FuncGraphManagerPtr manager_;
int64_t stage_;

@ -1752,7 +1752,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) {
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) {
continue;
}
auto attrs = prim->attrs();
@ -2420,6 +2420,13 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
return sens_loss_pairs;
}
bool IsLastStage() {
MS_EXCEPTION_IF_NULL(g_device_manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
return ((stage_num - 1) == stage_id);
}
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(root);
@ -2432,7 +2439,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
for (auto &pair : sens_loss_pairs) {
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
StepSplitSens(pair);
if (IsLastStage()) {
StepSplitSens(pair);
}
}
for (auto &node : all_nodes) {
@ -2448,13 +2457,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(distribute_operator);
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
if (!IsSomePrimitive(cnode, RECEIVE)) {
InsertForwardOps(distribute_operator, cnode);
}
// insert redistribution ops
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
// insert backward ops
if (has_backward) {
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
}
@ -2468,7 +2479,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
continue;
}
@ -2895,7 +2906,7 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN
for (auto &candidate : candidate_set) {
auto candidate_node = candidate.first;
auto c = candidate_node->cast<CNodePtr>();
if (c == nullptr || !c->has_user_data<OperatorInfo>()) {
if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
continue;
}
(void)parameter_user_info.second.second.insert(candidate);

@ -131,6 +131,10 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node);
void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
bool IsLastStage();
// Add node for whole graph
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);

@ -21,6 +21,7 @@
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace pipeline {
@ -59,7 +60,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL || parallel_mode != parallel::AUTO_PARALLEL) {
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
@ -80,6 +81,9 @@ bool PipelineSplit(const ResourcePtr &res) {
}
auto stage = InferStage(global_rank, stage_num, device_num);
auto per_stage_rank_num = device_num / stage_num;
if (parallel::ParallelInit() != parallel::SUCCESS) {
MS_LOG(EXCEPTION) << "parallel init failed.";
}
auto transformer =
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
// step1: Do color graph

@ -20,9 +20,10 @@ from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
_GetTensorSlice, _MirrorOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@bprop_getters.register(AllReduce)

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Unique, GatherD, Identity, SequenceMask)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, Send, Receive,
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

@ -21,6 +21,7 @@ from ... import context
from ...common import dtype as mstype
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
from ...communication.management import get_rank, GlobalComm, _get_group
class ExtractImagePatches(PrimitiveWithInfer):
@ -371,6 +372,116 @@ class MatrixDiagPart(PrimitiveWithInfer):
return out_shape
class Send(PrimitiveWithInfer):
"""
Send tensors from src_rank to the specified dest_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Send must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be received by the Receive op with the same "sr_tag".
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = ops.Depend()
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.sr_tag = sr_tag
self.group = group
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return x_shape
def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Receive must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be send by the Send op with the same "sr_tag".
src_rank (int): A required integer identifying the source rank.
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group")
>>>
>>> def construct(self):
>>> out = self.recv()
>>> return out
>>>
>>> net = Net()
>>> output = net()
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.tag = sr_tag
self.shape = shape
self.dtype = dtype
self.group = group
def infer_shape(self, x_shape=None):
return self.shape
def infer_dtype(self, x_dtype=None):
return self.dtype
class MatrixSetDiag(PrimitiveWithInfer):
r"""
Modifies the batched diagonal part of a batched tensor.

@ -116,117 +116,6 @@ class AllReduce(PrimitiveWithInfer):
return x_dtype
class Send(PrimitiveWithInfer):
"""
Send tensors from src_rank to the specified dest_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Send must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be received by the Receive op with the same "sr_tag".
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = ops.Depend()
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.sr_tag = sr_tag
self.group = group
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return x_shape
def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Receive must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be send by the Send op with the same "sr_tag".
src_rank (int): A required integer identifying the source rank.
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.recv(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.tag = sr_tag
self.shape = shape
self.dtype = dtype
self.group = group
def infer_shape(self, x_shape=None):
return self.shape
def infer_dtype(self, x_dtype=None):
return self.dtype
class AllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group.

@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
@ -38,7 +39,7 @@ class SendNet(nn.Cell):
super(SendNet, self).__init__()
self.x = Parameter(initializer(Tensor(x), x.shape), name='x')
self.depend = P.Depend()
self.send = P.Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP)
self.send = Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP)
def construct(self):
out = self.depend(self.x, self.send(self.x))
@ -47,8 +48,8 @@ class SendNet(nn.Cell):
class RecvNet(nn.Cell):
def __init__(self):
super(RecvNet, self).__init__()
self.recv = P.Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32,
group=NCCL_WORLD_COMM_GROUP)
self.recv = Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32,
group=NCCL_WORLD_COMM_GROUP)
def construct(self):
out = self.recv()

@ -1,91 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().shard(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.gatherv2.set_stage(stage1)
self.mul.set_stage(stage2)
self.axis = axis
def construct(self, x, y):
out = self.gatherv2(x, self.index, self.axis)
out = self.mul(out, y)
return out
def test_gatherv2_semi_samestage1():
context.set_auto_parallel_context(device_num=8, global_rank=0, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)
def test_gatherv2_semi_samestage2():
context.set_auto_parallel_context(device_num=8, global_rank=5, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label
def reset(self):
self.index = 0
def get_dataset_size(self):
return 32
def get_repeat_count(self):
return 1
def get_batch_size(self):
return 32
def create_tuple_iterator(self, num_epochs=1):
return self
class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x):
out = self.matmul(x, self.param)
out = self.matmul1(out, self.param1)
return out
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2)
cell.stage = i
self.block.append(cell)
def construct(self, x):
for i in range(2):
x = self.block[i](x)
return x
class PipelineSplit(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.cell = Net(strategy1, strategy2)
def construct(self, x, label):
x = self.cell(x)
return x
def test_pipeline_split():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit(strategy1, strategy2)
params = net.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)
Loading…
Cancel
Save