|
|
|
@ -19,11 +19,12 @@ import mindspore.nn as nn
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def weight_init_ones(shape):
|
|
|
|
|
"""Weight init."""
|
|
|
|
|
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float16))
|
|
|
|
|
return Tensor(np.array(np.ones(shape).astype(np.float32) * 0.01).astype(np.float32))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mode='pad'):
|
|
|
|
@ -32,15 +33,15 @@ def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0, pad_mod
|
|
|
|
|
weights = weight_init_ones(shape)
|
|
|
|
|
return nn.Conv2d(in_channels, out_channels,
|
|
|
|
|
kernel_size=kernel_size, stride=stride, padding=padding,
|
|
|
|
|
pad_mode=pad_mode, weight_init=weights, has_bias=False)
|
|
|
|
|
pad_mode=pad_mode, weight_init=weights, has_bias=False).to_float(mstype.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _BatchNorm2dInit(out_chls, momentum=0.1, affine=True, use_batch_statistics=True):
|
|
|
|
|
"""Batchnorm2D wrapper."""
|
|
|
|
|
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
|
|
|
|
|
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
|
|
|
|
|
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float16))
|
|
|
|
|
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float16))
|
|
|
|
|
gamma_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
|
|
|
|
beta_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
|
|
|
|
moving_mean_init = Tensor(np.array(np.ones(out_chls) * 0).astype(np.float32))
|
|
|
|
|
moving_var_init = Tensor(np.array(np.ones(out_chls)).astype(np.float32))
|
|
|
|
|
|
|
|
|
|
return nn.BatchNorm2d(out_chls, momentum=momentum, affine=affine, gamma_init=gamma_init,
|
|
|
|
|
beta_init=beta_init, moving_mean_init=moving_mean_init,
|
|
|
|
|