perfect annotation of ops and context and support bool equal

pull/6201/head
buxue 5 years ago
parent 952135362f
commit c3c06514d7

@ -14,7 +14,7 @@
# ============================================================================
"""
The context of mindspore, used to configure the current execution environment,
including execution mode, execution backend and other feature switches.
includes the execution mode, execution backend and other feature switches.
"""
import os
import time
@ -338,40 +338,40 @@ def set_auto_parallel_context(**kwargs):
Note:
Attribute name is required for setting attributes.
If a program has tasks with different parallel modes, then before setting new parallel mode for
If a program has tasks with different parallel modes, then before setting new parallel mode for the
next task, interface mindspore.context.reset_auto_parallel_context() needs to be called to reset
the configuration.
Setting or changing parallel modes must be called before any Initializer created, or RuntimeError
may be raised when compile network.
Setting or changing parallel modes must be called before any creating Initializer, otherwise,
RuntimeError may be raised when compiling the network.
Args:
device_num (int): Available device number, the value must be in [1, 4096]. Default: 1.
global_rank (int): Global rank id, the value must be in [0, 4095]. Default: 0.
gradients_mean (bool): Whether to perform mean operator after all-reduce of mirror.
"stand_alone" do not support gradients_mean. Default: False.
gradient_fp32_sync (bool): Gradients allreduce by fp32 even though gradients is fp16 if this flag is True..
"stand_alone" does not support `gradients_mean`. Default: False.
gradient_fp32_sync (bool): Gradients allreduce by fp32, even though gradients is fp16 if this flag is True..
"stand_alone", "data_parallel" and "hybrid_parallel" do not support
gradient_fp32_sync. Default: True.
parallel_mode (str): There are five kinds of parallel modes, "stand_alone", "data_parallel",
"hybrid_parallel", "semi_auto_parallel" and "auto_parallel". Default: "stand_alone".
- stand_alone: Only one processor working.
- stand_alone: Only one processor is working.
- data_parallel: Distributing the data across different processors.
- data_parallel: Distributes the data across different processors.
- hybrid_parallel: Achieving data parallelism and model parallelism manually.
- hybrid_parallel: Achieves data parallelism and model parallelism manually.
- semi_auto_parallel: Achieving data parallelism and model parallelism by
- semi_auto_parallel: Achieves data parallelism and model parallelism by
setting parallel strategies.
- auto_parallel: Achieving parallelism automatically.
- auto_parallel: Achieves parallelism automatically.
auto_parallel_search_mode (str): There are two kinds of search modes, "recursive_programming"
and "dynamic_programming". Default: "dynamic_programming".
- recursive_programming: Recursive programming search mode.
- dynamic_programming: Dynamic programming search mode.
parameter_broadcast (bool): Indicating whether to broadcast parameters before training.
parameter_broadcast (bool): Whether to broadcast parameters before training.
"stand_alone", "semi_auto_parallel" and "auto_parallel" do not support parameter
broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
@ -468,7 +468,7 @@ def set_context(**kwargs):
When the `save_graphs` attribute is set to True, attribute of `save_graphs_path` is used to set the
intermediate compilation graph storage path. By default, the graphs are saved in the current directory.
As for other configurations and arguments, please refer to the corresponding module
For other configurations and arguments, please refer to the corresponding module
description, the configuration is optional and can be enabled when needed.
Note:
@ -498,9 +498,9 @@ def set_context(**kwargs):
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should no more than 4096. Default: 0.
device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". Default: "Ascend".
device_id (int): ID of the target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should be no more than 4096. Default: 0.
save_graphs (bool): Whether to save graphs. Default: False.
save_graphs_path (str): Path to save graphs. Default: "."
enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: False.
@ -509,33 +509,34 @@ def set_context(**kwargs):
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False.
save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
save_dump_path (str): When the program is executed on Ascend, operators can dump data in this path.
The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
variable_memory_max_size (str): Sets variable memory max size. Default: "0GB".
variable_memory_max_size (str): Set the maximum size of the variable memory max size. Default: "0GB".
enable_profiling (bool): Whether to open profiling. Default: False.
profiling_options (str): Sets profiling collection options, operators can profiling data here.
Profiling collection options, the values are as follows, supporting the collection of multiple data.
profiling_options (str): Set profiling collection options, operators can profiling data here.
The values of profiling collection options are as follows, supporting the collection of multiple data.
- training_trace: collect iterative trajectory data, that is, the training task and software information of
the AI software stack, to achieve performance analysis of the training task, focusing on data
enhancement, forward and backward calculation, gradient aggregation update and other related data.
- task_trace: collect task trajectory data, that is, the hardware information of the HWTS/AICore of
the Ascend 910 processor, and analyze the information of start and end of the task.
the Ascend 910 processor, and analyze the information of beginning and ending of the task.
- op_trace: collect single operator performance data.
The profiling can choose training_trace, task_trace, training_trace and task_trace combination and
separated by colons; single operator can choose op_trace, op_trace cannot be combined with
training_trace and task_trace. Default: "training_trace".
The profiling can choose the combination of `training_trace`, `task_trace`,
`training_trace` and `task_trace` combination, and eparated by colons;
a single operator can choose `op_trace`, `op_trace` cannot be combined with
`training_trace` and `task_trace`. Default: "training_trace".
check_bprop (bool): Whether to check bprop. Default: False.
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
max_device_memory (str): Sets the maximum memory available for devices.
Currently, it is only supported on GPU. The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of saving print data. If this parameter is set, print data is saved to
a file by default, and turns off printing to the screen. If the file already exists, add a timestamp
suffix to the file. Default: ''.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
max_call_depth(int): Specify the maximum depth of function call. Default: 1000.
Raises:
ValueError: If input key is not an attribute in context.
@ -614,13 +615,13 @@ class ParallelMode:
There are five kinds of parallel modes, "STAND_ALONE", "DATA_PARALLEL",
"HYBRID_PARALLEL", "SEMI_AUTO_PARALLEL" and "AUTO_PARALLEL". Default: "STAND_ALONE".
- STAND_ALONE: Only one processor working.
- DATA_PARALLEL: Distributing the data across different processors.
- HYBRID_PARALLEL: Achieving data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieving data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieving parallelism automatically.
- STAND_ALONE: Only one processor is working.
- DATA_PARALLEL: Distributes the data across different processors.
- HYBRID_PARALLEL: Achieves data parallelism and model parallelism manually.
- SEMI_AUTO_PARALLEL: Achieves data parallelism and model parallelism by setting parallel strategies.
- AUTO_PARALLEL: Achieves parallelism automatically.
MODE_LIST: The list for all supported parallel modes.
MODE_LIST: The list of all supported parallel modes.
"""
STAND_ALONE = "stand_alone"

@ -26,6 +26,21 @@ using ".register" decorator
"""
@equal.register("Bool", "Bool")
def _equal_bool(x, y):
"""
Determine if two bool objects are equal.
Args:
x (bool): first input bool object.
y (bool): second input bool object.
Returns:
bool, if x == y return true, x != y return false.
"""
return F.bool_eq(x, y)
@equal.register("Number", "Number")
def _equal_scalar(x, y):
"""

@ -123,6 +123,7 @@ string_concat = Primitive('string_concat')
bool_not = Primitive("bool_not")
bool_or = Primitive("bool_or")
bool_and = Primitive("bool_and")
bool_eq = Primitive("bool_eq")
logical_and = P.LogicalAnd()
logical_or = P.LogicalOr()
logical_not = P.LogicalNot()

File diff suppressed because it is too large Load Diff

@ -117,7 +117,7 @@ class ImageSummary(PrimitiveWithInfer):
class TensorSummary(PrimitiveWithInfer):
"""
Output tensor to protocol buffer through tensor summary operator.
Output a tensor to a protocol buffer through a tensor summary operator.
Inputs:
- **name** (str) - The name of the input variable.

