parent
f47767b361
commit
fc182bf315
@ -1,90 +0,0 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
def weight_variable(shape):
|
||||
init_value = initializer('Normal', shape, mstype.float32)
|
||||
return Parameter(init_value)
|
||||
|
||||
class Conv3D(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
|
||||
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
|
||||
group=group, data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
class Conv3DTranspose(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
output_padding=0,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
|
||||
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
|
||||
dilation=dilation, group=group, output_padding=output_padding, \
|
||||
data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv_transpose(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
Loading…
Reference in new issue