diff --git a/mindspore/ops/_grad/grad_array_ops.py b/mindspore/ops/_grad/grad_array_ops.py index a2a808781e..b53a7412fc 100644 --- a/mindspore/ops/_grad/grad_array_ops.py +++ b/mindspore/ops/_grad/grad_array_ops.py @@ -219,6 +219,24 @@ def get_bprop_embedding_lookup(self): return bprop_sparse +@bprop_getters.register(P.EmbeddingLookup) +def get_bprop_embedding_look_up(self): + """Generate bprop for EmbeddingLookup""" + sub_op = P.Sub() + reshape_op = P.Reshape() + def bprop(x, indices, offset, out, dout): + x_shp = shape_op(x) + new_indices = sub_op(indices, offset) + # Reshape the 'new_indices' + new_indices_shape_changed = (size_op(new_indices),) + new_indices = reshape_op(new_indices, new_indices_shape_changed) + x_shp_tail = x_shp[1:] + actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail + actual_dout = reshape_op(dout, actual_dout_shape_changed) + return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset) + return bprop + + @bprop_getters.register(P.Transpose) def get_bprop_transpose(self): """Generate bprop for Transpose"""