You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ops/operations/_grad_ops.py

1049 lines
37 KiB

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Operators for gradients."""
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from ..._checkparam import ParamValidator as validator
from ..._checkparam import Rel, check_int_positive, check_bool
from .._utils import _get_concat_offset
from ...common import dtype as mstype
class AbsGrad(PrimitiveWithInfer):
"""Computes gradients for abs operation."""
@prim_attr_register
def __init__(self):
"""init AbsGrad"""
def infer_shape(self, y, dy):
return y
def infer_dtype(self, y, dy):
return y
class ACosGrad(PrimitiveWithInfer):
"""
Computes ACosGrad of input element-wise.
Returns:
Tensor, has the same type as input.
"""
@prim_attr_register
def __init__(self):
"""init ACosGrad"""
def infer_shape(self, x, dout):
validator.check_param_equal("x", x, "dout", dout)
return x
def infer_dtype(self, x, dout):
args = {"x": x, "dout": dout}
validator.check_type_same(args, mstype.number_type)
return x
class BatchNormGrad(PrimitiveWithInfer):
"""Performs grad of BatchNorm operation."""
@prim_attr_register
def __init__(self, is_training=False, epsilon=1e-5):
self.is_training = validator.check_type('is_training', is_training, (bool,))
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT)
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape):
validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape)
return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape)
def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type):
return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type)
class BiasAddGrad(Primitive):
"""Computes gradients of BiasAdd."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['dout'], outputs=['output'])
self.add_prim_attr('data_format', 'NCHW')
def __call__(self, d_output):
raise NotImplementedError
class BinaryCrossEntropyGrad(PrimitiveWithInfer):
"""Computes gradients for `BinaryCrossEntropy` operation."""
@prim_attr_register
def __init__(self, reduction='mean'):
self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'])
def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape):
validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape)
if weight_shape:
validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape)
return x_shape
def infer_dtype(self, x_type, y_type, doutput_type, weight_type):
args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type}
validator.check_type_same(args, (mstype.float16, mstype.float32))
if weight_type:
validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type)
return x_type
class ConcatOffset(PrimitiveWithInfer):
"""primitive for computing Concat's gradient."""
@prim_attr_register
def __init__(self, N=2, axis=0):
"""init ConcatOffset"""
def __infer__(self, input_x):
axis = self.axis
x_shp = input_x['shape']
x_type = input_x['dtype']
offset, _, axis = _get_concat_offset(x_shp, x_type, axis)
self.add_prim_attr('T', x_type[0].element_type())
offset_values = []
for i in range(len(x_shp)):
values = []
for j in range(len(x_shp[0])):
value = 0
if j == axis:
value = offset[i]
values.append(value)
offset_values.append(tuple(values))
out = {'shape': None,
'dtype': None,
'value': tuple(offset_values)}
return out
class Conv2DBackpropFilter(PrimitiveWithInfer):
"""
Computes the gradients of convolution with respect to the filter.
Args:
out_channel (int): The dimensionality of the output space.
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
pad (int): The pad value to fill. Default: 0.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution, 3 depthwise convolution. Default: 1.
stride (tuple): The stride to apply conv filter. Default: (1, 1).
dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1).
group (int): Splits input into groups. Default: 1.
Returns:
Tensor, the gradients of convolution.
"""
@prim_attr_register
def __init__(self,
out_channel,
kernel_size,
pad_mode="valid",
pad=0,
pad_list=(0, 0, 0, 0),
mode=1,
stride=(1, 1),
dilation=(1, 1, 1, 1),
group=1):
"""init Convolution"""
self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output'])
self.out_channel = out_channel
self.kernel_size = kernel_size
self.mode = mode
pad_mode = pad_mode.upper()
self.add_prim_attr('pad_mode', pad_mode)
self.pad = pad
if isinstance(stride, tuple) and len(stride) == 4:
self.stride = (stride[2], stride[3])
self.add_prim_attr('stride', self.stride)
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")
def __infer__(self, doutput, x, w_size):
w_size_v = w_size['value']
validator.check_type('w_size', w_size_v, [tuple])
for i, dim_len in enumerate(w_size_v):
validator.check_type("w_size[%d]" % i, dim_len, [int])
validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32])
validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype'])
out = {
'value': None,
'shape': w_size_v,
'dtype': doutput['dtype'],
}
return out
class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer):
"""
Returns the gradient of filter for DepthwiseConv2dNative.
Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
Refer to class DepthwiseConv2dNative for more details.
Args:
channel_multiplier (int): The multipiler for the original output conv.
kernel_size (int or tuple): The size of the conv kernel.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution,
2 deconvolution,3 depthwise convolution. Defaul: 3.
pad_mode (str): The mode to fill padding which can be: "valid", "same" or "pad". Default: "valid".
pad (int): The pad value to fill. Default: 0.
pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
stride (int): The stride to apply conv filter. Default: 1.
dilation (int): Specifies the space to use between kernel elements. Default: 1.
group (int): Splits input into groups. Default: 1.
Returns:
Tensor, the value is the gradient of filter for DepthwiseConv2dNative.
"""
@prim_attr_register
def __init__(self,
channel_multiplier,
kernel_size,
pad_mode="valid",
pad=0,
pads=(0, 0, 0, 0),
mode=3,
stride=1,
dilation=1,
group=1):
"""init Convolution"""
self.init_prim_io_names(inputs=['input', 'filter_size', 'dout'], outputs=['output'])
self.channel_multiplier = channel_multiplier
self.kernel_size = kernel_size
self.mode = mode
self.pad_mode = pad_mode
self.pad = pad
self.pads = pads
self.stride = stride
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")
def __call__(self, x, w_size, dout):
raise NotImplementedError
def __infer__(self, x, w_size, dout):
w_size_v = w_size['value']
args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']}
validator.check_type_same(args, mstype.number_type)
out = {
'value': None,
'shape': w_size_v,
'dtype': dout['dtype'],
}
return out
class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer):
"""
Returns the gradient of input for DepthwiseConv2dNative.
Applies depthwise conv2d for the input, which will generate more channels with channel_multiplier.
Args:
channel_multiplier (int): The multipiler for the original output conv.
kernel_size (int or tuple): The size of the conv kernel.
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
2 deconvolution,3 depthwise convolution. Default: 3.
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
pad (int): the pad value to fill. Default: 0.
pads (tuple): The pad list like (top, bottom, left, right). Default: (0, 0, 0, 0).
stride (int): the stride to apply conv filter. Default: 1.
dilation (int): Specifies the space to use between kernel elements. Default: 1.
group (int): Splits input into groups. Default: 1.
Returns:
Tensor, the value is the gradient of input for DepthwiseConv2dNative.
"""
@prim_attr_register
def __init__(self,
channel_multiplier,
kernel_size,
pad_mode="valid",
pad=0,
pads=(0, 0, 0, 0),
mode=3,
stride=1,
dilation=1,
group=1):
"""init Convolution"""
self.init_prim_io_names(inputs=['input_size', 'filter', 'dout'], outputs=['output'])
self.channel_multiplier = channel_multiplier
self.kernel_size = kernel_size
self.mode = mode
self.pad_mode = pad_mode
self.pad = pad
self.pads = pads
self.stride = stride
self.dilation = dilation
self.group = group
self.add_prim_attr('data_format', "NCHW")
def __call__(self, x_size, w, dout):
raise NotImplementedError
def __infer__(self, x_size, w, dout):
args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']}
validator.check_type_same(args, mstype.number_type)
x_size_v = x_size['value']
out = {
'value': None,
'shape': x_size_v,
'dtype': dout['dtype'],
}
return out
class FlattenGrad(PrimitiveWithInfer):
"""Performs gradients of Flatten."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['output'])
def __infer__(self, *args):
out = {
'value': None,
'shape': args[1]['value'],
'dtype': args[0]['dtype'],
}
return out
class FusedBatchNormGrad(Primitive):
"""Gradients of FusedBatchNorm operation."""
@prim_attr_register
def __init__(self, epsilon=0.0, momentum=0.1):
self.init_prim_io_names(inputs=['dy', 'x', 'scale', 'save_mean', 'save_inv_variance'],
outputs=['dx', 'bn_scale', 'bn_bias'])
def __call__(self, dy, x, scale, save_mean, save_inv_variance):
raise NotImplementedError
class GeluGrad(PrimitiveWithInfer):
"""Gradients of Gelu operation."""
@prim_attr_register
def __init__(self):
"""init GeluGrad"""
def infer_shape(self, y_backprop_shape, x_shape, y_shape):
return x_shape
def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype):
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32))
return x_dtype
class _PoolGrad(PrimitiveWithInfer):
"""Gradients of the max/avg pool operation."""
@prim_attr_register
def __init__(self, ksize, strides, padding="VALID"):
self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output'])
validator.check_type('ksize', ksize, [int, tuple])
validator.check_type('strides', strides, [int, tuple])
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'])
self.add_prim_attr("padding", self.padding)
self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax")
if not self.is_maxpoolgradwithargmax:
self.add_prim_attr('data_format', "NCHW")
if isinstance(ksize, int):
validator.check_integer("ksize", ksize, 1, Rel.GE)
if self.is_maxpoolgradwithargmax:
self.ksize = (1, ksize, ksize, 1)
else:
self.ksize = (1, 1, ksize, ksize)
else:
ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number"
f"or a tuple of two or four positive int numbers, but got {ksize}")
if len(ksize) != 2 and len(ksize) != 4:
raise ksize_error
for ksize_val in ksize:
if not isinstance(ksize_val, int) or (ksize_val <= 0):
raise ksize_error
if len(ksize) == 2 and self.is_maxpoolgradwithargmax:
self.ksize = (1, ksize[0], ksize[1], 1)
elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax:
self.ksize = (1, 1, ksize[0], ksize[1])
else:
self.ksize = ksize
self.add_prim_attr("ksize", self.ksize)
if isinstance(strides, int):
validator.check_integer("strides", strides, 1, Rel.GE)
if self.is_maxpoolgradwithargmax:
self.strides = (1, strides, strides, 1)
else:
self.strides = (1, 1, strides, strides)
else:
strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number"
f"or a tuple of two or four positive int numbers, but got {strides}")
if len(strides) != 2 and len(strides) != 4:
raise strides_error
for strides_val in strides:
if not isinstance(strides_val, int) or (strides_val <= 0):
raise strides_error
if len(strides) == 2 and self.is_maxpoolgradwithargmax:
self.strides = (1, strides[0], strides[1], 1)
elif len(strides) == 2 and not self.is_maxpoolgradwithargmax:
self.strides = (1, 1, strides[0], strides[1])
else:
self.strides = strides
self.add_prim_attr("strides", self.strides)
class AvgPoolGrad(_PoolGrad):
"""Gradients of the avg pool operation."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(AvgPoolGrad, self).__init__(ksize, strides, padding)
def __infer__(self, origin_input, dout):
out = {
'value': None,
'shape': tuple(origin_input['value']),
'dtype': dout['dtype'],
}
return out
class AvgPoolGradGpu(_PoolGrad):
"""Gradients of the avg pool operation for gpu."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(AvgPoolGradGpu, self).__init__(ksize, strides, padding)
def infer_shape(self, x1_shape, x2_shape, grad_shape):
return x1_shape
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
return x1_dtype
class MaxPoolGrad(_PoolGrad):
"""Performs gradients of the max pool operation."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID"):
super(MaxPoolGrad, self).__init__(ksize, strides, padding)
def infer_shape(self, x1_shape, x2_shape, grad_shape):
return x1_shape
def infer_dtype(self, x1_dtype, x2_dtype, grad_dtype):
return x1_dtype
class MaximumGrad(Primitive):
"""Grad for maximum."""
@prim_attr_register
def __init__(self, grad_x=True, grad_y=True):
"""Init MaximumGrad"""
def __call__(self, x, y, dout):
raise NotImplementedError
class MaxPoolGradWithArgmax(_PoolGrad):
"""Computes the gradients of MaxPoolWithArgmax."""
@prim_attr_register
def __init__(self, ksize=1, strides=1, padding="VALID",):
self.init_prim_io_names(inputs=['x', 'grad', 'argmax'], outputs=['output'])
super(MaxPoolGradWithArgmax, self).__init__(ksize, strides, padding)
def infer_shape(self, x_shape, grad_shape, argmax_shape):
if not grad_shape:
raise TypeError("The dout of MaxPoolGradWithArgmax should be a Tensor.")
return x_shape
def infer_dtype(self, x_dtype, grad_dtype, argmax_dtype):
return grad_dtype
class MinimumGrad(Primitive):
"""Grad for minimum."""
@prim_attr_register
def __init__(self, grad_x=True, grad_y=True):
"""Init MinimumGrad"""
def __call__(self, x, y, dout):
raise NotImplementedError
class L2NormalizeGrad(PrimitiveWithInfer):
r"""
Gradients of L2 normalize.
Args:
axis (int): The begin axis for the input to apply L2 normalize. Default: 0.
epsilon (float): A small value added for numerical stability. Default: 1e-4.
Inputs:
- **input_x** (Tensor) - Should be the input `weight` of forward operator L2Normalize.
- **out** (Tensor) - Should be the output of forward operator L2Normalize.
- **dout** (Tensor) - The backprop of the next layer.
Outputs:
Tensor, gradients of L2Normalize `input_x`.
"""
@prim_attr_register
def __init__(self, axis=0, epsilon=1e-4):
validator.check_type('axis', axis, [int])
validator.check_type('epsilon', epsilon, [int, float])
def infer_shape(self, input_x, out, dout):
validator.check_param_equal('input_x', input_x, 'out', out)
validator.check_param_equal('input_x', input_x, 'dout', dout)
return input_x
def infer_dtype(self, input_x, out, dout):
args = {'input_x': input_x, 'out': out, 'dout': dout}
validator.check_type_same(args, mstype.number_type)
return input_x
class LayerNormGrad(Primitive):
"""
Applies the layer normalization to the input array.
This operator will calculate the input gradients of layernorm.
Args:
begin_norm_axis (int): The begin axis for the input to apply layernorm. Default: 1.
begin_params_axis (int): The begin axis for the parameter input to apply layernorm. Default: 1.
Returns:
tuple[int], tuple of 3 values (the gradients of layernorm input, gamma, beta).
"""
@prim_attr_register
def __init__(self, begin_norm_axis=1, begin_params_axis=1):
"""init"""
self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int])
self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int])
def __call__(self, x, dy, variance, mean, gamma):
raise NotImplementedError
class LogSoftmaxGrad(PrimitiveWithInfer):
"""Computes gradient for the Log Softmax activation."""
@prim_attr_register
def __init__(self, axis=-1):
"""init LogSoftmaxGrad"""
validator.check_type("axis", axis, [int])
def infer_shape(self, dout, logits):
rank = len(logits)
validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH)
return logits
def infer_dtype(self, dout, logits):
validator.check_subclass("logits", logits, mstype.tensor)
return logits
class LSTMGradData(PrimitiveWithInfer):
"""Computes the data gradients of LSTM."""
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = check_int_positive(input_size)
self.hidden_size = check_int_positive(hidden_size)
self.num_layers = check_int_positive(num_layers)
self.has_bias = check_bool(has_bias)
self.bidirectional = check_bool(bidirectional)
self.dropout = validator.check_type("dropout", dropout, [float])
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
if bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape,
hx_shape, cx_shape, reserve_shape, state_shape):
# dhy and dcy should be same shape
validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ)
validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ)
validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ)
validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ)
validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ)
validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ)
validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ)
# dy: (seq_len, batch_size, hidden_size * num_directions)
validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ)
validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ)
validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ)
# (seq_len, batch_size, input_size)
dx_shape = (y_shape[0], y_shape[1], self.input_size)
dhx_shape = dhy_shape
dcx_shape = dcy_shape
return (dx_shape, dhx_shape, dcx_shape)
def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype,
hx_dtype, cx_dtype, reserve_dtype, state_dtype):
validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16))
validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),))
validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),))
return (dy_dtype, dy_dtype, dy_dtype)
class LSTMGradWeight(PrimitiveWithInfer):
"""Computes the weight gradients of LSTM."""
@prim_attr_register
def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout):
self.input_size = check_int_positive(input_size)
self.hidden_size = check_int_positive(hidden_size)
self.num_layers = check_int_positive(num_layers)
self.has_bias = check_bool(has_bias)
self.bidirectional = check_bool(bidirectional)
self.dropout = validator.check_type("dropout", dropout, [float])
self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH)
if bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
def infer_shape(self, x_shape, hx_shape, y_shape, reserve_shape, state_shape):
weight_size = 0
gate_size = 4 * self.hidden_size
for layer in range(self.num_layers):
for _ in range(self.num_directions):
input_layer_size = self.input_size if layer == 0 else self.hidden_size * self.num_directions
weight_size += gate_size * input_layer_size
weight_size += gate_size * self.hidden_size
if self.has_bias:
weight_size += 2 * gate_size
return (weight_size, 1, 1)
def infer_dtype(self, x_dtype, hx_dtype, y_dtype, reserve_dtype, state_dtype):
return hx_dtype
class PReLUGrad(PrimitiveWithInfer):
r"""
Gradients of PReLU operation.
Inputs:
- **y_backprop** (Tensor) - Representing the backprop of the next layer.
- **input_x** (Tensor) - Should be the input `input_x` of forward operator PRelu.
- **weight** (Tensor) - Float Tensor, w > 0, should be the input `weight` of forward operator PRelu.
Outputs:
Tensor, with the same type as `input_x`.
"""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, y_backprop_shape, A_shape, w_shape):
return y_backprop_shape, w_shape
def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype):
validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32))
validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32))
validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32))
return y_backprop_dtype, w_dtype
class ReluGrad(Primitive):
"""Performs grad of Relu operation."""
@prim_attr_register
def __init__(self):
"""init ReluGrad"""
self.init_prim_io_names(inputs=['y_backprop', 'x'], outputs=['output'])
def __call__(self, y_backprop, x):
raise NotImplementedError
class ReLU6Grad(PrimitiveWithInfer):
"""Performs grad of ReLU6 operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def __call__(self, y_grad, x):
raise NotImplementedError
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class ReluGradV2(PrimitiveWithInfer):
"""Performs grad of ReLUV2 operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output'])
def __call__(self, gradients, mask):
raise NotImplementedError
def infer_shape(self, gradients_shape, mask_shape):
return gradients_shape
def infer_dtype(self, gradients_dtype, mask_dtype):
args_type = {'gradients': gradients_dtype, 'mask': mask_dtype}
validator.check_args_tensor(args_type)
validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type)
validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,))
return gradients_dtype
class EluGrad(PrimitiveWithInfer):
"""Performs grad of Elu operation."""
@prim_attr_register
def __init__(self):
"""Init EluGrad"""
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
args_type = {'y_grad': y_grad_dtype, 'x': x_dtype}
validator.check_args_tensor(args_type)
args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype}
validator.check_type_same(args_dtype, mstype.float_type)
return x_dtype
class ResizeBilinearGrad(PrimitiveWithInfer):
"""Performs grad of ResizeBilinear operation."""
@prim_attr_register
def __init__(self, align_corners=False):
"""init"""
def infer_shape(self, dout_shape, orig_shape):
return orig_shape
def infer_dtype(self, dout_dtype, orig_type):
return dout_dtype
class ResizeNearestNeighborGrad(PrimitiveWithInfer):
"""
Compute gradient of `ResizeNearestNeighbor` operator.
Note:
The shape of input parameter `size` must be (height, width).
Args:
align_corners (bool): Whether the centers of the 4 corner pixels of the input
and output tensors are aligned. Default: False.
"""
@prim_attr_register
def __init__(self, align_corners=False):
"""Init ResizeNearestNeighborGrad"""
self.init_prim_io_names(inputs=['grads', 'size'], outputs=['y'])
def __infer__(self, grads, size):
shp = (grads['shape'][0],) + (grads['shape'][1],) + size['value']
return {'shape': shp,
'dtype': grads['dtype'],
'value': None}
class ROIAlignGrad(PrimitiveWithInfer):
"""
ROIAlignGrad operator.
Args:
pooled_height (int): The output feature height.
pooled_width (int): The output feature width.
spatial_scale (float): The feature stride.
sample_num (int): Number of sampling points. Default: 2.
"""
@prim_attr_register
def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2):
"""init ROIAlignGrad"""
validator.check_type("pooled_height", pooled_height, [int])
validator.check_type("pooled_width", pooled_width, [int])
validator.check_type("spatial_scale", spatial_scale, [float])
validator.check_type("sample_num", sample_num, [int])
validator.check_type("xdiff_shape", xdiff_shape, [tuple])
self.xdiff_shape = xdiff_shape
self.pooled_height = pooled_height
self.pooled_width = pooled_width
self.spatial_scale = spatial_scale
self.sample_num = sample_num
def infer_shape(self, ydiff_shape, rois_shape):
return self.xdiff_shape
def infer_dtype(self, ydiff_type, rois_type):
return ydiff_type
class SigmoidGrad(PrimitiveWithInfer):
"""Gets the gradient of Sigmoid operation."""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, out, dout):
return out
def infer_dtype(self, out, dout):
validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32))
validator.check_typename("out dtype", out, (mstype.float16, mstype.float32))
args = {"out type": out, "dout type": dout}
validator.check_type_same(args, mstype.number_type)
return out
class HSigmoidGrad(PrimitiveWithInfer):
"""Gets the gradient of HSigmoid operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class HSwishGrad(PrimitiveWithInfer):
"""Gets the gradient of HSwish operation."""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output'])
def infer_shape(self, y_grad_shape, x_shape):
return x_shape
def infer_dtype(self, y_grad_dtype, x_dtype):
validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32))
validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32))
return x_dtype
class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer):
"""Computes the gradients of `SigmoidCrossEntropyWithLogits`."""
@prim_attr_register
def __init__(self):
"""Init SigmoidCrossEntropyWithLogitsGrad"""
self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad'])
def infer_shape(self, x_shape, y_shape, dout_shape):
validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape)
validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape)
return x_shape
def infer_dtype(self, x_dtype, y_dtype, dout_dtype):
args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype}
validator.check_type_same(args, mstype.number_type)
return dout_dtype
class SliceGrad(PrimitiveWithInfer):
"""Reverse of slice."""
@prim_attr_register
def __init__(self):
"""init SliceGrad"""
self.init_prim_io_names(inputs=['dy', 'x', 'begin', 'size'], outputs=['dx'])
def __infer__(self, dy, x, begin, size):
dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value']
dy_shape_len = len(dy_shape)
for i in range(dy_shape_len):
validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE)
validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ)
return {'shape': x_shape,
'dtype': x['dtype'],
'value': None}
class SmoothL1LossGrad(PrimitiveWithInfer):
"""Computes gradient for prediction on SmoothL1Loss."""
@prim_attr_register
def __init__(self, sigma=1.0):
pass
def infer_shape(self, prediction, target, dloss):
validator.check_param_equal('prediction', prediction, 'target', target)
validator.check_param_equal('prediction', prediction, 'dloss', dloss)
return prediction
def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_type_same(args, mstype.number_type)
return dloss
class StridedSliceGrad(PrimitiveWithInfer):
"""
Performs grad of StridedSlice operation.
Args:
begin_mask (int): Start indexing the slice. Default: 0.
end_mask (int): End indexing the slice. Default: 0.
ellipsis_mask (int): An int32 mask. Default: 0.
new_axis_mask (int): An int32 mask. Default: 0.
shrink_axis_mask (int): An int32 mask. Default: 0.
Returns:
Tensor, has the same shape of input.
"""
@prim_attr_register
def __init__(self,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0):
"""init StrideSliceGrad"""
validator.check_type('begin_mask', begin_mask, [int])
validator.check_type('end_mask', end_mask, [int])
validator.check_type('ellipsis_mask', ellipsis_mask, [int])
validator.check_type('new_axis_mask', new_axis_mask, [int])
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
def __infer__(self, dy, shapex, begin, end, strides):
return {'shape': shapex['value'],
'dtype': dy['dtype'],
'value': None}
class TanhGrad(PrimitiveWithInfer):
"""Computes gradient of hyperbolic tangent of input element-wise."""
@prim_attr_register
def __init__(self):
pass
def infer_shape(self, out, dout):
return out
def infer_dtype(self, out, dout):
validator.check_subclass("out", out, mstype.tensor)
validator.check_subclass("dout", dout, mstype.tensor)
args = {"out type": out, "dout type": dout}
validator.check_type_same(args, mstype.number_type)
return out
class MirrorPadGrad(PrimitiveWithInfer):
"""Gradients of MirrorPad operation."""
@prim_attr_register
def __init__(self, mode="REFLECT"):
"""init MirrorPad"""
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'])
self.mode = mode
def __infer__(self, dout, paddings, x):
validator.check_subclass("dout", dout['dtype'], mstype.tensor)
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor)
validator.check_subclass("input_x", x['dtype'], mstype.tensor)
return {'shape': x['shape'],
'dtype': dout['dtype'],
'value': None}
class RefToEmbed(Primitive):
r"""
Make a key from Ref.
The Key is a symbolic_key, is a embedding on Parameter, which is used as a key of the variable in env_type,
and get items by operation `env_get_item` with the symbolic_key instance. The `Parameter` is a ref.
Inputs:
- **input** (Ref) - Target ref, ref is short for reference. The value of a Parameter is a ref.
Outputs:
symbolic_key, made from the Ref.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.weight = mindspore.Parameter(1.0, name='weight')
>>>
>>> def construct(self):
>>> key = RefToEmbed()(self.weight)
>>> return key, self.weight
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD),
)
@prim_attr_register
def __init__(self):
pass