update thor gradient op

fix bug update thor gradient op

fix bug

fix format

fix format

fix docstring

fix ops docstring

change algo type

pylint
pull/3991/head
mamba_ni 5 years ago
parent 7c03073143
commit c1dbc5a090

@ -0,0 +1,100 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "convert_gradient_impl.cuh"
template <typename T>
__global__ void ConvertGradientKernel(const size_t size, const size_t height_h, const size_t height_w,
const size_t batchwidth, const size_t width, T *input_addr, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t dst_batchIdx = pointIdx / (height_h * height_w);
size_t dst_batchIdxX = dst_batchIdx / batchwidth;
size_t dst_batchIdxY = dst_batchIdx % batchwidth;
size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w;
size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w;
size_t src_coordinate = dst_batchIdxX * height_h * width + dst_x * width + dst_batchIdxY * height_w + dst_y;
output_addr[pointIdx] = input_addr[src_coordinate];
}
}
template <typename T>
__global__ void ConvertGradientBackKernel(const size_t size, const size_t height_h, const size_t height_w,
const size_t batchwidth, const size_t width, T *input_addr, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t dst_batchIdx = pointIdx / (height_h * height_w);
size_t dst_batchIdxX = dst_batchIdx / batchwidth;
size_t dst_batchIdxY = dst_batchIdx % batchwidth;
size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w;
size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w;
size_t src_coordinate = dst_batchIdxX * height_h * width + dst_x * width + dst_batchIdxY * height_w + dst_y;
output_addr[src_coordinate] = input_addr[pointIdx];
}
}
template <typename T>
__global__ void ConvertGradientBackKernel(const size_t size, const size_t height_h, const size_t height_w,
const size_t ori_h, const size_t ori_w, const size_t batchwidth,
const size_t width, T *input_addr, T *output_addr) {
for (size_t pointIdx = blockIdx.x * blockDim.x + threadIdx.x; pointIdx < (size); pointIdx += blockDim.x * gridDim.x) {
size_t dst_batchIdx = pointIdx / (height_h * height_w);
size_t dst_batchIdxX = dst_batchIdx / batchwidth;
size_t dst_batchIdxY = dst_batchIdx % batchwidth;
size_t dst_x = (pointIdx - dst_batchIdx * height_h * height_w) / height_w;
size_t dst_y = (pointIdx - dst_batchIdx * height_h * height_w) % height_w;
size_t src_x = dst_batchIdxX * height_h + dst_x;
size_t src_y = dst_batchIdxY * height_w + dst_y;
if (src_x < ori_h && src_y < ori_w) {
size_t src_coordinate = src_x * ori_w + src_y;
output_addr[src_coordinate] = input_addr[pointIdx];
}
}
}
template <typename T>
void ConvertGradient(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth,
const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream) {
ConvertGradientKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, height_h, height_w, batchwidth, width,
input_addr, output_addr);
}
template <typename T>
void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth,
const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream) {
ConvertGradientBackKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, height_h, height_w, batchwidth,
width, input_addr, output_addr);
}
template <typename T>
void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t ori_h,
const size_t ori_w, const size_t batchwidth, const size_t width, T *input_addr, T *output_addr,
cudaStream_t cuda_stream) {
ConvertGradientBackKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, height_h, height_w, ori_h, ori_w, batchwidth, width, input_addr, output_addr);
}
template void ConvertGradient<float>(const size_t size, const size_t height_h, const size_t height_w,
const size_t batchwidth, const size_t width, float *input_addr, float *output_addr,
cudaStream_t cuda_stream);
template void ConvertGradientBack<float>(const size_t size, const size_t height_h, const size_t height_w,
const size_t batchwidth, const size_t width, float *input_addr,
float *output_addr, cudaStream_t cuda_stream);
template void ConvertGradientBack<float>(const size_t size, const size_t height_h, const size_t height_w,
const size_t ori_h, const size_t ori_w, const size_t batchwidth,
const size_t width, float *input_addr, float *output_addr,
cudaStream_t cuda_stream);

@ -0,0 +1,34 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void ConvertGradient(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth,
const size_t width, T *input_addr, T *outt_addr, cudaStream_t cuda_stream);
template <typename T>
void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t batchwidth,
const size_t width, T *input_addr, T *output_addr, cudaStream_t cuda_stream);
template <typename T>
void ConvertGradientBack(const size_t size, const size_t height_h, const size_t height_w, const size_t ori_h,
const size_t ori_w, const size_t batchwidth, const size_t width, T *input_addr, T *output_addr,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CONVERTGRADIENT_H_

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/math/update_thor_gradient.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(UpdateThorGradient,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
UpdateThorGradientGpuKernel, float)
} // namespace kernel
} // namespace mindspore

