support forward unique

pull/8269/head
yao_yf 4 years ago
parent e094b63f87
commit 31819bb4a7

@ -27,7 +27,7 @@ void DynamicShapeKernel::Execute() {
}
auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0);
auto output_shape = std::vector<int64_t>(SizeToLong(prev_output_shape.size()));
std::vector<int64_t> output_shape = {SizeToLong(prev_output_shape.size())};
auto output_type = TypeId::kNumberTypeInt64;

@ -62,7 +62,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
continue;
}
}
if (AnfAlgo::IsNodeDynamicShape(cnode) &&
if (AnfAlgo::IsDynamicShape(cnode) &&
DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) {
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
continue;

@ -541,7 +541,6 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kerne
void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const {
MS_LOG(INFO) << "Start!";
MS_EXCEPTION_IF_NULL(kernel_graph);
opt::RemoveNopNode(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->AssignMemory(kernel_graph);

@ -471,7 +471,21 @@ bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; }
DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "anf_node should be a cnode";
}
auto cnode = anf_node->cast<CNodePtr>();
if (opt::IsNopNode(cnode)) {
size_t kNopNodeInputSize = 2;
size_t kNopNodeRealInputIndex = 1;
if (cnode->size() != kNopNodeInputSize) {
MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size();
}
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
return PreAssignCNodeMemory(cnode->input(kNopNodeRealInputIndex), input_node_with_index.second);
}
auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
if (output_sizes.size() <= index) {
MS_LOG(EXCEPTION) << "Previous node output size < node index";

@ -126,6 +126,7 @@ class Parameter(MetaTensor_):
self.is_param_ps = False
self._cast_type = None
self.init_in_server = False
self._unique = False
self.is_in_parallel = _is_in_parallel_mode()
@staticmethod
@ -238,6 +239,15 @@ class Parameter(MetaTensor_):
def sliced(self, sliced_):
self._sliced = sliced_
@property
def unique(self):
"""whether the parameter is already unique or not."""
return self._unique
@unique.setter
def unique(self, unique_):
self._unique = unique_
@property
def is_init(self):
"""

@ -433,6 +433,17 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr
result_shp.push_back(input_shp[idx]);
indices.insert(idx);
}
ShapeVector max_shp;
ShapeVector min_shp;
if (input->shape()->max_shape().size() == input_shp.size() &&
input->shape()->min_shape().size() == input_shp.size()) {
for (size_t i = 0; i < perm_vec.size(); i++) {
size_t idx = static_cast<size_t>(perm_vec[i]);
max_shp.push_back(input->shape()->max_shape()[idx]);
min_shp.push_back(input->shape()->min_shape()[idx]);
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp, min_shp, max_shp));
}
return std::make_shared<AbstractTensor>(input->element(), std::make_shared<Shape>(result_shp));
}

@ -405,10 +405,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr
if (tmp_shape.empty()) {
MS_LOG(EXCEPTION) << "shape size is 0";
}
if (tmp_shape[0] % rank_size != 0) {
MS_LOG(EXCEPTION) << "first dimension of x should be divided by rank_size";
if (tmp_shape[0] > 0) {
tmp_shape[0] = tmp_shape[0] * rank_size;
}
tmp_shape[0] = tmp_shape[0] / rank_size;
return std::make_shared<AbstractTensor>(x->element(), std::make_shared<Shape>(tmp_shape));
}

@ -149,6 +149,7 @@ class EmbeddingLookup(Cell):
max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
Inputs:
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table,
@ -161,6 +162,8 @@ class EmbeddingLookup(Cell):
Examples:
>>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32)
>>> out = nn.EmbeddingLookup(4,2)(input_indices)
>>> output.shape
(2, 2, 2)
"""
BATCH_SLICE = "batch_slice"
FIELD_SLICE = "field_slice"
@ -188,6 +191,12 @@ class EmbeddingLookup(Cell):
name='embedding_table')
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.forward_unique = False
self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),))
self.reshape = P.Reshape()
self.shape = P.Shape()
indices_shape_size = 2
if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes:
raise ValueError("in slice field mode, the manual_shapes should not be none")
@ -200,18 +209,32 @@ class EmbeddingLookup(Cell):
self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size())))
self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size())))
elif slice_mode == "table_row_slice" and is_auto_parallel:
self.gatherv2.shard(((get_group_size(), 1), (1, 1)))
self.embeddinglookup.shard(((get_group_size(), 1), (1, 1)))
if target == 'DEVICE':
indices_shape_size = 1
self.gather_revert.shard(((1, 1), (1,)))
self.forward_unique = True
indices_strategy = (1,)*indices_shape_size
self.gatherv2.shard(((get_group_size(), 1), indices_strategy))
self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy))
elif slice_mode == "table_column_slice" and is_auto_parallel:
self.gatherv2.shard(((1, get_group_size()), (1, 1)))
self.embeddinglookup.shard(((1, get_group_size()), (1, 1)))
if target == 'DEVICE':
indices_shape_size = 1
self.gather_revert.shard(((1, get_group_size()), (1,)))
self.forward_unique = True
indices_strategy = (1,)*indices_shape_size
self.gatherv2.shard(((1, get_group_size()), indices_strategy))
self.embeddinglookup.shard(((1, get_group_size()), indices_strategy))
elif slice_mode == "batch_slice" and is_auto_parallel:
self.gatherv2.shard(((1, 1), (get_group_size(), 1)))
self.embeddinglookup.shard(((1, 1), (get_group_size(), 1)))
indices_strategy = [get_group_size()]
indices_strategy.extend([1]*(indices_shape_size - 1))
indices_strategy = tuple(indices_strategy)
self.gatherv2.shard(((1, 1), indices_strategy))
self.embeddinglookup.shard(((1, 1), indices_strategy))
else:
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
self.embedding_table.unique = self.forward_unique
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
@ -221,7 +244,15 @@ class EmbeddingLookup(Cell):
if self.target == "CPU":
out = self.embeddinglookup(self.embedding_table, indices, 0)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
if self.forward_unique:
shp = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape(indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
weight_unique = self.gatherv2(unique_id)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)
out = self.reshape(weight_flatten, shp)
else:
out = self.gatherv2(self.embedding_table, indices, 0)
if self.max_norm is not None:
axis = _make_axis_range(F.rank(indices), F.rank(out))
clip_by_norm = ClipByNorm(axis)

