Merge branch 'incubator-master' into sync_05177ff9_6b1715a7

pull/1775/head
jonyguo 5 years ago
commit 228061818c

2
.gitmodules vendored

@ -12,4 +12,4 @@
url = https://github.com/protocolbuffers/protobuf.git
[submodule "graphengine"]
path = graphengine
url = https://gitee.com/mindspore/graphengine.git
url = https://gitee.com/ms-incubator/graphengine.git

@ -7,6 +7,9 @@ endif ()
include(${CMAKE_SOURCE_DIR}/cmake/options.cmake)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/")
if (ENABLE_GE)
add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0)
endif ()
if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2")

@ -1 +1 @@
Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23
Subproject commit c27e428e9698dd4f9b198008596676bc2d1b49aa

@ -127,7 +127,7 @@ endif()
if (ENABLE_GE)
if(ENABLE_TRAIN)
target_link_libraries(mindspore ge_client_train hccl)
target_link_libraries(mindspore ge_runner hccl)
else ()
target_link_libraries(mindspore ge_client)
endif ()

@ -470,7 +470,7 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)},
{10, INPUT_DESC(grad)}};
ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())},
{"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits<bool>())}};
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}};
OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}};
// Relu6
INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}};
@ -823,7 +823,7 @@ OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}};
// Cast
INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}};
INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits<GEType>())}};
ATTR_MAP(Cast) = {{"Truncate", ATTR_DESC(truncate, AnyTraits<bool>())}};
ATTR_MAP(Cast) = EMPTY_ATTR_MAP;
OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}};
// Reciprocal
@ -1153,7 +1153,7 @@ INPUT_MAP(SparseApplyAdagradD) = {
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}};
ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
// SparseApplyFtrlD
INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)},

@ -453,11 +453,13 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context()
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool,
check_bprop=bool)
@args_type_check(mode=int, precompile_only=bool, device_target=str,
device_id=int, enable_ir_fusion=bool, save_graphs=bool,
enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool,
enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool,
enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
enable_reduce_precision=bool, enable_dynamic_memory=bool, graph_memory_max_size=str,
variable_memory_max_size=str, enable_profiling=bool, profiling_options=str)
def set_context(**kwargs):
"""
Sets context for running environment.

@ -292,7 +292,6 @@ class Optimizer(Cell):
current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
lr += (current_dynamic_lr,)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = self.learning_rate
if self.dynamic_lr:

@ -516,6 +516,18 @@ def get_bprop_l2_loss(self):
return bprop
@bprop_getters.register(P.RNNTLoss)
def get_bprop_rnnt_loss(self):
"""Grad definition for `RNNTLoss` operation."""
expand = P.ExpandDims()
def bprop(acts, labels, act_lens, label_lens, out, dout):
grad_loss = out[1]
grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1)
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
return bprop
@bprop_getters.register(P.PReLU)
def get_bprop_prelu(self):
"""Grad definition for `PReLU` operation."""

@ -24,3 +24,6 @@ from .flatten import _flatten_aicpu
from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu
from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .ctcloss import _ctcloss_aicpu
from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CTCLoss op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
ctcloss_op_info = AiCPURegOp("CTCLoss") \
.fusion_type("OPAQUE") \
.input(0, "inputs", "required") \
.input(1, "labels_indices", "required") \
.input(2, "labels_values", "required") \
.input(3, "sequence_length", "required") \
.output(0, "loss", "required") \
.output(1, "gradient", "required") \
.attr("preprocess_collapse_repeated", "bool") \
.attr("ctc_merge_repeated", "bool") \
.attr("ignore_longer_outputs_than_inputs", "bool") \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW,
DataType.F32_NCHW, DataType.F32_NCHW) \
.dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW,
DataType.F64_NCHW, DataType.F64_NCHW) \
.get_op_info()
@op_info_register(ctcloss_op_info)
def _ctcloss_aicpu():
"""CTCLoss AiCPU register"""
return

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomCategorical op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
random_categorical_op_info = AiCPURegOp("RandomCategorical") \
.fusion_type("OPAQUE") \
.input(0, "logits", "required") \
.input(1, "num_sample", "required") \
.input(2, "seed", "required") \
.output(0, "output", "required") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.get_op_info()
@op_info_register(random_categorical_op_info)
def _random_categorical_aicpu():
"""RandomCategorical AiCPU register"""
return

@ -0,0 +1,37 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RNNTLoss op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \
.fusion_type("OPAQUE") \
.input(0, "acts", "required") \
.input(1, "labels", "required") \
.input(2, "input_lengths", "required") \
.input(3, "label_lengths", "required") \
.output(0, "costs", "required") \
.output(1, "grads", "required") \
.attr("blank_label", "int") \
.dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
DataType.F32_NCHW) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(rnnt_loss_op_info)
def _rnnt_loss_aicpu():
"""RNNTLoss AiCPU register"""
return

@ -51,7 +51,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh)
from .random_ops import (RandomChoiceWithMask)
from .random_ops import (RandomChoiceWithMask, RandomCategorical)
from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
@ -66,6 +66,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
ResizeBilinear, Sigmoid,
SigmoidCrossEntropyWithLogits,
SmoothL1Loss, Softmax, Softplus,
RNNTLoss,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
@ -155,6 +156,7 @@ __all__ = [
'HSigmoid',
'Tanh',
'RandomChoiceWithMask',
'RandomCategorical',
'ResizeBilinear',
'ScalarSummary',
'ImageSummary',
@ -183,6 +185,7 @@ __all__ = [
'SmoothL1Loss',
'L2Loss',
'CTCLoss',
'RNNTLoss',
'ReduceAll',
'ScalarToArray',
'ScalarToTensor',

@ -1608,6 +1608,61 @@ class L2Loss(PrimitiveWithInfer):
return x_type
class RNNTLoss(PrimitiveWithInfer):
"""
Computes the RNNTLoss and its gradient with respect to the softmax outputs.
Args:
blank_label (int): blank label. Default: 0.
Inputs:
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`.
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`.
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
Outputs:
- **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
- **grads** (Tensor[int32]) - Has the same shape as `acts`.
Examples:
>>> B, T, U, V = 1, 2, 3, 5
>>> acts = np.random.random((B, T, U, V)).astype(np.float32)
>>> labels = np.array([[1, 2]]).astype(np.int32)
>>> input_length = np.array([T] * B).astype(np.int32)
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
>>> rnnt_loss = P.RNNTLoss(blank_label=blank)
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
"""
@prim_attr_register
def __init__(self, blank_label=0):
validator.check_value_type('blank_label', blank_label, [int], self.name)
self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
outputs=['costs', 'grads'])
def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name)
validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name)
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
costs_shape = (acts_shape[0],)
return (costs_shape, acts_shape)
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name)
validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name)
validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name)
validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name)
validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name)
validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name)
validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name)
return (acts_type, acts_type)
class SGD(PrimitiveWithInfer):
"""
Computes stochastic gradient descent (optionally with momentum).