@ -125,10 +125,10 @@ class TensorAdd(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor, or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
@ -1079,10 +1079,10 @@ class Sub(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor, or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
@ -1157,10 +1157,10 @@ class SquaredDifference(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is float16, float32, int32 or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is float16, float32, int32 or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor or a tensor whose data type is
float16, float32, int32 or bool.
Outputs:
@ -1863,10 +1863,10 @@ class TruncateDiv(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor, or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
@ -1893,10 +1893,10 @@ class TruncateMod(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is number or bool.
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is number or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor, or a tensor whose data type is number or bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
@ -2048,10 +2048,10 @@ class Xdivy(_MathBinaryOp):
the scalar could only be a constant.
Inputs:
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number or
a bool or a tensor whose data type is float16, float32 or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number or
a bool when the first input is a tensor or a tensor whose data type is float16, float32 or bool.
- **input_x** (Union[Tensor, Number, bool]) - The first input is a number, or a bool,
or a tensor whose data type is float16, float32 or bool.
- **input_y** (Union[Tensor, Number, bool]) - The second input is a number,
or a bool when the first input is a tensor, or a tensor whose data type is float16, float32 or bool.
Outputs:
Tensor, the shape is the same as the one after broadcasting,
@ -3069,7 +3069,7 @@ class Sign(PrimitiveWithInfer):
Note:
.. math::
sign(x) = \begin{cases} -1, &if\ x < 0 \cr
0, &if\ x == 0 \cr
0, &if\ x = 0 \cr
1, &if\ x > 0\end{cases}
Inputs:
@ -3251,7 +3251,7 @@ class SquareSumAll(PrimitiveWithInfer):
Inputs:
- **input_x1** (Tensor) - The input tensor. The data type must be float16 or float32.
- **input_x2** (Tensor) - The input tensor same type and shape as the `input_x1`.
- **input_x2** (Tensor) - The input tensor has the same type and shape as the `input_x1`.
Note:
SquareSumAll only supports float16 and float32 data type.

@ -98,7 +98,7 @@ class Softmax(PrimitiveWithInfer):
Softmax operation.
Applies the Softmax operation to the input tensor on the specified axis.
Suppose a slice in the given aixs :math:`x` then for each element :math:`x_i`
Suppose a slice in the given aixs :math:`x`, then for each element :math:`x_i`,
the Softmax function is shown as follows:
.. math::
@ -107,7 +107,7 @@ class Softmax(PrimitiveWithInfer):
where :math:`N` is the length of the tensor.
Args:
axis (Union[int, tuple]): The axis to do the Softmax operation. Default: -1.
axis (Union[int, tuple]): The axis to perform the Softmax operation. Default: -1.
Inputs:
- **logits** (Tensor) - The input of Softmax, with float16 or float32 data type.
@ -1549,17 +1549,17 @@ class TopK(PrimitiveWithInfer):
Finds values and indices of the `k` largest entries along the last dimension.
Args:
sorted (bool): If true, the resulting elements will
sorted (bool): If True, the obtained elements will
be sorted by the values in descending order. Default: False.
Inputs:
- **input_x** (Tensor) - Input to be computed, data type should be float16, float32 or int32.
- **k** (int) - Number of top elements to be computed along the last dimension, constant input is needed.
- **k** (int) - The number of top elements to be computed along the last dimension, constant input is needed.
Outputs:
Tuple of 2 Tensors, the values and the indices.
Tuple of 2 tensors, the values and the indices.
- **values** (Tensor) - The `k` largest elements along each last dimensional slice.
- **values** (Tensor) - The `k` largest elements in each slice of the last dimensional.
- **indices** (Tensor) - The indices of values within the last dimension of input.
Examples:
@ -1593,7 +1593,7 @@ class TopK(PrimitiveWithInfer):
class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
r"""
Gets the softmax cross-entropy value between logits and labels which shoule be one-hot encoding.
Gets the softmax cross-entropy value between logits and labels with one-hot encoding.
Note:
Sets input logits as `X`, input label as `Y`, output as `loss`. Then,
@ -1609,7 +1609,7 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
- **labels** (Tensor) - Ground truth labels, with shape :math:`(N, C)`, has the same data type with `logits`.
Outputs:
Tuple of 2 Tensors, the loss shape is `(N,)`, and the dlogits with the same shape as `logits`.
Tuple of 2 tensors, the `loss` shape is `(N,)`, and the `dlogits` with the same shape as `logits`.
Examples:
>>> logits = Tensor([[2, 4, 1, 4, 5], [2, 1, 2, 4, 3]], mindspore.float32)
@ -1653,7 +1653,7 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer):
loss = \sum_{ij} loss_{ij}
Args:
is_grad (bool): If it's true, this operation returns the computed gradient. Default: False.
is_grad (bool): If true, this operation returns the computed gradient. Default: False.
Inputs:
- **logits** (Tensor) - Input logits, with shape :math:`(N, C)`. Data type should be float16 or float32.
@ -4084,19 +4084,19 @@ class SparseApplyAdagrad(PrimitiveWithInfer):
Args:
lr (float): Learning rate.
update_slots (bool): If `True`, `accum` will be updated. Default: True.
use_locking (bool): If true, the var and accumulation tensors will be protected from being updated.
use_locking (bool): If true, the `var` and `accumulation` tensors will be protected from being updated.
Default: False.
Inputs:
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
- **accum** (Parameter) - Accumulation to be updated. The shape and dtype should be the same as `var`.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension.
Has the same data type as `var`.
- **accum** (Parameter) - Accumulation to be updated. The shape and data type should be the same as `var`.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except the first dimension.
Gradients has the same data type as `var`.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
Outputs:
Tuple of 2 Tensors, the updated parameters.
Tuple of 2 tensors, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
@ -4170,20 +4170,20 @@ class SparseApplyAdagradV2(PrimitiveWithInfer):
Args:
lr (float): Learning rate.
epsilon (float): A small value added for numerical stability.
use_locking (bool): If `True`, the var and accumulation tensors will be protected from being updated.
use_locking (bool): If `True`, the `var` and `accum` tensors will be protected from being updated.
Default: False.
update_slots (bool): If `True`, the computation logic will be different to `False`. Default: True.
Inputs:
- **var** (Parameter) - Variable to be updated. The data type must be float16 or float32.
- **accum** (Parameter) - Accumulation to be updated. The shape and dtype should be the same as `var`.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except first dimension.
Has the same data type as `var`.
- **accum** (Parameter) - Accumulation to be updated. The shape and data type should be the same as `var`.
- **grad** (Tensor) - Gradient. The shape must be the same as `var`'s shape except the first dimension.
Gradients has the same data type as `var`.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension, the type must be int32.
Outputs:
Tuple of 2 Tensors, the updated parameters.
Tuple of 2 tensors, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
@ -4361,23 +4361,23 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
RuntimeError exception will be thrown when the data type conversion of Parameter is required.
Args:
use_locking (bool): If true, the var and accumulation tensors will be protected from being updated.
use_locking (bool): If true, the `var` and `accum` tensors will be protected from being updated.
Default: False.
Inputs:
- **var** (Parameter) - Variable tensor to be updated. The data type must be float16 or float32.
- **accum** (Parameter) - Variable tensor to be updated, has the same dtype as `var`.
- **lr** (Union[Number, Tensor]) - The learning rate value. Tshould be a float number or
- **lr** (Union[Number, Tensor]) - The learning rate value. should be a float number or
a scalar tensor with float16 or float32 data type.
- **l1** (Union[Number, Tensor]) - l1 regularization strength. should be a float number or
a scalar tensor with float16 or float32 data type.
- **l2** (Union[Number, Tensor]) - l2 regularization strength. should be a float number or
a scalar tensor with float16 or float32 data type..
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
- **indices** (Tensor) - A vector of indices in the first dimension of `var` and `accum`.
Outputs:
Tuple of 2 Tensors, the updated parameters.
Tuple of 2 tensors, the updated parameters.
- **var** (Tensor) - The same shape and data type as `var`.
- **accum** (Tensor) - The same shape and data type as `accum`.
@ -4982,16 +4982,16 @@ class SparseApplyFtrl(PrimitiveWithCheck):
Inputs:
- **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
- **accum** (Parameter) - The accumulation to be updated, must be same type and shape as `var`.
- **linear** (Parameter) - the linear coefficient to be updated, must be same type and shape as `var`.
- **accum** (Parameter) - The accumulation to be updated, must be same data type and shape as `var`.
- **linear** (Parameter) - the linear coefficient to be updated, must be the same data type and shape as `var`.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
- **indices** (Tensor) - A vector of indices in the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in the first dimension. The type must be int32.
Outputs:
- **var** (Tensor) - Tensor, has the same shape and type as `var`.
- **accum** (Tensor) - Tensor, has the same shape and type as `accum`.
- **linear** (Tensor) - Tensor, has the same shape and type as `linear`.
- **var** (Tensor) - Tensor, has the same shape and data type as `var`.
- **accum** (Tensor) - Tensor, has the same shape and data type as `accum`.
- **linear** (Tensor) - Tensor, has the same shape and data type as `linear`.
Examples:
>>> import mindspore
@ -5074,18 +5074,18 @@ class SparseApplyFtrlV2(PrimitiveWithInfer):
Inputs:
- **var** (Parameter) - The variable to be updated. The data type must be float16 or float32.
- **accum** (Parameter) - The accumulation to be updated, must be same type and shape as `var`.
- **linear** (Parameter) - the linear coefficient to be updated, must be same type and shape as `var`.
- **accum** (Parameter) - The accumulation to be updated, must be same data type and shape as `var`.
- **linear** (Parameter) - the linear coefficient to be updated, must be same data type and shape as `var`.
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
- **indices** (Tensor) - A vector of indices into the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in first dimension. The type must be int32.
- **indices** (Tensor) - A vector of indices in the first dimension of `var` and `accum`.
The shape of `indices` must be the same as `grad` in the first dimension. The type must be int32.
Outputs:
Tuple of 3 Tensor, the updated parameters.
- **var** (Tensor) - Tensor, has the same shape and type as `var`.
- **accum** (Tensor) - Tensor, has the same shape and type as `accum`.
- **linear** (Tensor) - Tensor, has the same shape and type as `linear`.
- **var** (Tensor) - Tensor, has the same shape and data type as `var`.
- **accum** (Tensor) - Tensor, has the same shape and data type as `accum`.
- **linear** (Tensor) - Tensor, has the same shape and data type as `linear`.
Examples:
>>> import mindspore

@ -34,7 +34,7 @@ class StandardNormal(PrimitiveWithInfer):
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
Outputs:
Tensor. The shape that the input 'shape' denotes. The dtype is float32.
Tensor. The shape is the same as the input `shape`. The dtype is float32.
Examples:
>>> shape = (4, 16)
@ -239,13 +239,13 @@ class UniformInt(PrimitiveWithInfer):
Inputs:
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
- **minval** (Tensor) - The a distribution parameter.
It defines the minimum possibly generated value. With int32 data type. Only one number is supported.
- **maxval** (Tensor) - The b distribution parameter.
It defines the maximum possibly generated value. With int32 data type. Only one number is supported.
- **minval** (Tensor) - The distribution parameter, a.
It defines the minimum possibly generated value, with int32 data type. Only one number is supported.
- **maxval** (Tensor) - The distribution parameter, b.
It defines the maximum possibly generated value, with int32 data type. Only one number is supported.
Outputs:
Tensor. The shape that the input 'shape' denotes. The dtype is int32.
Tensor. The shape is the same as the input 'shape', and the data type is int32.
Examples:
>>> shape = (4, 16)
@ -284,7 +284,7 @@ class UniformInt(PrimitiveWithInfer):
class UniformReal(PrimitiveWithInfer):
r"""
Produces random floating-point values i, uniformly distributed on the interval [0, 1).
Produces random floating-point values i, uniformly distributed to the interval [0, 1).
Args:
seed (int): Random seed. Must be non-negative. Default: 0.

@ -29,10 +29,10 @@ class SparseToDense(PrimitiveWithInfer):
Inputs:
- **indices** (Tensor) - The indices of sparse representation.
- **values** (Tensor) - Values corresponding to each row of indices.
- **dense_shape** (tuple) - A int tuple which specifies the shape of dense tensor.
- **dense_shape** (tuple) - An int tuple which specifies the shape of dense tensor.
Returns:
Tensor, the shape of tensor is dense_shape.
Tensor, the shape of tensor is `dense_shape`.
Examples:
>>> indices = Tensor([[0, 1], [1, 2]])

@ -62,7 +62,13 @@ else
exit ${RET}
fi
pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train $CURRPATH/ops
pytest -n 4 --dist=loadfile -v $CURRPATH/parallel $CURRPATH/train
RET=$?
if [ ${RET} -ne 0 ]; then
exit ${RET}
fi
pytest -n 2 --dist=loadfile -v $CURRPATH/ops
RET=$?
if [ ${RET} -ne 0 ]; then
exit ${RET}

Loading…
Cancel
Save