@ -144,6 +144,11 @@ class Optimizer(Cell):
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0
# when a parameter has been unique, there is no need do another unique in optimizer.
for param in self.parameters:
if param.unique:
self._unique = False
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale

@ -67,6 +67,7 @@ from .mul_ds import _mul_ds_tbe
from .real_div import _real_div_tbe
from .real_div_ds import _real_div_ds_tbe
from .relu import _relu_tbe
from .relu_ds import _relu_ds_tbe
from .relu_grad import _relu_grad_tbe
from .relu6 import _relu6_tbe
from .relu6_grad import _relu6_grad_tbe

@ -0,0 +1,40 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ReLU op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
relu_op_info = TBERegOp("ReLU") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("relu.so") \
.compute_cost(10) \
.kernel_name("relu") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("formatAgnostic") \
.dtype_format(DataType.I8_None, DataType.I8_None) \
.dtype_format(DataType.I32_None, DataType.I32_None) \
.dtype_format(DataType.F16_None, DataType.F16_None) \
.dtype_format(DataType.F32_None, DataType.F32_None) \
.get_op_info()
@op_info_register(relu_op_info)
def _relu_ds_tbe():
"""Relu TBE register"""
return

@ -163,7 +163,8 @@ class AllGather(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
x_shape[0] = x_shape[0] * self.rank_size
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.rank_size
return x_shape
def infer_dtype(self, x_dtype):
@ -213,7 +214,8 @@ class _HostAllGather(PrimitiveWithInfer):
def infer_shape(self, x_shape):
validator.check_positive_int(len(x_shape), "x shape", self.name)
x_shape[0] = x_shape[0] * self.group_size
if x_shape[0] > 0:
x_shape[0] = x_shape[0] * self.group_size
return x_shape
def infer_dtype(self, x_dtype):

@ -720,7 +720,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
if field_size > 0:
from mindspore.parallel._tensor import _reshape_param_data_with_weight
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, [field_size])
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size)
else:
from mindspore.parallel._tensor import _reshape_param_data

@ -43,7 +43,7 @@ do
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 &
elif [ $MODE == "field_slice_host_device_mix" ]; then
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 &
elif [ $MODE == "backward_unique" ]; then
elif [ $MODE == "forward_unique" ]; then
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --sparse=1 >train_deep$i.log 2>&1 &
else
python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 &

@ -38,7 +38,7 @@ do
user=$(get_node_user ${cluster_config_path} ${node})
passwd=$(get_node_passwd ${cluster_config_path} ${node})
echo "------------------${user}@${node}---------------------"
if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "backward_unique" ]; then
if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "forward_unique" ]; then
ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}"
else
echo "[ERROR] mode is wrong"

@ -88,7 +88,7 @@ class EvalCallBack(Callback):
Args:
print_per_step (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False):
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_per_step must be int and >= 0.")
@ -99,7 +99,7 @@ class EvalCallBack(Callback):
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
self.eval_values = []
self.host_device_mix = host_device_mix
self.sparse = config.sparse
self.config = config
def epoch_end(self, run_context):
@ -116,7 +116,7 @@ class EvalCallBack(Callback):
ParallelMode.DATA_PARALLEL):
rank_id = get_rank()
start_time = time.time()
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix))
out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse))
end_time = time.time()
eval_time = int(end_time - start_time)

