support img2col for resnet50_thor GPU

primitive for im2col

fix bug

clang code format

clang format fix

fix pylint

fix license

delete useless code
pull/3924/head
mamba_ni 5 years ago
parent 657b547116
commit 4fce4c7c34

@ -0,0 +1,40 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "identity_impl.cuh"
#include <iostream>
template <typename T>
__global__ void IdentityKernel(const size_t size, const size_t dim, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t batchIdx = pointIdx / (dim * dim);
size_t dst_x = (pointIdx - batchIdx * dim * dim) / dim;
size_t dst_y = (pointIdx - batchIdx * dim * dim) % dim;
if (dst_x == dst_y) {
output_addr[pointIdx] = 1;
} else {
output_addr[pointIdx] = 0;
}
}
}
template <typename T>
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream) {
IdentityKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dim, output_addr);
return;
}
template void Identity<float>(const size_t size, const size_t dim, float *output_addr, cudaStream_t cuda_stream);

@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_IDENTITY_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void Identity(const size_t size, const size_t dim, T *output_addr, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "matrix_combine_impl.cuh"
#include <iostream>
template <typename T>
__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width,
const size_t dst_width, T *input_addr, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t batchIdx = pointIdx / (src_height * src_width);
size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width;
size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width;
size_t dst_h = src_height * batchIdx + src_h;
size_t dst_w = src_width * batchIdx + src_w;
output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx];
}
}
template <typename T>
__global__ void MatrixCombineKernel(const size_t size, const size_t src_height, const size_t src_width,
const size_t dst_width, const size_t res_width, const size_t batch, T *input_addr,
T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t batchIdx = pointIdx / (src_height * src_width);
if (batchIdx != (batch - 1)) {
size_t src_h = (pointIdx - batchIdx * src_height * src_width) / src_width;
size_t src_w = (pointIdx - batchIdx * src_height * src_width) % src_width;
size_t dst_h = src_height * batchIdx + src_h;
size_t dst_w = src_width * batchIdx + src_w;
output_addr[dst_h * dst_width + dst_w] = input_addr[pointIdx];
} else {
size_t src_h = (pointIdx - (batch - 1) * src_height * src_width) / res_width;
size_t src_w = (pointIdx - (batch - 1) * src_height * src_width) % res_width;
size_t src_coordinate = (batch - 1) * src_height * src_width + src_h * src_width + src_w;
size_t dst_h = src_height * (batch - 1) + src_h;
size_t dst_w = src_width * (batch - 1) + src_w;
output_addr[dst_h * dst_width + dst_w] = input_addr[src_coordinate];
}
}
}
template <typename T>
void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width,
const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr,
cudaStream_t cuda_stream) {
if (residual == 0) {
MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width,
input_addr, output_addr);
} else {
MatrixCombineKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, src_height, src_width, dst_width,
res_width, batch, input_addr, output_addr);
}
return;
}
template void MatrixCombine<float>(const size_t size, const size_t src_height, const size_t src_width,
const size_t dst_width, const size_t residual, const size_t res_width,
const size_t batch, float *input_addr, float *output_addr, cudaStream_t cuda_stream);

