|
|
|
@ -192,10 +192,6 @@ class RMSELoss(_Loss):
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, weighted loss float tensor.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
|
|
|
|
ValueError: If the dimensions are different.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend`` ``GPU``
|
|
|
|
|
|
|
|
|
@ -212,7 +208,6 @@ class RMSELoss(_Loss):
|
|
|
|
|
self.MSELoss = MSELoss()
|
|
|
|
|
|
|
|
|
|
def construct(self, logits, label):
|
|
|
|
|
_check_shape(logits.shape, label.shape)
|
|
|
|
|
rmse_loss = F.sqrt(self.MSELoss(logits, label))
|
|
|
|
|
|
|
|
|
|
return rmse_loss
|
|
|
|
@ -482,16 +477,17 @@ class MultiClassDiceLoss(_Loss):
|
|
|
|
|
obtained through the binary loss of each category, and then the average value.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`.
|
|
|
|
|
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`. The weight shape[0] should be equal to
|
|
|
|
|
y shape[1].
|
|
|
|
|
ignore_indiex (Union[int, None]): Class index to ignore.
|
|
|
|
|
activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'.
|
|
|
|
|
Default: 'softmax'. Choose from: ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh','Sigmoid']
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). y_pred dimension should be greater than 1.
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
- **y** (Tensor) - Tensor of shape (N, C, ...). y dimension should be greater than 1.
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
The data type must be float16 or float32.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
|
|
|
|
@ -533,6 +529,7 @@ class MultiClassDiceLoss(_Loss):
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
|
|
|
|
|
def construct(self, logits, label):
|
|
|
|
|
_check_shape(logits.shape, label.shape)
|
|
|
|
|
_check_ndim_multi(logits.ndim, label.ndim)
|
|
|
|
|
total_loss = 0
|
|
|
|
|
|
|
|
|
|