@ -48,6 +48,7 @@ def argparse_init():
parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not")
parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not")
parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not")
parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice")
return parser
@ -86,6 +87,7 @@ class WideDeepConfig():
self.field_slice = False
self.manual_shape = None
self.sparse = False
self.deep_table_slice_mode = "column_slice"
def argparse_init(self):
"""
@ -121,5 +123,6 @@ class WideDeepConfig():
self.parameter_server = args.parameter_server
self.field_slice = bool(args.field_slice)
self.sparse = bool(args.sparse)
self.deep_table_slice_mode = args.deep_table_slice_mode
if self.host_device_mix == 1:
self.sparse = True

@ -198,19 +198,29 @@ class WideDeepModel(nn.Cell):
self.tile = P.Tile()
self.concat = P.Concat(axis=1)
self.cast = P.Cast()
self.unique = P.Unique().shard(((1,),))
self.wide_gatherv2 = P.GatherV2()
self.deep_gatherv2 = P.GatherV2()
if is_auto_parallel and sparse and not is_field_slice:
self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),))
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
target = 'DEVICE'
if host_device_mix:
target = 'CPU'
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True)
if target == 'DEVICE':
self.wide_mul.shard(((1, 1, 1), (1, 1, 1)))
if config.deep_table_slice_mode == "column_slice":
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE)
self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),))
self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),))
self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1)))
self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size)
self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1)))
self.deep_reshape.add_prim_attr("skip_redistribution", True)
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target,
slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE)
self.reduce_sum.add_prim_attr("cross_batch", True)
self.embedding_table = self.deep_embeddinglookup.embedding_table
elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape:
@ -247,13 +257,15 @@ class WideDeepModel(nn.Cell):
id_hldr: batch ids;
wt_hldr: batch weights;
"""
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wide_id_weight = self.wide_embeddinglookup(id_hldr)
# Deep layer
deep_id_embs = self.deep_embeddinglookup(id_hldr)
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
# Wide layer
wx = self.wide_mul(wide_id_weight, mask)
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
# Deep layer
deep_id_embs = self.deep_embeddinglookup(id_hldr)
vx = self.deep_mul(deep_id_embs, mask)
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
deep_in = self.dense_layer_1(deep_in)
@ -333,7 +345,8 @@ class TrainStepWrap(nn.Cell):
parameter_server (Bool): Whether run in parameter server mode. Default: False
"""
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False):
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False,
sparse=False):
super(TrainStepWrap, self).__init__()
parallel_mode = context.get_auto_parallel_context("parallel_mode")
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)

@ -40,8 +40,8 @@ def get_WideDeep_net(config):
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
loss_net = VirtualDatasetCellTriple(loss_net)
train_net = TrainStepWrap(
loss_net, host_device_mix=bool(config.host_device_mix), sparse=config.sparse)
train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix),
sparse=config.sparse)
eval_net = PredictWithSigmoid(WideDeep_net)
eval_net = VirtualDatasetCellTriple(eval_net)
return train_net, eval_net
@ -122,7 +122,7 @@ def train_and_eval(config):
metrics={"auc": auc_metric})
eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config, host_device_mix=host_device_mix)
model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
@ -146,7 +146,7 @@ if __name__ == "__main__":
context.set_context(variable_memory_max_size="24GB")
context.set_context(enable_sparse=True)
init()
if wide_deep_config.host_device_mix == 1:
if wide_deep_config.sparse:
context.set_auto_parallel_context(
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
else:

@ -37,6 +37,8 @@ def argparse_init():
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/")
parser.add_argument("--eval_file_name", type=str, default="eval.log")
parser.add_argument("--loss_file_name", type=str, default="loss.log")
parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not")
parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice")
return parser
@ -66,6 +68,8 @@ class WideDeepConfig():
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
self.stra_ckpt = "./strategy_train.ckpt"
self.sparse = False
self.deep_table_slice_mode = "column_slice"
def argparse_init(self):
"""
@ -94,3 +98,7 @@ class WideDeepConfig():
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path
self.stra_ckpt = args.stra_ckpt
self.sparse = bool(args.sparse)
self.deep_table_slice_mode = args.deep_table_slice_mode
if self.host_device_mix == 1:
self.sparse = True

@ -93,7 +93,7 @@ def test_unique_row_split():
self.embedding_lookp = P.GatherV2().shard(((8, 1), (1,)))
self.embedding_table = Parameter(initializer('normal', [2000, 128]),
name='embedding_table')
self.gatherv2 = P.GatherV2().shard(((1, 1), (8,)))
self.gatherv2 = P.GatherV2().shard(((1, 1), (1,)))
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight")
@ -108,7 +108,7 @@ def test_unique_row_split():
return vx
size = 8
context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="stand_alone")
context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="semi_auto_parallel")
x = Tensor(np.ones([32, 64]), dtype=ms.int32)
net = Net()
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)

Loading…
Cancel
Save