You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
620 lines
17 KiB
620 lines
17 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""array_ops"""
|
|
|
|
from .. import operations as P
|
|
from ..operations import _grad_ops as G
|
|
from ..operations import _inner_ops as inner
|
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
|
from .. import functional as F
|
|
from .grad_base import bprop_getters
|
|
from ..primitive import constexpr
|
|
from ... import context
|
|
from ...common import dtype as mstype
|
|
|
|
reduce_sum = P.ReduceSum()
|
|
unsorted_segment_sum = P.UnsortedSegmentSum()
|
|
transpose = P.Transpose()
|
|
shape_op = P.Shape()
|
|
reshape = P.Reshape()
|
|
invert_permutation = P.InvertPermutation()
|
|
logical_and = P.LogicalAnd()
|
|
|
|
|
|
@bprop_getters.register(P.Fill)
|
|
def get_bprop_fill(self):
|
|
"""Generate bprop for Fill"""
|
|
|
|
def bprop(dtype, dims, x, out, dout):
|
|
return zeros_like(dims), zeros_like(x)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DType)
|
|
def get_bprop_dtype(self):
|
|
"""Generate bprop for DType"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Cast)
|
|
def get_bprop_cast(self):
|
|
"""Generate bprop for Cast"""
|
|
cast = P.Cast()
|
|
get_dtype = P.DType()
|
|
|
|
def bprop(x, t, out, dout):
|
|
dx = cast(dout, get_dtype(x))
|
|
return dx, zeros_like(t)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Shape)
|
|
def get_bprop_shape(self):
|
|
"""Generate bprop for Shape"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Split)
|
|
def get_bprop_split(self):
|
|
"""Generate bprop for Split"""
|
|
axis = self.axis
|
|
|
|
def bprop(x, out, dout):
|
|
concat_op = P.Concat(axis)
|
|
dx = concat_op(dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Rank)
|
|
def get_bprop_rank(self):
|
|
"""Generate bprop for Rank"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Reshape)
|
|
def get_bprop_reshape(self):
|
|
"""Generate bprop for Reshape"""
|
|
|
|
def bprop(x, shp, out, dout):
|
|
shapex = shape_op(x)
|
|
return reshape(dout, shapex), zeros_like(shp)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ExpandDims)
|
|
def get_bprop_expand_dims(self):
|
|
"""Generate bprop for ExpandDims"""
|
|
|
|
def bprop(x, axis, out, dout):
|
|
shapex = shape_op(x)
|
|
return reshape(dout, shapex), zeros_like(axis)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Squeeze)
|
|
def get_bprop_squeeze(self):
|
|
"""Generate bprop for Squeeze"""
|
|
|
|
def bprop(x, out, dout):
|
|
shapex = shape_op(x)
|
|
return (reshape(dout, shapex),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Flatten)
|
|
def get_bprop_flatten(self):
|
|
"""Generate bprop for Flatten"""
|
|
flatten_grad = G.FlattenGrad()
|
|
|
|
def bprop(x, out, dout):
|
|
dx = flatten_grad(dout, shape_op(x))
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@constexpr
|
|
def _tile_shape(multiples, shapex):
|
|
"""Calculate [1,2], [3, 4] -> [1,3,2,4]."""
|
|
len_muli = len(multiples)
|
|
rank = len(shapex)
|
|
len_cmp = len_muli - rank
|
|
max_len = max(len_muli, rank)
|
|
i = 0
|
|
j = 0
|
|
ret = []
|
|
while (i < max_len) and (j < max_len):
|
|
if len_cmp == 0:
|
|
ret.append(multiples[i])
|
|
ret.append(shapex[j])
|
|
i += 1
|
|
j += 1
|
|
elif len_cmp > 0:
|
|
ret.append(multiples[i])
|
|
ret.append(1)
|
|
i += 1
|
|
len_cmp -= 1
|
|
else:
|
|
ret.append(1)
|
|
ret.append(shapex[j])
|
|
len_cmp += 1
|
|
return tuple(ret)
|
|
|
|
|
|
@bprop_getters.register(P.Tile)
|
|
def get_bprop_tile(self):
|
|
"""Generate bprop for Tile"""
|
|
|
|
def bprop(x, multiples, out, dout):
|
|
shapex = shape_op(x)
|
|
r_shape = _tile_shape(multiples, shapex)
|
|
# 0 represents the start index, and 2 represents the step
|
|
axis = F.make_range(0, len(r_shape), 2)
|
|
dx = reduce_sum(reshape(dout, r_shape), axis)
|
|
dx = reshape(dx, shapex)
|
|
return dx, zeros_like(multiples)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(inner.EmbeddingLookup)
|
|
def get_bprop_embedding_lookup(self):
|
|
"""Generate bprop for EmbeddingLookup"""
|
|
host_sub = P.Sub().add_prim_attr('primitive_target', 'CPU')
|
|
host_reshape = P.Reshape().add_prim_attr('primitive_target', 'CPU')
|
|
def bprop_sparse(x, indices, offset, reduce_scatter_flag, split_num, out, dout):
|
|
x_shp = shape_op(x)
|
|
if reduce_scatter_flag is True:
|
|
elu_grad = G.EmbeddingLookupCommGrad()
|
|
actual_dout = elu_grad(dout, split_num)
|
|
else:
|
|
actual_dout = dout
|
|
new_indices = host_sub(indices - offset)
|
|
# Reshape the 'new_indices'
|
|
new_indices_shape_changed = (size_op(new_indices),)
|
|
new_indices = host_reshape(new_indices, new_indices_shape_changed)
|
|
# Reshape the 'actual_dout'
|
|
x_shp_tail = x_shp[1:]
|
|
actual_dout_shape_changed = new_indices_shape_changed + x_shp_tail
|
|
actual_dout = host_reshape(actual_dout, actual_dout_shape_changed)
|
|
return (new_indices, actual_dout, x_shp), zeros_like(new_indices), zeros_like(axis), \
|
|
zeros_like(reduce_scatter_flag), zeros_like(split_num)
|
|
return bprop_sparse
|
|
|
|
|
|
@bprop_getters.register(P.Transpose)
|
|
def get_bprop_transpose(self):
|
|
"""Generate bprop for Transpose"""
|
|
|
|
def bprop(x, perm, out, dout):
|
|
return transpose(dout, invert_permutation(perm)), zeros_like(perm)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Concat)
|
|
def get_bprop_concat(self):
|
|
"""Generate bprop for Concat"""
|
|
axis = self.axis
|
|
|
|
def bprop(x, out, dout):
|
|
dx = ()
|
|
out_offset = G.ConcatOffset(F.tuple_len(x), axis)(x)
|
|
for i in range(F.tuple_len(x)):
|
|
slice_out = P.Slice()(dout, out_offset[i], shape_op(x[i]))
|
|
dx = dx + (slice_out,)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@constexpr
|
|
def _slice_grad_pad(begins, sizes, shapes):
|
|
pads = tuple((begin, shape - begin - size) for begin, size, shape in zip(begins, sizes, shapes))
|
|
return pads
|
|
|
|
|
|
@bprop_getters.register(P.Slice)
|
|
def get_bprop_slice(self):
|
|
"""Generate bprop for Slice"""
|
|
|
|
def bprop(x, begin, size, out, dout):
|
|
dx = P.Pad(_slice_grad_pad(begin, size, shape_op(x)))(dout)
|
|
return (dx, zeros_like(begin), zeros_like(size))
|
|
|
|
def bprop_grad(x, begin, size, out, dout):
|
|
dx = dx = G.SliceGrad()(dout, x, begin, size)
|
|
return (dx, zeros_like(begin), zeros_like(size))
|
|
|
|
if context.get_context('device_target') == "GPU" or context.get_context('device_target') == "CPU":
|
|
return bprop_grad
|
|
return bprop
|
|
|
|
|
|
@constexpr
|
|
def _generate_shape_index(out_shape, indices_shape, axis):
|
|
out_rank = len(out_shape)
|
|
ind_rank = len(indices_shape)
|
|
if axis < 0:
|
|
axis += out_rank - ind_rank + 1
|
|
perm_part1 = tuple(range(axis, axis + ind_rank))
|
|
index = tuple(range(out_rank))
|
|
perm = perm_part1 + index[:axis] + index[axis + ind_rank:]
|
|
return perm
|
|
|
|
|
|
@constexpr
|
|
def _generate_inverse_index(x_shape, axis):
|
|
x_rank = len(x_shape)
|
|
index = tuple(range(x_rank))
|
|
if axis < 0:
|
|
axis += x_rank
|
|
perm = index[1:1 + axis] + (0,) + index[1 + axis:]
|
|
return perm
|
|
|
|
|
|
@bprop_getters.register(P.GatherV2)
|
|
def get_bprop_gather_v2(self):
|
|
"""Generate bprop for GatherV2"""
|
|
|
|
def bprop(x, indices, axis, out, dout):
|
|
if F.rank(dout) == 0:
|
|
dout = P.ExpandDims()(dout, -1)
|
|
if F.rank(indices) == 0:
|
|
indices = P.ExpandDims()(indices, -1)
|
|
x_shp = shape_op(x)
|
|
out_shp = shape_op(dout)
|
|
ind_shp = shape_op(indices)
|
|
# Example: out_shape:(3,2,3) axis 1 -> (1,0,2)
|
|
perm_1 = _generate_shape_index(out_shp, ind_shp, axis)
|
|
values_transpose = transpose(dout, perm_1)
|
|
params_grad = unsorted_segment_sum(values_transpose, indices, shape_op(x)[axis])
|
|
# Example: out_shape:(3,2,3) axis 2 -> (1,2,0)
|
|
perm_2 = _generate_inverse_index(x_shp, axis)
|
|
params_grad = transpose(params_grad, perm_2)
|
|
return params_grad, zeros_like(indices), zeros_like(axis)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Range)
|
|
def get_bprop_range(self):
|
|
"""Generate bprop for Range"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Pack)
|
|
def get_bprop_pack(self):
|
|
"""Generate bprop for Pack"""
|
|
axis = self.axis
|
|
|
|
def bprop(x, out, dout):
|
|
pack_grad = P.Unpack(axis)
|
|
out = pack_grad(dout)
|
|
return (out,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Unpack)
|
|
def get_bprop_unpack(self):
|
|
"""Generate bprop for Unpack"""
|
|
axis = self.axis
|
|
|
|
def bprop(x, out, dout):
|
|
unpack_grad = P.Pack(axis)
|
|
out = unpack_grad(dout)
|
|
return (out,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.StridedSlice)
|
|
def get_bprop_strided_slice(self):
|
|
"""Generate bprop for StridedSlice"""
|
|
input_grad = G.StridedSliceGrad(self.begin_mask,
|
|
self.end_mask,
|
|
self.ellipsis_mask,
|
|
self.new_axis_mask,
|
|
self.shrink_axis_mask)
|
|
|
|
def bprop(x, begin, end, strides, out, dout):
|
|
dx = input_grad(dout, shape_op(x), begin, end, strides)
|
|
return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Eye)
|
|
def get_bprop_eye(self):
|
|
"""Generate bprop for Eye"""
|
|
|
|
def bprop(n, m, t, out, dout):
|
|
return zeros_like(n), zeros_like(m), zeros_like(t)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Select)
|
|
def get_bprop_select(self):
|
|
"""Generate bprop for Select"""
|
|
select = P.Select()
|
|
|
|
def bprop(cond, x, y, out, dout):
|
|
return zeros_like(cond), select(cond, dout, zeros_like(x)), select(cond, zeros_like(y), dout)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.OnesLike)
|
|
def get_bprop_oneslike(self):
|
|
"""Generate bprop for OnesLike"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ZerosLike)
|
|
def get_bprop_zeroslike(self):
|
|
"""Generate bprop for OnesLike"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ResizeNearestNeighbor)
|
|
def get_bprop_resize_nearest_neighbor(self):
|
|
"""Generate bprop for ResizeNearestNeighbor"""
|
|
op = G.ResizeNearestNeighborGrad(self.align_corners)
|
|
|
|
def bprop(inputs, out, dout):
|
|
shp = shape_op(inputs)
|
|
# 2 and 3 represent the height and width
|
|
shp = (shp[2], shp[3])
|
|
return (op(dout, shp),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.GatherNd)
|
|
def get_bprop_gather_nd(self):
|
|
"""Generate bprop for GatherNd"""
|
|
op = P.ScatterNd()
|
|
|
|
def bprop(x, indices, out, dout):
|
|
shp = shape_op(x)
|
|
return op(indices, dout, shp), zeros_like(indices)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ScatterNd)
|
|
def get_bprop_scatter_nd(self):
|
|
"""Generate bprop for ScatterNd"""
|
|
op = P.GatherNd()
|
|
|
|
def bprop(indices, x, shape, out, dout):
|
|
return zeros_like(indices), op(dout, indices), zeros_like(shape)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ScatterNdUpdate)
|
|
def get_bprop_scatter_nd_update(self):
|
|
"""Generate bprop for ScatterNdUpdate"""
|
|
op = P.GatherNd()
|
|
|
|
def bprop(x, indices, update, out, dout):
|
|
return dout, zeros_like(indices), op(dout, indices)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Argmax)
|
|
def get_bprop_argmax(self):
|
|
"""Generate bprop for Argmax"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Argmin)
|
|
def get_bprop_argmin(self):
|
|
"""Generate bprop for Argmin"""
|
|
|
|
def bprop(x, out, dout):
|
|
return (zeros_like(x),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SpaceToDepth)
|
|
def get_bprop_space_to_depth(self):
|
|
"""Generate bprop for SpaceToDepth"""
|
|
op = P.DepthToSpace(self.block_size)
|
|
|
|
def bprop(x, out, dout):
|
|
return (op(dout),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DepthToSpace)
|
|
def get_bprop_depth_to_space(self):
|
|
"""Generate bprop for DepthToSpace"""
|
|
op = P.SpaceToDepth(self.block_size)
|
|
|
|
def bprop(x, out, dout):
|
|
return (op(dout),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.Diag)
|
|
def get_bprop_diag(self):
|
|
"""Generate bprop for Diag"""
|
|
op = P.DiagPart()
|
|
|
|
def bprop(x, out, dout):
|
|
return (op(dout),)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.DiagPart)
|
|
def get_bprop_diag_part(self):
|
|
"""Generate bprop for DiagPart"""
|
|
op = P.Diag()
|
|
|
|
def bprop(x, out, dout):
|
|
return (op(dout),)
|
|
|
|
return bprop
|
|
|
|
|
|
def _GatherDropNegatives(params,
|
|
ids,
|
|
zero_clipped_indices=None,
|
|
is_positive=None):
|
|
"""Helper function for unsorted segment ops."""
|
|
maximum = P.Maximum()
|
|
gather = P.GatherV2()
|
|
greater_equal = P.GreaterEqual()
|
|
rank = P.Rank()
|
|
fill = P.Fill()
|
|
select = P.Select()
|
|
|
|
if zero_clipped_indices is None:
|
|
zero_clipped_indices = maximum(ids, zeros_like(ids))
|
|
gathered = gather(params, zero_clipped_indices, 0)
|
|
if is_positive is None:
|
|
is_positive = greater_equal(ids, 0)
|
|
is_positive_shape = shape_op(is_positive)
|
|
broadcastable_shape = is_positive_shape
|
|
for _ in range(rank(gathered) - rank(is_positive)):
|
|
broadcastable_shape += (1,)
|
|
is_positive = reshape(is_positive, broadcastable_shape)
|
|
gathered_shape = shape_op(gathered)
|
|
is_positive = logical_and(is_positive, fill(mstype.bool_, gathered_shape, 1))
|
|
zero_slice = zeros_like(gathered)
|
|
return (select(is_positive, gathered, zero_slice), zero_clipped_indices, is_positive)
|
|
|
|
|
|
@bprop_getters.register(P.UnsortedSegmentMin)
|
|
def get_bprop_unsorted_segment_min(self):
|
|
"""Generate bprop for UnsortedSegmentMin"""
|
|
equal = P.Equal()
|
|
cast = P.Cast()
|
|
divide = P.RealDiv()
|
|
get_dtype = P.DType()
|
|
select = P.Select()
|
|
|
|
def bprop(x, segment_ids, num_segments, out, dout):
|
|
gathered_outputs, zero_clipped_indices, is_positive = _GatherDropNegatives(out, segment_ids)
|
|
is_selected = equal(x, gathered_outputs)
|
|
is_selected = logical_and(is_selected, is_positive)
|
|
num_selected = unsorted_segment_sum(cast(is_selected, get_dtype(dout)),
|
|
segment_ids, num_segments)
|
|
weighted_grads = divide(dout, num_selected)
|
|
gathered_grads, _, _ = _GatherDropNegatives(weighted_grads, None,
|
|
zero_clipped_indices, is_positive)
|
|
zeros = zeros_like(gathered_grads)
|
|
return select(is_selected, gathered_grads, zeros), zeros_like(segment_ids), zeros_like(num_segments)
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SpaceToBatch)
|
|
def get_bprop_space_to_batch(self):
|
|
"""Generate bprop for SpaceToBatch"""
|
|
space_to_batch_grad = P.BatchToSpace(self.block_size, self.paddings)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = space_to_batch_grad(dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.BatchToSpace)
|
|
def get_bprop_batch_to_space(self):
|
|
"""Generate bprop for BatchToSpace"""
|
|
batch_to_space_grad = P.SpaceToBatch(self.block_size, self.crops)
|
|
|
|
def bprop(x, out, dout):
|
|
dx = batch_to_space_grad(dout)
|
|
return (dx,)
|
|
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.SpaceToBatchND)
|
|
def get_bprop_space_to_batch_nd(self):
|
|
"""Generate bprop for SpaceToBatchND"""
|
|
space_to_batch_nd_grad = P.BatchToSpaceND(self.block_shape, self.paddings)
|
|
def bprop(x, out, dout):
|
|
dx = space_to_batch_nd_grad(dout)
|
|
return (dx,)
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.BatchToSpaceND)
|
|
def get_bprop_batch_to_space_nd(self):
|
|
"""Generate bprop for BatchToSpaceND"""
|
|
batch_to_space_nd_grad = P.SpaceToBatchND(self.block_shape, self.crops)
|
|
def bprop(x, out, dout):
|
|
dx = batch_to_space_nd_grad(dout)
|
|
return (dx,)
|
|
return bprop
|
|
|
|
|
|
@bprop_getters.register(P.ReverseSequence)
|
|
def get_bprop_reverse_sequence(self):
|
|
"""Generate bprop for ReverseSequence"""
|
|
reverse_sequence_grad = P.ReverseSequence(batch_dim=self.batch_dim_, seq_dim=self.seq_dim_)
|
|
|
|
def bprop(x, seq_lengths, out, dout):
|
|
dx = reverse_sequence_grad(dout, seq_lengths)
|
|
return dx, zeros_like(seq_lengths)
|
|
return bprop
|