From 44942be2df6238993a2162ec4e45285beca7650b Mon Sep 17 00:00:00 2001 From: An Xiao <369376805@qq.com> Date: Mon, 22 Mar 2021 15:10:50 +0800 Subject: [PATCH] Add IPT Ascend * Merge branch 'IPT' of gitee.com:xiaoan95/mindspore into ipt_ascend * Add IPT Ascend --- model_zoo/research/cv/IPT/eval.py | 5 +- model_zoo/research/cv/IPT/src/ipt.py | 368 +++++++++++++-------------- 2 files changed, 183 insertions(+), 190 deletions(-) diff --git a/model_zoo/research/cv/IPT/eval.py b/model_zoo/research/cv/IPT/eval.py index dafbb05f6b..8f70618168 100755 --- a/model_zoo/research/cv/IPT/eval.py +++ b/model_zoo/research/cv/IPT/eval.py @@ -23,7 +23,7 @@ from mindspore import context import mindspore.dataset as de from mindspore.train.serialization import load_checkpoint, load_param_into_net -context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=0) +context.set_context(mode=context.GRAPH_MODE, device_target="ASCEND", device_id=0) def main(): @@ -46,11 +46,12 @@ def main(): net_m.set_train(False) num_imgs = train_de_dataset.get_dataset_size() psnrs = np.zeros((num_imgs, 1)) + inference = ipt.IPT_post(net_m, args) for batch_idx, imgs in enumerate(train_loader): lr = imgs['LR'] hr = imgs['HR'] hr_np = np.float32(hr.asnumpy()) - pred = net_m.infrc(lr) + pred = inference.forward(lr) pred_np = np.float32(pred.asnumpy()) pred_np = quantize(pred_np, 255) psnr = calc_psnr(pred_np, hr_np, 4, 255.0, y_only=True) diff --git a/model_zoo/research/cv/IPT/src/ipt.py b/model_zoo/research/cv/IPT/src/ipt.py index 8f51a974f0..1e142bdf28 100755 --- a/model_zoo/research/cv/IPT/src/ipt.py +++ b/model_zoo/research/cv/IPT/src/ipt.py @@ -23,9 +23,6 @@ from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter -# from mindspore.ops.primitive import constexpr -# import IPython - class MultiheadAttention(nn.Cell): """ Apply multi-headed attention from "from_tensor" to "to_tensor". @@ -85,7 +82,7 @@ class MultiheadAttention(nn.Cell): self.shape_q_2d = (-1, q_tensor_width) self.shape_k_2d = (-1, k_tensor_width) self.shape_v_2d = (-1, v_tensor_width) - self.hidden_width = hidden_width + self.hidden_width = int(hidden_width) # units = num_attention_heads * self.size_per_head if self.same_dim: self.in_proj_layer = \ @@ -132,46 +129,49 @@ class MultiheadAttention(nn.Cell): self.softmax_cast = P.Cast() self.matmul_dense = P.MatMul(transpose_b=True) self.split = P.Split(0, 3) + self.equal = P.Equal() + self.shape = P.Shape() - def construct(self, tensor_q, tensor_k, tensor_v, batch_size, seq_length, attention_mask=None): + def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None): """Apply multihead attention.""" - self.batch_size = batch_size - shape_qkv = (self.batch_size, -1, + batch_size, seq_length, _ = self.shape(tensor_q) + shape_qkv = (batch_size, -1, self.num_attention_heads, self.size_per_head) - shape_linear = (self.batch_size * seq_length, + shape_linear = (batch_size * seq_length, self.num_attention_heads * self.size_per_head) - if self.do_return_2d_tensor: - shape_return = (self.batch_size * seq_length, + if self.do_return_2d_tensor is True: + shape_return = (batch_size * seq_length, self.num_attention_heads * self.size_per_head) if seq_length == -1: shape_return = (-1, self.num_attention_heads * self.size_per_head) else: - shape_return = (self.batch_size, seq_length, + shape_return = (batch_size, seq_length, self.num_attention_heads * self.size_per_head) tensor_q_2d = self.reshape(tensor_q, self.shape_q_2d) tensor_k_2d = self.reshape(tensor_k, self.shape_k_2d) tensor_v_2d = self.reshape(tensor_v, self.shape_v_2d) - if P.Equal()(tensor_q_2d, tensor_v_2d)[0][0]: + if self.equal(tensor_q_2d, tensor_v_2d) is True: x = self.matmul_dense(self.in_proj_layer, tensor_q_2d) query_out, key_out, value_out = self.split(x) - elif self.same_dim: - _start = int(0) - _end = int(self.hidden_width) + elif self.same_dim is True: + _start = 0 + _end = self.hidden_width _w = self.in_proj_layer[_start:_end, :] # _b = None query_out = self.matmul_dense(_w, tensor_q_2d) - _start = int(self.hidden_width) - _end = int(self.hidden_width * 2) + _start = self.hidden_width + _end = self.hidden_width * 2 _w = self.in_proj_layer[_start:_end, :] # _b = None key_out = self.matmul_dense(_w, tensor_k_2d) - _start = int(self.hidden_width * 2) + _start = self.hidden_width * 2 + _end = None _w = self.in_proj_layer[_start:] # _b = None @@ -247,7 +247,7 @@ class TransformerEncoderLayer(nn.Cell): permute_recover = (b, n, d) src2 = self.norm1(src) q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn(q, k, src2, batch_size=b, seq_length=n) + src2 = self.self_attn(q, k, src2) src = src + self.dropout1(src2) src2 = self.norm2(src) src2 = self.reshape(src2, permute_linear) @@ -301,13 +301,12 @@ class TransformerDecoderLayer(nn.Cell): permute_recover = (b, n, d) tgt2 = self.norm1(tgt) q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, tensor_v=tgt2, batch_size=b, seq_length=n) + tgt2 = self.self_attn(q, k, tensor_v=tgt2) tgt = tgt + self.dropout1(tgt2) tgt2 = self.norm2(tgt) tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos), tensor_k=self.with_pos_embed(memory, pos), - tensor_v=memory, - batch_size=b, seq_length=n) + tensor_v=memory,) tgt = tgt + self.dropout2(tgt2) tgt2 = self.norm3(tgt) tgt2 = self.reshape(tgt2, permute_linear) @@ -393,6 +392,7 @@ class VisionTransformer(nn.Cell): num_layers, hidden_dim, num_queries, + idx, positional_encoding_type="learned", dropout_rate=0, norm=False, @@ -422,7 +422,7 @@ class VisionTransformer(nn.Cell): self.no_pos = no_pos self.unf = _unfold_(patch_dim) - self.fold = _fold_(patch_dim) + self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim)) if self.mlp is not True: self.linear_encoding = nn.Dense( @@ -437,7 +437,6 @@ class VisionTransformer(nn.Cell): self.query_embed = nn.Embedding( num_queries, embedding_dim * self.seq_length) - encoder_layer = TransformerEncoderLayer( embedding_dim, num_heads, hidden_dim, dropout_rate) self.encoder = TransformerEncoder(encoder_layer, num_layers) @@ -455,30 +454,31 @@ class VisionTransformer(nn.Cell): ) self.dropout_layer1 = nn.Dropout(1. - dropout_rate) - - def construct(self, x, query_idx): + self.query_idx = idx + self.query_idx_tensor = Tensor(idx, mstype.int32) + def construct(self, x): """ipt""" B, _, _, _ = x.shape x = self.unf(x) B, N, _ = x.shape if self.mlp is not True: - x = self.reshape(x, (int(B * N), -1)) + x = self.reshape(x, (B * N, -1)) x = self.dropout_layer1(self.linear_encoding(x)) + x x = self.reshape(x, (B, N, -1)) query_embed = self.tile( - self.reshape(self.query_embed.embedding_table[int( - query_idx)], (1, self.seq_length, self.embedding_dim)), + self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)), (B, 1, 1)) if not self.no_pos: pos = self.position_encoding(x) - - x = self.encoder(x + pos) + x = self.encoder(x + pos) + else: + x = self.encoder(x) x = self.decoder(x, x, query_pos=query_embed) if self.mlp is not True: - x = self.reshape(x, (int(B * N), -1)) + x = self.reshape(x, (B * N, -1)) x = self.mlp_head(x) + x x = self.reshape(x, (B, N, -1)) x = self.fold(x) @@ -542,9 +542,9 @@ class ResBlock(nn.Cell): def _pixelsf_(x, scale): """ipt""" N, C, iH, iW = x.shape - oH = int(iH * scale) - oW = int(iW * scale) - oC = int(C // (scale ** 2)) + oH = iH * scale + oW = iW * scale + oC = C // (scale ** 2) output = P.Reshape()(x, (N, oC, scale, scale, iH, iW)) @@ -565,11 +565,12 @@ class SmallUpSampler(nn.Cell): self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias) self.reshape = P.Reshape() self.upsize = upsize + self.pixelsf = _pixelsf_ def construct(self, x): """ipt""" x = self.conv(x) - output = _pixelsf_(x, self.upsize) + output = self.pixelsf(x, self.upsize) return output @@ -628,7 +629,8 @@ class IPT(nn.Cell): dropout_rate=args.dropout_rate, mlp=args.no_mlp, pos_every=args.pos_every, - no_pos=args.no_pos) + no_pos=args.no_pos, + idx=self.scale_idx) self.tail = nn.CellList([ nn.SequentialCell( @@ -645,7 +647,7 @@ class IPT(nn.Cell): """ipt""" x = self.sub_mean(x) x = self.head[self.scale_idx](x) - res = self.body(x, self.scale_idx) + res = self.body(x) res += x x = self.tail[self.scale_idx](res) x = self.add_mean(x) @@ -654,30 +656,43 @@ class IPT(nn.Cell): def set_scale(self, scale_idx): """ipt""" + self.body.query_idx = scale_idx self.scale_idx = scale_idx - def infrc(self, x): - """ipt""" - forward_function = self.forward_chop_new - return forward_function(x) +class IPT_post(): + """ipt""" + def __init__(self, model, args): + super(IPT_post, self).__init__() + self.model = model + self.args = args + self.scale_idx = 0 + self.reshape = P.Reshape() + self.tile = P.Tile() + self.transpose = P.Transpose() + self.cc_0 = P.Concat(axis=0) + self.cc_2 = P.Concat(axis=2) + self.cc_3 = P.Concat(axis=3) - def forward_chop_new(self, x, shave=12, batchsize=64): + def set_scale(self, scale_idx): + """ipt""" + self.body.query_idx = scale_idx + self.scale_idx = scale_idx + + def forward(self, x, shave=12, batchsize=64): """ipt""" h, w = x.shape[-2:] padsize = int(self.args.patch_size) shave = int(self.args.patch_size / 4) scale = self.args.scale[self.scale_idx] - h_cut = (h - padsize) % (padsize - shave) w_cut = (w - padsize) % (padsize - shave) unf_1 = _stride_unfold_(padsize, stride=padsize - shave) - x_unfold = unf_1(x) + x_unfold = unf_1.compute(x) x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2) - x_hw_cut = x[:, :, (h - padsize):, (w - padsize):] - y_hw_cut = self.construct(x_hw_cut) + y_hw_cut = self.model(x_hw_cut) x_h_cut = x[:, :, (h - padsize):, :] x_w_cut = x[:, :, :, (w - padsize):] @@ -696,66 +711,71 @@ class IPT(nn.Cell): x_unfold, (x_unfold.shape[0], -1, padsize, padsize)) x_range = x_unfold.shape[0] // batchsize + \ (x_unfold.shape[0] % batchsize != 0) - - cc_0 = P.Concat(axis=0) for i in range(x_range): if i == 0: - y_unfold = self.construct( + y_unfold = self.model( x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) else: - y_unfold = cc_0((y_unfold, self.construct( + y_unfold = self.cc_0((y_unfold, self.model( x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) y_unf_shape_0 = y_unfold.shape[0] fold_1 = \ _stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale), stride=padsize * scale - shave * scale) - y = fold_1(self.transpose(self.reshape( + y = fold_1.compute(self.transpose(self.reshape( y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) - cc_2 = P.Concat(axis=2) - cc_3 = P.Concat(axis=3) - y = cc_2((y_h_top, y[:, :, padsize * scale:, :])) - y = cc_3((y_w_top, y[:, :, :, padsize * scale:])) + if y[:, :, padsize * scale:, :].shape[2] == 0: + y = y_h_top + else: + y = self.cc_2((y_h_top, y[:, :, padsize * scale:, :])) + if y[:, :, :, padsize * scale:].shape[3] == 0: + y = y_w_top + else: + y = self.cc_3((y_w_top, y[:, :, :, padsize * scale:])) y_unfold = y_unfold[:, :, int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale), int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] fold_2 = _stride_fold_(padsize * scale - shave * scale, output_shape=((h - h_cut - shave) * scale, (w - w_cut - shave) * scale), stride=padsize * scale - shave * scale) - y_inter = fold_2(self.transpose(self.reshape( + y_inter = fold_2.compute(self.transpose(self.reshape( y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) - y = cc_3((cc_3((y[:, :, :, :int(shave / 2 * scale)], cc_2((cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)), y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])))), y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long - y = cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], - y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) - y_w_cat = cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], - y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) - y = cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], - y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) + concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long + concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long + concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2)) + y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long + y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long + + y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], + y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) + y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], + y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) + return y def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): """ipt""" unf_1 = _stride_unfold_(padsize, stride=padsize - shave) - x_h_cut_unfold = unf_1(x_h_cut) + x_h_cut_unfold = unf_1.compute(x_h_cut) x_h_cut_unfold = self.transpose(x_h_cut_unfold, (1, 0, 2)) x_h_cut_unfold = self.reshape( x_h_cut_unfold, (x_h_cut_unfold.shape[0], -1, padsize, padsize)) x_range = x_h_cut_unfold.shape[0] // batchsize + \ (x_h_cut_unfold.shape[0] % batchsize != 0) - cc_0 = P.Concat(axis=0) for i in range(x_range): if i == 0: - y_h_cut_unfold = self.construct( + y_h_cut_unfold = self.model( x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) else: y_h_cut_unfold = \ - cc_0((y_h_cut_unfold, self.construct( + self.cc_0((y_h_cut_unfold, self.model( x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0] fold_1 = \ _stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale), stride=padsize * scale - shave * scale) - y_h_cut = fold_1(self.transpose(self.reshape( + y_h_cut = fold_1.compute(self.transpose(self.reshape( y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) y_h_cut_unfold = y_h_cut_unfold[:, :, :, int( shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] @@ -763,37 +783,35 @@ class IPT(nn.Cell): output_shape=(padsize * scale, (w - w_cut - shave) * scale), stride=padsize * scale - shave * scale) - y_h_cut_inter = fold_2(self.transpose(self.reshape( + y_h_cut_inter = fold_2.compute(self.transpose(self.reshape( y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) - cc_3 = P.Concat(axis=3) - y_h_cut = cc_3((cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], - y_h_cut_inter)), y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) + concat1 = self.cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], y_h_cut_inter)) + y_h_cut = self.cc_3((concat1, y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) return y_h_cut def cut_w_new(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): """ipt""" unf_1 = _stride_unfold_(padsize, stride=padsize - shave) - x_w_cut_unfold = unf_1(x_w_cut) + x_w_cut_unfold = unf_1.compute(x_w_cut) x_w_cut_unfold = self.transpose(x_w_cut_unfold, (1, 0, 2)) x_w_cut_unfold = self.reshape( x_w_cut_unfold, (x_w_cut_unfold.shape[0], -1, padsize, padsize)) x_range = x_w_cut_unfold.shape[0] // batchsize + \ (x_w_cut_unfold.shape[0] % batchsize != 0) - cc_0 = P.Concat(axis=0) for i in range(x_range): if i == 0: - y_w_cut_unfold = self.construct( + y_w_cut_unfold = self.model( x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) else: - y_w_cut_unfold = cc_0((y_w_cut_unfold, - self.construct(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) + y_w_cut_unfold = self.cc_0((y_w_cut_unfold, + self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0] fold_1 = _stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, padsize * scale), stride=padsize * scale - shave * scale) - y_w_cut = fold_1(self.transpose(self.reshape( + y_w_cut = fold_1.compute(self.transpose(self.reshape( y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) y_w_cut_unfold = y_w_cut_unfold[:, :, int( shave / 2 * scale):padsize * scale - int(shave / 2 * scale), :] @@ -801,19 +819,18 @@ class IPT(nn.Cell): output_shape=((h - h_cut - shave) * scale, padsize * scale), stride=padsize * scale - shave * scale) - y_w_cut_inter = fold_2(self.transpose(self.reshape( + y_w_cut_inter = fold_2.compute(self.transpose(self.reshape( y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) - cc_2 = P.Concat(axis=2) - y_w_cut = cc_2((cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], - y_w_cut_inter)), y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :])) + concat1 = self.cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], y_w_cut_inter)) + y_w_cut = self.cc_2((concat1, y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :])) return y_w_cut +class _stride_unfold_(): + '''stride''' -class _stride_unfold_(nn.Cell): - """ipt""" - - def __init__( - self, kernel_size, stride=-1): + def __init__(self, + kernel_size, + stride=-1): super(_stride_unfold_, self).__init__() if stride == -1: @@ -821,28 +838,24 @@ class _stride_unfold_(nn.Cell): else: self.stride = stride self.kernel_size = kernel_size - self.reshape = P.Reshape() - self.transpose = P.Transpose() + self.unfold = _unfold_(kernel_size) - def construct(self, x): - """ipt""" + def compute(self, x): + """stride""" + x = x.asnumpy() N, C, H, W = x.shape leftup_idx_x = [] leftup_idx_y = [] - nh = int((H - self.kernel_size) / self.stride + 1) - nw = int((W - self.kernel_size) / self.stride + 1) + nh = (H - self.kernel_size) // self.stride + 1 + nw = (W - self.kernel_size) // self.stride + 1 for i in range(nh): leftup_idx_x.append(i * self.stride) for i in range(nw): leftup_idx_y.append(i * self.stride) NumBlock_x = len(leftup_idx_x) NumBlock_y = len(leftup_idx_y) - zeroslike = P.ZerosLike() - cc_2 = P.Concat(axis=2) - cc_3 = P.Concat(axis=3) - unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, - NumBlock_y * self.kernel_size), mstype.float32) + unf_x = np.zeros((N, C, NumBlock_x * self.kernel_size, NumBlock_y * self.kernel_size), dtype=np.float32) N, C, H, W = unf_x.shape for i in range(NumBlock_x): for j in range(NumBlock_y): @@ -852,23 +865,28 @@ class _stride_unfold_(nn.Cell): org_j = leftup_idx_y[j] fills = x[:, :, org_i:org_i + self.kernel_size, org_j:org_j + self.kernel_size] - unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), - cc_2( - (cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), - zeroslike(unf_x[:, :, unf_i + self.kernel_size:, - unf_j:unf_j + self.kernel_size]))))), - zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) + zeros2 = np.zeros(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size].shape) + concat1 = np.concatenate((zeros2, fills), axis=2) + zeros3 = np.zeros(unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size].shape) + concat2 = np.concatenate((concat1, zeros3), axis=2) + zeros1 = np.zeros(unf_x[:, :, :, :unf_j].shape) + concat3 = np.concatenate((zeros1, concat2), axis=3) + zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape) + concat4 = np.concatenate((concat3, zeros4), axis=3) + unf_x += concat4 + unf_x = Tensor(unf_x, mstype.float32) y = self.unfold(unf_x) return y -class _stride_fold_(nn.Cell): - """ipt""" +class _stride_fold_(): + '''stride''' - def __init__( - self, kernel_size, output_shape=(-1, -1), stride=-1): + def __init__(self, + kernel_size, + output_shape=(-1, -1), + stride=-1): super(_stride_fold_, self).__init__() - if isinstance(kernel_size, (list, tuple)): self.kernel_size = kernel_size else: @@ -880,66 +898,49 @@ class _stride_fold_(nn.Cell): self.stride = stride self.output_shape = output_shape - self.reshape = P.Reshape() - self.transpose = P.Transpose() - self.fold = _fold_(kernel_size) - def construct(self, x): - """ipt""" - cc_2 = P.Concat(axis=2) - cc_3 = P.Concat(axis=3) - zeroslike = P.ZerosLike() - if self.output_shape[0] == -1: - large_x = self.fold(x) - N, C, H, _ = large_x.shape - leftup_idx = [] - for i in range(0, H, self.kernel_size[0]): - leftup_idx.append(i) - NumBlock = len(leftup_idx) - fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], - (NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) - - for i in range(NumBlock): - for j in range(NumBlock): - fold_i = i * self.stride - fold_j = j * self.stride - org_i = leftup_idx[i] - org_j = leftup_idx[j] - fills = large_x[:, :, org_i:org_i + self.kernel_size[0], - org_j:org_j + self.kernel_size[1]] - fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long - y = fold_x - else: - NumBlock_x = int( - (self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) - NumBlock_y = int( - (self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) - large_shape = [NumBlock_x * self.kernel_size[0], - NumBlock_y * self.kernel_size[1]] - self.fold = _fold_(self.kernel_size, large_shape) - large_x = self.fold(x) - N, C, H, _ = large_x.shape - leftup_idx_x = [] - leftup_idx_y = [] - for i in range(NumBlock_x): - leftup_idx_x.append(i * self.kernel_size[0]) - for i in range(NumBlock_y): - leftup_idx_y.append(i * self.kernel_size[1]) - fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], - (NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) - for i in range(NumBlock_x): - for j in range(NumBlock_y): - fold_i = i * self.stride - fold_j = j * self.stride - org_i = leftup_idx_x[i] - org_j = leftup_idx_y[j] - fills = large_x[:, :, org_i:org_i + self.kernel_size[0], - org_j:org_j + self.kernel_size[1]] - fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long - y = fold_x + self.NumBlock_x = (self.output_shape[0] - self.kernel_size[0]) // self.stride + 1 + self.NumBlock_y = (self.output_shape[1] - self.kernel_size[1]) // self.stride + 1 + self.large_shape = [self.NumBlock_x * self.kernel_size[0], self.NumBlock_y * self.kernel_size[1]] + self.fold = _fold_(self.kernel_size, self.large_shape) + + def compute(self, x): + '''stride''' + NumBlock_x = self.NumBlock_x + NumBlock_y = self.NumBlock_y + large_x = self.fold(x) + large_x = large_x.asnumpy() + N, C, _, _ = large_x.shape + leftup_idx_x = [] + leftup_idx_y = [] + for i in range(NumBlock_x): + leftup_idx_x.append(i * self.kernel_size[0]) + for i in range(NumBlock_y): + leftup_idx_y.append(i * self.kernel_size[1]) + fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long + for i in range(NumBlock_x): + for j in range(NumBlock_y): + fold_i = i * self.stride + fold_j = j * self.stride + org_i = leftup_idx_x[i] + org_j = leftup_idx_y[j] + fills = large_x[:, :, org_i:org_i + self.kernel_size[0], org_j:org_j + self.kernel_size[1]] + t2 = fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]] + zeros2 = np.zeros(t2.shape) + concat1 = np.concatenate((zeros2, fills), axis=2) + t3 = fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]] + zeros3 = np.zeros(t3.shape) + concat2 = np.concatenate((concat1, zeros3), axis=2) + t1 = fold_x[:, :, :, :fold_j] + zeros1 = np.zeros(t1.shape) + concat3 = np.concatenate((zeros1, concat2), axis=3) + t4 = fold_x[:, :, :, fold_j + self.kernel_size[1]:] + zeros4 = np.zeros(t4.shape) + concat4 = np.concatenate((concat3, zeros4), axis=3) + fold_x += concat4 + y = Tensor(fold_x, mstype.float32) return y - class _unfold_(nn.Cell): """ipt""" @@ -957,20 +958,16 @@ class _unfold_(nn.Cell): def construct(self, x): """ipt""" N, C, H, W = x.shape - numH = int(H / self.kernel_size) - numW = int(W / self.kernel_size) + numH = H // self.kernel_size + numW = W // self.kernel_size if numH * self.kernel_size != H or numW * self.kernel_size != W: x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) - - output_img = self.reshape(output_img, (N, C, int( - numH * numW), self.kernel_size, self.kernel_size)) - - output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) - - output_img = self.reshape(output_img, (N, int(numH * numW), -1)) + output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size)) + output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4)) + output_img = self.reshape(output_img, (N, numH * numW, -1)) return output_img @@ -994,22 +991,17 @@ class _fold_(nn.Cell): self.reshape = P.Reshape() self.transpose = P.Transpose() + self.sqrt = P.Sqrt() + self.cast = P.Cast() def construct(self, x): """ipt""" N, C, L = x.shape - org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) - if self.output_shape[0] == -1: - numH = int(np.sqrt(C)) - numW = int(np.sqrt(C)) - org_H = int(numH * self.kernel_size[0]) - org_W = org_H - else: - org_H = int(self.output_shape[0]) - org_W = int(self.output_shape[1]) - numH = int(org_H / self.kernel_size[0]) - numW = int(org_W / self.kernel_size[1]) - + org_C = L // (self.kernel_size[0] * self.kernel_size[1]) + org_H = self.output_shape[0] + org_W = self.output_shape[1] + numH = org_H // self.kernel_size[0] + numW = org_W // self.kernel_size[1] output_img = self.reshape( x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1]))