!260 fix embeddinglookupgrad when param shape is one dim

Merge pull request !260 from wuxuejian/grad_embeddinglookup_fix
pull/3198/head
mindspore-ci-bot 5 years ago committed by Gitee
commit 45b6b2f4e3

@ -230,8 +230,9 @@ def get_bprop_embedding_look_up(self):
# Reshape the 'new_indices' # Reshape the 'new_indices'
new_indices_shape_changed = (size_op(new_indices),) new_indices_shape_changed = (size_op(new_indices),)
new_indices = reshape_op(new_indices, new_indices_shape_changed) new_indices = reshape_op(new_indices, new_indices_shape_changed)
x_shp_tail = x_shp[1:] actual_dout_shape_changed = new_indices_shape_changed
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail if len(x_shp) > 1:
actual_dout_shape_changed += x_shp[1:]
actual_dout = reshape_op(dout, actual_dout_shape_changed) actual_dout = reshape_op(dout, actual_dout_shape_changed)
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
return bprop return bprop

@ -15,6 +15,7 @@
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P

Loading…
Cancel
Save