diff --git a/graphengine b/graphengine index 43a715bc46..d345a800a4 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859 +Subproject commit d345a800a4f7c32eb768ea48667d1ce00b841748 diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 5ec54b2037..76fec9e21c 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -1133,7 +1133,7 @@ INPUT_MAP(SparseApplyAdagradD) = { {1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}}; ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits())}, {"use_locking", ATTR_DESC(use_locking, AnyTraits())}}; -OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}}; +OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}}; // SparseApplyFtrlD INPUT_MAP(SparseApplyFtrlD) = {{1, INPUT_DESC(var)}, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9750549dc5..8f877adc18 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2433,7 +2433,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer): The shape of `indices` must be the same as `grad` in first dimension, the type must be int32. Outputs: - Tensor, has the same shape and type as `var`. + Tuple of 2 Tensor, the updated parameters. + + - **var** (Tensor) - The same shape and data type as `var`. + - **accum** (Tensor) - The same shape and data type as `accum`. """ @prim_attr_register @@ -2448,13 +2451,13 @@ class SparseApplyAdagrad(PrimitiveWithInfer): validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) - return var_shape + return var_shape, accum_shape def infer_dtype(self, var_type, accum_type, grad_type, indices_type): args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} validator.check_tensor_type_same(args, (mstype.float32,), self.name) validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) - return var_type + return var_type, accum_type class LARSUpdate(PrimitiveWithInfer): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1dea7b6502..1c0f5eb5fe 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -814,7 +814,7 @@ test_case_nn_ops = [ ('SparseApplyAdagrad', { 'block': P.SparseApplyAdagrad(0.5), 'desc_inputs': [[3, 3], [3, 3], [3, 3], Tensor(np.ones((3,), np.int32))], - 'desc_bprop': [3, 3], + 'desc_bprop': [[3, 3], [3, 3]], 'skip': ['backward']}), ('Flatten_1', { 'block': NetForFlatten(),