From c2333ae195dac5968ffdd9665ee9917ffb1d3de8 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Thu, 30 Jul 2020 20:12:40 +0800 Subject: [PATCH] add gpu python --- mindspore/ops/_grad/grad_nn_ops.py | 11 +++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/_grad_ops.py | 22 +++++++++ mindspore/ops/operations/nn_ops.py | 70 +++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 1 deletion(-) diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 61c7e40960..96209755bc 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -732,6 +732,17 @@ def get_bprop_binary_cross_entropy(self): return bprop +@bprop_getters.register(P.KLDivLoss) +def get_bprop_kl_div_loss(self): + """Grad definition for `KLDivLoss` operation.""" + grad = G.KLDivLossGrad(self.reduction) + + def bprop(x, y, out, dout): + dx, dy = grad(x, y, dout) + return dx, dy + + return bprop + @bprop_getters.register(P.Dropout) def get_bprop_dropout(self): diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 71f84cdeba..d10b41aeea 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -73,7 +73,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, SoftmaxCrossEntropyWithLogits, ROIAlign, SparseSoftmaxCrossEntropyWithLogits, Tanh, - TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, + TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, SparseApplyAdagradV2, SparseApplyFtrlV2, FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, @@ -305,6 +305,7 @@ __all__ = [ "LSTM", "Abs", "BinaryCrossEntropy", + "KLDivLoss", "SparseApplyAdagrad", "SparseApplyAdagradV2", "SpaceToDepth", diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 5e5e56f708..fb1e6d1228 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -145,6 +145,23 @@ class BiasAddGrad(Primitive): raise NotImplementedError +class KLDivLossGrad(PrimitiveWithInfer): + """Computes gradients for `KLDivLoss` operation.""" + + @prim_attr_register + def __init__(self, reduction='mean'): + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) + + def infer_shape(self, x_shape, y_shape, doutput_shape): + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) + return x_shape, y_shape + + def infer_dtype(self, x_type, y_type, doutput_type): + args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + return x_type, y_type + + class BinaryCrossEntropyGrad(PrimitiveWithInfer): """Computes gradients for `BinaryCrossEntropy` operation.""" @@ -406,6 +423,7 @@ class FusedBatchNormGrad(Primitive): def __call__(self, dy, x, scale, save_mean, save_inv_variance): raise NotImplementedError + class BNTrainingReduceGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" @@ -420,6 +438,7 @@ class BNTrainingReduceGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, diff_scale, diff_offset, scale, batch_mean, batch_variance): return grads + class BNTrainingUpdateGrad(PrimitiveWithInfer): """Gradients of FusedBatchNorm operation.""" @@ -434,6 +453,7 @@ class BNTrainingUpdateGrad(PrimitiveWithInfer): def infer_dtype(self, grads, x, batch_mean, batch_variance): return (batch_mean, batch_variance) + class GeluGrad(PrimitiveWithInfer): """Gradients of Gelu operation.""" @@ -1319,6 +1339,7 @@ class EmbeddingLookupCommGrad(PrimitiveWithInfer): This works ONLY when 'reduce_scatter_flag' is True in 'EmbeddingLookup'. Roughly speaking, this primitive is implemented by StridedSlice --> _HostAllGather --> Concat. This primitive runs on host. """ + @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['dy', 'split_num'], outputs=['output']) @@ -1519,6 +1540,7 @@ class InvGrad(PrimitiveWithInfer): class LRNGrad(PrimitiveWithInfer): """Computes gradients for LRN operation.""" + @prim_attr_register def __init__(self, depth_radius=5, bias=1.0, alpha=1.0, beta=0.5): self.init_prim_io_names(inputs=['grads', 'x', 'y'], outputs=['z']) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 8573c348e8..257820d471 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -3360,6 +3360,76 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): validator.check_tensor_type_same({'indices': indices_dtype}, valid_types, self.name) return var_dtype, accum_dtype +class KLDivLoss(PrimitiveWithInfer): + r""" + Computes the Kullback-Leibler divergence between the target and the output. + + Note: + Sets input as :math:`x`, input label as :math:`y`, output as :math:`\ell(x, y)`. + Let, + + .. math:: + L = \{l_1,\dots,l_N\}^\top, \quad + l_n = y_n \cdot (\log y_n - x_n) + + Then, + + .. math:: + \ell(x, y) = \begin{cases} + L, & \text{if reduction} = \text{'none';}\\ + \operatorname{mean}(L), & \text{if reduction} = \text{'mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{'sum'.} + \end{cases} + + Args: + reduction (str): Specifies the reduction to apply to the output. + Its value should be one of 'none', 'mean', 'sum'. Default: 'mean'. + + Inputs: + - **input_x** (Tensor) - The input Tensor. The data type must be float32. + - **input_y** (Tensor) - The label Tensor which has same shape as `input_x`. The data type must be float32. + + Outputs: + Tensor or Scalar, if `reduction` is 'none', then output is a tensor and same shape as `input_x`. + Otherwise it is a scalar. + + Examples: + >>> import mindspore + >>> import mindspore.nn as nn + >>> import numpy as np + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.kldiv_loss = P.KLDivLoss() + >>> def construct(self, x, y): + >>> result = self.kldiv_loss(x, y) + >>> return result + >>> + >>> net = Net() + >>> input_x = Tensor(np.array([0.2, 0.7, 0.1]), mindspore.float32) + >>> input_y = Tensor(np.array([0., 1., 0.]), mindspore.float32) + >>> result = net(input_x, input_y) + """ + + @prim_attr_register + def __init__(self, reduction='mean'): + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) + + def infer_shape(self, x_shape, y_shape): + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) + if self.reduction in ('mean', 'sum'): + shape = [] + else: + shape = x_shape + return shape + + def infer_dtype(self, x_type, y_type): + args = {'x': x_type, 'y': y_type} + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) + return x_type class BinaryCrossEntropy(PrimitiveWithInfer): r"""