@ -64,3 +64,61 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name)
return (mstype.int32, mstype.bool_)
class RandomCategorical(PrimitiveWithInfer):
"""
Generates random samples from a given categorical distribution tensor.
Args:
dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16,
mindspore.int32, mindspore.int64]. Default: mindspore.int64.
Inputs:
- **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
- **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
- **seed** (int) - Random seed. Default: 0.
Outputs:
- **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self, num_sample):
>>> super(Net, self).__init__()
>>> self.random_categorical = P.RandomCategorical(mindspore.int64)
>>> self.num_sample = num_sample
>>> def construct(self, logits, seed=0):
>>> return self.random_categorical(logits, self.num_sample, seed)
>>>
>>> x = np.random.random((10, 5)).astype(np.float32)
>>> net = Net(8)
>>> output = net(Tensor(x))
"""
@prim_attr_register
def __init__(self, dtype=mstype.int64):
"""Init RandomCategorical"""
self.dtype = dtype
valid_values = (mstype.int32, mstype.int16, mstype.int64)
validator.check_type_name("dtype", dtype, valid_values, self.name)
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
outputs=['output'])
def __infer__(self, logits, num_samples, seed):
logits_dtype = logits['dtype']
valid_types = (mstype.float32, mstype.float16, mstype.float64)
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
num_samples_v = num_samples['value']
seed_v = seed['value']
validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
validator.check_value_type('seed', seed_v, (int,), self.name)
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name)
x_shape = list(logits['shape'])
if len(x_shape) != 2:
raise ValueError("RandomCategorical shape should be 2-dimension.")
ndim = len(x_shape) - 1
x_shape[ndim] = num_samples_v
return {'shape': (x_shape),
'dtype': (self.dtype),
'value': None}

@ -0,0 +1,38 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, num_sample):
super(Net, self).__init__()
self.random_categorical = P.RandomCategorical(mindspore.int64)
self.num_sample = num_sample
def construct(self, logits, seed=0):
return self.random_categorical(logits, self.num_sample, seed)
def test_net():
x = np.random.random((10, 5)).astype(np.float32)
net = Net(8)
output = net(Tensor(x))
print(x)
print(output.asnumpy())
print(output.dtype())

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
from mindspore import Tensor
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common.api import ms_function
import numpy as np
import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.rnnt_loss = P.RNNTLoss(blank_label=0)
def construct(self, acts, labels, act_lens, label_lens):
return self.rnnt_loss(acts, labels, act_lens, label_lens)
def test_net():
B, T, U, V = 1, 2, 3, 5
acts = np.random.random((B, T, U, V)).astype(np.float32)
labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32)
input_length = np.array([T] * B).astype(np.int32)
label_length = np.array([len(l) for l in labels]).astype(np.int32)
rnnt_loss = Net()
costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
print(costs.asnumpy())
print(grads.asnumpy())

@ -129,7 +129,7 @@ add_executable(ut_tests ${UT_SRCS} ${MINDSPORE_SRC_LIST} ${UT_SUTB_SRC_LIST})
if (ENABLE_GE)
if(ENABLE_TRAIN)
target_link_libraries(ut_tests PRIVATE graph ge_client_train)
target_link_libraries(ut_tests PRIVATE graph ge_runner)
else()
target_link_libraries(ut_tests PRIVATE graph ge_client)
endif()

@ -29,7 +29,6 @@ context.set_context(mode=context.GRAPH_MODE)
class LeNet5(nn.Cell):
""" LeNet5 definition """
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')

Loading…
Cancel
Save