@ -0,0 +1,27 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void MatrixCombine(const size_t size, const size_t src_height, const size_t src_width, const size_t dst_width,
const size_t residual, const size_t res_width, const size_t batch, T *input_addr, T *output_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXCOMBINE_H_

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "matrix_split_impl.cuh"
#include <iostream>
template <typename T>
__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, T *input_addr,
T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t batchIdx = pointIdx / (split_dim * split_dim);
size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim;
size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim;
size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y;
output_addr[pointIdx] = input_addr[src_coordinate];
}
}
template <typename T>
__global__ void MatrixSplitKernel(const size_t size, const size_t split_dim, const size_t dim, const size_t res_dim,
T *input_addr, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t batchIdx = pointIdx / (split_dim * split_dim);
size_t dst_x = (pointIdx - batchIdx * split_dim * split_dim) / split_dim;
size_t dst_y = (pointIdx - batchIdx * split_dim * split_dim) % split_dim;
size_t src_coordinate = (batchIdx * split_dim + dst_x) * dim + batchIdx * split_dim + dst_y;
size_t batch_lower = dim / split_dim;
if (batchIdx < batch_lower) {
output_addr[pointIdx] = input_addr[src_coordinate];
} else {
if (dst_x < res_dim && dst_y < res_dim) {
output_addr[pointIdx] = input_addr[src_coordinate];
} else if (dst_x == dst_y) {
output_addr[pointIdx] = 1;
} else {
output_addr[pointIdx] = 0;
}
}
}
}
template <typename T>
void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr,
cudaStream_t cuda_stream) {
size_t batch = dim / split_dim;
size_t res_dim = dim - batch * split_dim;
if (res_dim == 0) {
MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, input_addr, output_addr);
} else {
MatrixSplitKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, split_dim, dim, res_dim, input_addr,
output_addr);
}
return;
}
template void MatrixSplit<float>(const size_t size, const size_t split_dim, const size_t dim, float *input_addr,
float *output_addr, cudaStream_t cuda_stream);

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void MatrixSplit(const size_t size, const size_t split_dim, const size_t dim, T *input_addr, T *output_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MATRIXSPLIT_H_

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/im2col_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Im2ColGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Im2Col, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Im2ColGpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore

@ -83,7 +83,10 @@ from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
from .thor_ops import *
from .thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
CusMatMulCubeFraczLeftCast, Im2Col)
from .sparse_ops import SparseToDense
__all__ = [

@ -13,9 +13,12 @@
# limitations under the License.
# ============================================================================
"""thor_ops"""
import math
from ..primitive import prim_attr_register, PrimitiveWithInfer
from ...common import dtype as mstype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
__all__ = ["CusBatchMatMul",
"CusCholeskyTrsm",
@ -31,6 +34,37 @@ __all__ = ["CusBatchMatMul",
]
def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False):
"""
Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements.
"""
def _raise_message():
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two "
f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}")
def _get_return_value():
if isinstance(arg_value, int):
ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value)
elif len(arg_value) == 2:
ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value
elif len(arg_value) == 4:
if not allow_four:
_raise_message()
ret = arg_value if ret_four else (arg_value[2], arg_value[3])
else:
_raise_message()
return ret
validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
ret_value = _get_return_value()
for item in ret_value:
if isinstance(item, int) and item > 0:
continue
_raise_message()
return ret_value
class CusBatchMatMul(PrimitiveWithInfer):
"""
Multiplies matrix `a` by matrix `b` in batch.
@ -360,6 +394,7 @@ class CusTranspose02314(PrimitiveWithInfer):
"""init CusTranspose02314"""
self.init_prim_io_names(inputs=['x1'], outputs=['y'])
from mindspore.ops._op_impl._custom_op.transpose02314_impl import CusTranspose02314
def get_bprop(self):
def bprop(x, out, dout):
return (C.zeros_like(x),)
@ -446,3 +481,84 @@ class CusMatMulCubeFraczLeftCast(PrimitiveWithInfer):
def infer_dtype(self, data1_dtype, data2_dtype):
return mstype.float16
class Im2Col(PrimitiveWithInfer):
"""
extract image pathes from image.
The rank of input_x1 must be `4`, data_format is "NCHW".
Inputs:
- **input_x1** (Tensor) - The feature map.
The shape of the tensor is :math:`(N, C, H, W)`.
Outputs:
Tensor.
Examples:
>>> input_x = Tensor(np.random.rand(32, 3, 224, 224).astype(np.float16))
>>> img2col = P.CusMatMulCubeDenseLeft(kernel_size=7, pad=3, stride=2)
>>> output = img2col(input_x)
"""
@prim_attr_register
def __init__(self,
kernel_size,
pad_mode="valid",
pad=0,
stride=1,
dilation=1):
"""init Im2Col"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
self.add_prim_attr('kernel_size', self.kernel_size)
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('stride', self.stride)
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
self.add_prim_attr('dilation', self.dilation)
validator.check_value_type('pad', pad, (int,), self.name)
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
if self.pad_mode == 'pad':
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
self.add_prim_attr('data_format', "NCHW")
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
kernel_size_h = self.kernel_size[0]
kernel_size_w = self.kernel_size[1]
stride_h = self.stride[2]
stride_w = self.stride[3]
dilation_h = self.dilation[2]
dilation_w = self.dilation[3]
if self.pad_mode == "valid":
h_out = math.ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h)
w_out = math.ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w)
pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0
elif self.pad_mode == "same":
h_out = math.ceil(x_shape[2] / stride_h)
w_out = math.ceil(x_shape[3] / stride_w)
pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
elif self.pad_mode == 'pad':
pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h
w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w
h_out = math.floor(h_out)
w_out = math.floor(w_out)
self.pad_list = [pad_top, pad_bottom, pad_left, pad_right]
self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right))
batch_size = x_shape[0]
channel = x_shape[1]
k_h = kernel_size_h
k_w = kernel_size_w
out_shape = [channel, k_h, k_w, batch_size, h_out, w_out]
return out_shape
def infer_dtype(self, x_dtype):
args = {'x': x_dtype}
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
return x_dtype

Loading…
Cancel
Save