@ -0,0 +1,241 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_UPDATE_THOR_GRADIENT_GPU_KERNEL_H
#define MINDSPORE_UPDATE_THOR_GRADIENT_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "backend/kernel_compiler/gpu/cuda_impl/convert_gradient_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh"
#include "utils/convert_utils.h"
namespace mindspore {
namespace kernel {
struct GradientSize {
size_t batch_h;
size_t batch_w;
size_t h;
size_t w;
size_t ori_h;
size_t ori_w;
size_t pad_h;
size_t pad_w;
bool need_convert;
cudaDataType_t dtype;
};
template <typename T>
class UpdateThorGradientGpuKernel : public GpuKernel {
public:
UpdateThorGradientGpuKernel() : split_dim(128) {}
~UpdateThorGradientGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
auto input1_addr = GetDeviceAddress<T>(inputs, 0);
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
auto input3_addr = GetDeviceAddress<T>(inputs, 2);
auto workspace1_addr = GetDeviceAddress<T>(workspace, 0);
T *workspace2_addr = nullptr;
T *workspace3_addr = nullptr;
if (gradient_size.need_convert) {
workspace2_addr = GetDeviceAddress<T>(workspace, 1);
workspace3_addr = GetDeviceAddress<T>(workspace, 2);
}
T *workspace4_addr = nullptr;
auto output_addr = GetDeviceAddress<T>(outputs, 0);
if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) {
workspace4_addr = GetDeviceAddress<T>(workspace, 3);
const size_t size = (gradient_size.ori_h + gradient_size.pad_h) * (gradient_size.ori_w + gradient_size.pad_w);
CalPad(size, input2_addr, 1, 1, gradient_size.ori_h, gradient_size.ori_w,
gradient_size.ori_h + gradient_size.pad_h, gradient_size.ori_w + gradient_size.pad_w, 0, 0, 0.0,
workspace4_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
cudaMemsetAsync(workspace1_addr, 0,
gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
input2_addr = workspace4_addr;
}
const float alpha = 1;
const float beta = 0;
const int lda = SizeToInt(gradient_size.h);
const int ldb = SizeToInt(gradient_size.ori_w + gradient_size.pad_w);
const int ldc = SizeToInt(gradient_size.ori_w + gradient_size.pad_w);
auto stride_a = SizeToInt(gradient_size.h * gradient_size.h);
auto stride_b = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w));
auto stride_c = SizeToInt(gradient_size.h * (gradient_size.ori_w + gradient_size.pad_w));
try {
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.ori_w),
SizeToInt(gradient_size.h), SizeToInt(gradient_size.h), &alpha, input2_addr,
gradient_size.dtype, ldb, stride_b, input1_addr, gradient_size.dtype, lda, stride_a,
&beta, workspace1_addr, gradient_size.dtype, ldc, stride_c, gradient_size.batch_h,
CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
} catch (const std::exception &e) {
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << "when invoke cubals cublasGemmStridedBatchedEx";
}
auto r_input_addr = workspace1_addr;
if (gradient_size.need_convert) {
size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h;
ConvertGradient(size, gradient_size.h, gradient_size.w, gradient_size.batch_w,
gradient_size.batch_w * gradient_size.w, workspace1_addr, workspace2_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
r_input_addr = workspace2_addr;
}
const int lda_r = SizeToInt(gradient_size.w);
const int ldb_r = SizeToInt(gradient_size.w);
const int ldc_r = SizeToInt(gradient_size.w);
stride_a = SizeToInt(gradient_size.h * gradient_size.w);
stride_b = SizeToInt(gradient_size.w * gradient_size.w);
stride_c = SizeToInt(gradient_size.h * gradient_size.w);
auto r_output_addr = output_addr;
if (gradient_size.need_convert) {
r_output_addr = workspace3_addr;
}
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_N, CUBLAS_OP_N, SizeToInt(gradient_size.w),
SizeToInt(gradient_size.h), SizeToInt(gradient_size.w), &alpha, input3_addr,
gradient_size.dtype, ldb_r, stride_b, r_input_addr, gradient_size.dtype, lda_r,
stride_a, &beta, r_output_addr, gradient_size.dtype, ldc_r, stride_c,
gradient_size.batch_h * gradient_size.batch_w, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
if (gradient_size.need_convert) {
size_t size = gradient_size.batch_w * gradient_size.batch_h * gradient_size.w * gradient_size.h;
if (gradient_size.pad_h == 0 && gradient_size.pad_w == 0) {
ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.batch_w,
gradient_size.batch_w * gradient_size.w, r_output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ConvertGradientBack(size, gradient_size.h, gradient_size.w, gradient_size.ori_h, gradient_size.ori_w,
gradient_size.batch_w, gradient_size.ori_w, r_output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
SetProperty(kernel_node);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
size_t unit_size = sizeof(T);
size_t input_size_ = gradient_size.h * gradient_size.h * gradient_size.batch_h * unit_size;
input_size_list_.push_back(input_size_);
input_size_ = gradient_size.ori_h * gradient_size.ori_w * unit_size;
input_size_list_.push_back(input_size_);
input_size_ = gradient_size.w * gradient_size.w * gradient_size.batch_w * unit_size;
input_size_list_.push_back(input_size_);
size_t output_size = gradient_size.ori_h * gradient_size.ori_w * unit_size;
output_size_list_.push_back(output_size);
size_t workspace_size_ = 0;
workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size;
workspace_size_list_.push_back(workspace_size_);
if (gradient_size.need_convert) {
workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size;
workspace_size_list_.push_back(workspace_size_);
workspace_size_ = gradient_size.w * gradient_size.h * gradient_size.batch_w * gradient_size.batch_h * unit_size;
workspace_size_list_.push_back(workspace_size_);
}
if (gradient_size.pad_h != 0 || gradient_size.pad_w != 0) {
workspace_size_ =
(gradient_size.ori_w + gradient_size.pad_w) * (gradient_size.ori_h + gradient_size.pad_h) * unit_size;
workspace_size_list_.push_back(workspace_size_);
}
}
private:
void SetProperty(const CNodePtr &kernel_node) {
auto matrix_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto matrix_g_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
split_dim = size_t(GetAttr<int>(kernel_node, "split_dim"));
gradient_size.batch_h = gradient_shape[0] / split_dim;
gradient_size.batch_w = gradient_shape[1] / split_dim;
if (gradient_size.batch_h * split_dim != gradient_shape[0]) {
gradient_size.batch_h += 1;
if (gradient_shape[0] > split_dim) {
gradient_size.h = split_dim;
gradient_size.pad_h = gradient_size.batch_h * split_dim - gradient_shape[0];
} else {
gradient_size.h = gradient_shape[0];
gradient_size.pad_h = 0;
}
} else {
gradient_size.h = split_dim;
gradient_size.pad_h = 0;
}
if (gradient_size.batch_w * split_dim != gradient_shape[1]) {
gradient_size.batch_w += 1;
if (gradient_shape[1] > split_dim) {
gradient_size.w = split_dim;
gradient_size.pad_w = gradient_size.batch_w * split_dim - gradient_shape[1];
} else {
gradient_size.w = gradient_shape[1];
gradient_size.pad_w = 0;
}
} else {
gradient_size.w = split_dim;
gradient_size.pad_w = 0;
}
if (gradient_size.batch_w * gradient_size.w <= split_dim) {
gradient_size.need_convert = false;
} else {
gradient_size.need_convert = true;
}
gradient_size.ori_w = gradient_shape[1];
gradient_size.ori_h = gradient_shape[0];
gradient_size.dtype = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
}
size_t split_dim;
struct GradientSize gradient_size;
cublasHandle_t handle_;
cublasGemmAlgo_t algo_ = CUBLAS_GEMM_DEFAULT;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif

@ -83,10 +83,10 @@ from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
from .thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
CusMatMulCubeFraczLeftCast, Im2Col)
from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft,
CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314,
CusMatMulCubeDenseRight,
CusMatMulCubeFraczLeftCast, Im2Col, UpdateThorGradient)
from .sparse_ops import SparseToDense
__all__ = [

@ -562,3 +562,48 @@ class Im2Col(PrimitiveWithInfer):
valid_types = [mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
return x_dtype
class UpdateThorGradient(PrimitiveWithInfer):
"""
Update Thor Gradient with Approximate Fisher info matrix(for GPU backend).
The rank of input_x1 must be `3`, which indicates the A matrix.
The rank of input_x2 must be `2`, which indicates the 1st-order gradient.
The rank of input_x3 must be `4`, which indicates the G matrix.
Inputs:
- **input_x1** (Tensor) - The first input is the diag part of the cov matrix of feature map.
Supported dtype [float32].
- **input_x2** (Tensor) - The second input is the corresponding 1st-order grad. Supported dtype [float32].
- **input_x3** (Tensor) - The third input is the diag part of the cov matrix of dout. Supported dtype [float32].
Outputs:
Tensor, the shape is the same as the shape of input_x2, it will be used to update the weights.
Examples:
>>> input_x1 = Tensor(np.random.rand(16, 128, 128).astype(np.float32))
>>> input_x2 = Tensor(np.random.rand(2048, 1024).astype(np.float32))
>>> temp_x3 = np.random.rand(8, 128, 128).astype(np.float32)
>>> input_x3 = np.zeros(16,8,128,128).astype(np.float32)
>>> for i in range(16):
>>> input_x3[i,:,:,:] = temp_x3
>>> input_x3 = Tensor(input_x3)
>>> update_thor_gradient = P.UpdateThorGradient(split_dim=128)
>>> output = update_thor_gradient(input_x1, input_x2, input_x3)
"""
@prim_attr_register
def __init__(self, split_dim=0):
"""init UpdateThorGradient"""
self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y'])
self.split_dim = split_dim
self.add_prim_attr('split_dim', self.split_dim)
def infer_shape(self, x1_shape, x2_shape, x3_shape):
return x2_shape
def infer_dtype(self, x1_dtype, x2_dtype, x3_dtype):
validator.check_tensor_type_same({'x1_dtype': x1_dtype, 'x2_dtype': x2_dtype, 'x3_dtype': x3_dtype},
[mstype.float32], self.name)
return x2_dtype
Loading…
Cancel
Save