diff --git a/.gitmodules b/.gitmodules index a241b6d69b..a024019b14 100644 --- a/.gitmodules +++ b/.gitmodules @@ -12,4 +12,4 @@ url = https://github.com/protocolbuffers/protobuf.git [submodule "graphengine"] path = graphengine - url = https://gitee.com/mindspore/graphengine.git + url = https://gitee.com/ms-incubator/graphengine.git diff --git a/CMakeLists.txt b/CMakeLists.txt index dc07ccae8b..6b69c510d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,9 @@ endif () include(${CMAKE_SOURCE_DIR}/cmake/options.cmake) set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/") +if (ENABLE_GE) + add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) +endif () if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Werror -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") diff --git a/graphengine b/graphengine index 579dcb75a9..c27e428e96 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 579dcb75a990b533f9182733a6424f2bd66f0f23 +Subproject commit c27e428e9698dd4f9b198008596676bc2d1b49aa diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 1e1c650239..48a3f5d65e 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -127,7 +127,7 @@ endif() if (ENABLE_GE) if(ENABLE_TRAIN) - target_link_libraries(mindspore ge_client_train hccl) + target_link_libraries(mindspore ge_runner hccl) else () target_link_libraries(mindspore ge_client) endif () diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 59ccb24168..735c9aac09 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -470,7 +470,7 @@ INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {10, INPUT_DESC(grad)}}; ATTR_MAP(ApplyAdam) = {{"use_locking", ATTR_DESC(use_locking, AnyTraits())}, {"use_nesterov", ATTR_DESC(use_nesterov, AnyTraits())}}; -OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(m)}, {2, OUTPUT_DESC(v)}}; +OUTPUT_MAP(ApplyAdam) = {{0, OUTPUT_DESC(var)}}; // Relu6 INPUT_MAP(Relu6) = {{1, INPUT_DESC(x)}}; @@ -823,7 +823,7 @@ OUTPUT_MAP(RealDiv) = {{0, OUTPUT_DESC(y)}}; // Cast INPUT_MAP(Cast) = {{1, INPUT_DESC(x)}}; INPUT_ATTR_MAP(Cast) = {{2, ATTR_DESC(dst_type, AnyTraits())}}; -ATTR_MAP(Cast) = {{"Truncate", ATTR_DESC(truncate, AnyTraits())}}; +ATTR_MAP(Cast) = EMPTY_ATTR_MAP; OUTPUT_MAP(Cast) = {{0, OUTPUT_DESC(y)}}; // Reciprocal @@ -1153,7 +1153,7 @@ INPUT_MAP(SparseApplyAdagradD) = { {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; +OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; // SparseApplyFtrlD INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, diff --git a/mindspore/context.py b/mindspore/context.py index bf6439a7d5..89fb56b843 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -453,11 +453,13 @@ def reset_auto_parallel_context(): _reset_auto_parallel_context() -@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, - save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, - save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, - enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, - check_bprop=bool) +@args_type_check(mode=int, precompile_only=bool, device_target=str, + device_id=int, enable_ir_fusion=bool, save_graphs=bool, + enable_task_sink=bool, save_graphs_path=str, enable_loop_sink=bool, + enable_mem_reuse=bool, save_ms_model=bool, save_ms_model_path=str, enable_gpu_summary=bool, + enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str, + enable_reduce_precision=bool, enable_dynamic_memory=bool, graph_memory_max_size=str, + variable_memory_max_size=str, enable_profiling=bool, profiling_options=str) def set_context(**kwargs): """ Sets context for running environment. diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 658ffb7b46..28c5d9e939 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -292,7 +292,6 @@ class Optimizer(Cell): current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0) lr += (current_dynamic_lr,) F.control_depend(lr, self.assignadd(self.global_step, 1)) - else: lr = self.learning_rate if self.dynamic_lr: diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 9f1ccdf5a9..fa34ac545f 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -516,6 +516,18 @@ def get_bprop_l2_loss(self): return bprop +@bprop_getters.register(P.RNNTLoss) +def get_bprop_rnnt_loss(self): + """Grad definition for `RNNTLoss` operation.""" + expand = P.ExpandDims() + + def bprop(acts, labels, act_lens, label_lens, out, dout): + grad_loss = out[1] + grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1) + return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens) + return bprop + + @bprop_getters.register(P.PReLU) def get_bprop_prelu(self): """Grad definition for `PReLU` operation.""" diff --git a/mindspore/ops/_op_impl/aicpu/__init__.py b/mindspore/ops/_op_impl/aicpu/__init__.py index 37d008940d..bb490d050b 100644 --- a/mindspore/ops/_op_impl/aicpu/__init__.py +++ b/mindspore/ops/_op_impl/aicpu/__init__.py @@ -24,3 +24,6 @@ from .flatten import _flatten_aicpu from .squeeze import _squeeze_aicpu from .expand_dims import _expand_dims_aicpu from .random_choice_with_mask import _random_choice_with_mask_aicpu +from .ctcloss import _ctcloss_aicpu +from .rnnt_loss import _rnnt_loss_aicpu +from .random_categorical import _random_categorical_aicpu diff --git a/mindspore/ops/_op_impl/aicpu/ctcloss.py b/mindspore/ops/_op_impl/aicpu/ctcloss.py new file mode 100644 index 0000000000..c393cb04b6 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/ctcloss.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""CTCLoss op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType +ctcloss_op_info = AiCPURegOp("CTCLoss") \ + .fusion_type("OPAQUE") \ + .input(0, "inputs", "required") \ + .input(1, "labels_indices", "required") \ + .input(2, "labels_values", "required") \ + .input(3, "sequence_length", "required") \ + .output(0, "loss", "required") \ + .output(1, "gradient", "required") \ + .attr("preprocess_collapse_repeated", "bool") \ + .attr("ctc_merge_repeated", "bool") \ + .attr("ignore_longer_outputs_than_inputs", "bool") \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F64_Default, DataType.F64_Default) \ + .dtype_format(DataType.F32_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, + DataType.F32_NCHW, DataType.F32_NCHW) \ + .dtype_format(DataType.F64_NCHW, DataType.I64_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, + DataType.F64_NCHW, DataType.F64_NCHW) \ + .get_op_info() + +@op_info_register(ctcloss_op_info) +def _ctcloss_aicpu(): + """CTCLoss AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/random_categorical.py b/mindspore/ops/_op_impl/aicpu/random_categorical.py new file mode 100644 index 0000000000..a0c6f64c97 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/random_categorical.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RandomCategorical op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +random_categorical_op_info = AiCPURegOp("RandomCategorical") \ + .fusion_type("OPAQUE") \ + .input(0, "logits", "required") \ + .input(1, "num_sample", "required") \ + .input(2, "seed", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ + .get_op_info() + +@op_info_register(random_categorical_op_info) +def _random_categorical_aicpu(): + """RandomCategorical AiCPU register""" + return diff --git a/mindspore/ops/_op_impl/aicpu/rnnt_loss.py b/mindspore/ops/_op_impl/aicpu/rnnt_loss.py new file mode 100644 index 0000000000..d35d102048 --- /dev/null +++ b/mindspore/ops/_op_impl/aicpu/rnnt_loss.py @@ -0,0 +1,37 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""RNNTLoss op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \ + .fusion_type("OPAQUE") \ + .input(0, "acts", "required") \ + .input(1, "labels", "required") \ + .input(2, "input_lengths", "required") \ + .input(3, "label_lengths", "required") \ + .output(0, "costs", "required") \ + .output(1, "grads", "required") \ + .attr("blank_label", "int") \ + .dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW, + DataType.F32_NCHW) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + +@op_info_register(rnnt_loss_op_info) +def _rnnt_loss_aicpu(): + """RNNTLoss AiCPU register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 71c11f492d..a5c2e9edbb 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -51,7 +51,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2 Reciprocal, CumSum, Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh) -from .random_ops import (RandomChoiceWithMask) +from .random_ops import (RandomChoiceWithMask, RandomCategorical) from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, BiasAdd, Conv2D, DepthwiseConv2dNative, @@ -66,6 +66,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, Softplus, + RNNTLoss, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, @@ -155,6 +156,7 @@ __all__ = [ 'HSigmoid', 'Tanh', 'RandomChoiceWithMask', + 'RandomCategorical', 'ResizeBilinear', 'ScalarSummary', 'ImageSummary', @@ -183,6 +185,7 @@ __all__ = [ 'SmoothL1Loss', 'L2Loss', 'CTCLoss', + 'RNNTLoss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ed7237b04c..98a3ccd9a7 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1608,6 +1608,61 @@ class L2Loss(PrimitiveWithInfer): return x_type +class RNNTLoss(PrimitiveWithInfer): + """ + Computes the RNNTLoss and its gradient with respect to the softmax outputs. + + Args: + blank_label (int): blank label. Default: 0. + + Inputs: + - **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`. + - **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`. + - **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + - **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + + Outputs: + - **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`. + - **grads** (Tensor[int32]) - Has the same shape as `acts`. + + Examples: + >>> B, T, U, V = 1, 2, 3, 5 + >>> acts = np.random.random((B, T, U, V)).astype(np.float32) + >>> labels = np.array([[1, 2]]).astype(np.int32) + >>> input_length = np.array([T] * B).astype(np.int32) + >>> label_length = np.array([len(l) for l in labels]).astype(np.int32) + >>> rnnt_loss = P.RNNTLoss(blank_label=blank) + >>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + """ + @prim_attr_register + def __init__(self, blank_label=0): + validator.check_value_type('blank_label', blank_label, [int], self.name) + self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'], + outputs=['costs', 'grads']) + + def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape): + validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name) + validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name) + validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name) + validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name) + validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name) + costs_shape = (acts_shape[0],) + return (costs_shape, acts_shape) + + def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type): + validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name) + validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name) + validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name) + validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name) + validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name) + validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name) + validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name) + validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name) + return (acts_type, acts_type) + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 2692b43b46..77201c25f9 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -64,3 +64,61 @@ class RandomChoiceWithMask(PrimitiveWithInfer): def infer_dtype(self, x_dtype): validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) + + +class RandomCategorical(PrimitiveWithInfer): + """ + Generates random samples from a given categorical distribution tensor. + + Args: + dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16, + mindspore.int32, mindspore.int64]. Default: mindspore.int64. + + Inputs: + - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes]. + - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed. + - **seed** (int) - Random seed. Default: 0. + + Outputs: + - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples]. + + Examples: + >>> class Net(nn.Cell): + >>> def __init__(self, num_sample): + >>> super(Net, self).__init__() + >>> self.random_categorical = P.RandomCategorical(mindspore.int64) + >>> self.num_sample = num_sample + >>> def construct(self, logits, seed=0): + >>> return self.random_categorical(logits, self.num_sample, seed) + >>> + >>> x = np.random.random((10, 5)).astype(np.float32) + >>> net = Net(8) + >>> output = net(Tensor(x)) + """ + @prim_attr_register + def __init__(self, dtype=mstype.int64): + """Init RandomCategorical""" + self.dtype = dtype + + valid_values = (mstype.int32, mstype.int16, mstype.int64) + validator.check_type_name("dtype", dtype, valid_values, self.name) + self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], + outputs=['output']) + + def __infer__(self, logits, num_samples, seed): + logits_dtype = logits['dtype'] + valid_types = (mstype.float32, mstype.float16, mstype.float64) + validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name) + num_samples_v = num_samples['value'] + seed_v = seed['value'] + validator.check_value_type('num_samples', num_samples_v, (int,), self.name) + validator.check_value_type('seed', seed_v, (int,), self.name) + validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name) + x_shape = list(logits['shape']) + if len(x_shape) != 2: + raise ValueError("RandomCategorical shape should be 2-dimension.") + ndim = len(x_shape) - 1 + x_shape[ndim] = num_samples_v + return {'shape': (x_shape), + 'dtype': (self.dtype), + 'value': None} diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py new file mode 100644 index 0000000000..6304e8b111 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_random_categorical.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self, num_sample): + super(Net, self).__init__() + self.random_categorical = P.RandomCategorical(mindspore.int64) + self.num_sample = num_sample + + def construct(self, logits, seed=0): + return self.random_categorical(logits, self.num_sample, seed) + +def test_net(): + x = np.random.random((10, 5)).astype(np.float32) + net = Net(8) + output = net(Tensor(x)) + print(x) + print(output.asnumpy()) + print(output.dtype()) diff --git a/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py new file mode 100644 index 0000000000..c7e2df07f8 --- /dev/null +++ b/tests/st/ops/ascend/test_aicpu_ops/test_rnnt_loss.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore as ms +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.rnnt_loss = P.RNNTLoss(blank_label=0) + + def construct(self, acts, labels, act_lens, label_lens): + return self.rnnt_loss(acts, labels, act_lens, label_lens) + + +def test_net(): + B, T, U, V = 1, 2, 3, 5 + acts = np.random.random((B, T, U, V)).astype(np.float32) + labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32) + input_length = np.array([T] * B).astype(np.int32) + label_length = np.array([len(l) for l in labels]).astype(np.int32) + + rnnt_loss = Net() + costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length)) + print(costs.asnumpy()) + print(grads.asnumpy()) diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index 9213e41450..5e30b074a3 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -129,7 +129,7 @@ add_executable(ut_tests ${UT_SRCS} ${MINDSPORE_SRC_LIST} ${UT_SUTB_SRC_LIST}) if (ENABLE_GE) if(ENABLE_TRAIN) - target_link_libraries(ut_tests PRIVATE graph ge_client_train) + target_link_libraries(ut_tests PRIVATE graph ge_runner) else() target_link_libraries(ut_tests PRIVATE graph ge_client) endif() diff --git a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py index 96c3c936b2..705c85be26 100644 --- a/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py +++ b/tests/ut/python/optimizer/test_optimize_with_parameter_groups.py @@ -29,7 +29,6 @@ context.set_context(mode=context.GRAPH_MODE) class LeNet5(nn.Cell): """ LeNet5 definition """ - def __init__(self): super(LeNet5, self).__init__() self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')