commit
ec025a1c4c
@ -0,0 +1,70 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh"
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T SqrtFunc(T input) {
|
||||||
|
return sqrt(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ half SqrtFunc(half input) {
|
||||||
|
return hsqrt(input);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__global__ void ApplyAdagradKernel(const size_t size,
|
||||||
|
const bool update_slots,
|
||||||
|
const T *learning_rate,
|
||||||
|
const T *gradient,
|
||||||
|
T *variable,
|
||||||
|
T *accumulation) {
|
||||||
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
|
||||||
|
if (update_slots) {
|
||||||
|
accumulation[i] += gradient[i] * gradient[i];
|
||||||
|
}
|
||||||
|
variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void ApplyAdagrad(const size_t size,
|
||||||
|
const bool update_slots,
|
||||||
|
const T *learning_rate,
|
||||||
|
const T *gradient,
|
||||||
|
T *variable,
|
||||||
|
T *accumulation,
|
||||||
|
cudaStream_t cuda_stream) {
|
||||||
|
ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
|
||||||
|
size, update_slots, learning_rate, gradient, variable, accumulation);
|
||||||
|
}
|
||||||
|
|
||||||
|
template void ApplyAdagrad<float>(const size_t size,
|
||||||
|
const bool update_slots,
|
||||||
|
const float *learning_rate,
|
||||||
|
const float *gradient,
|
||||||
|
float *variable,
|
||||||
|
float *accumulation,
|
||||||
|
cudaStream_t cuda_stream);
|
||||||
|
|
||||||
|
template void ApplyAdagrad<half>(const size_t size,
|
||||||
|
const bool update_slots,
|
||||||
|
const half *learning_rate,
|
||||||
|
const half *gradient,
|
||||||
|
half *variable,
|
||||||
|
half *accumulation,
|
||||||
|
cudaStream_t cuda_stream);
|
@ -0,0 +1,30 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/gpu/cuda_common.h"
|
||||||
|
template <typename T>
|
||||||
|
void ApplyAdagrad(const size_t size,
|
||||||
|
const bool update_slots,
|
||||||
|
const T *learning_rate,
|
||||||
|
const T *gradient,
|
||||||
|
T *variable,
|
||||||
|
T *accumulation,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_
|
@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ApplyAdagrad,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddInputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat32),
|
||||||
|
AdagradGpuKernel, float)
|
||||||
|
MS_REG_GPU_KERNEL_ONE(ApplyAdagrad,
|
||||||
|
KernelAttr()
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddInputAttr(kNumberTypeFloat16)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16)
|
||||||
|
.AddOutputAttr(kNumberTypeFloat16),
|
||||||
|
AdagradGpuKernel, half)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,104 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
||||||
|
#define MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||||
|
#include "backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
template <typename T>
|
||||||
|
class AdagradGpuKernel : public GpuKernel {
|
||||||
|
public:
|
||||||
|
AdagradGpuKernel()
|
||||||
|
: variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), update_slots(true) {}
|
||||||
|
|
||||||
|
~AdagradGpuKernel() override = default;
|
||||||
|
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||||
|
|
||||||
|
bool Init(const CNodePtr &kernel_node) override {
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||||
|
if (input_num != 4) {
|
||||||
|
MS_LOG(ERROR) << "Input number is " << input_num << ", but adagrad needs 4 inputs.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
variable_size_ = sizeof(T);
|
||||||
|
accumulation_size_ = sizeof(T);
|
||||||
|
learning_rate_size_ = sizeof(T);
|
||||||
|
gradient_size_ = sizeof(T);
|
||||||
|
|
||||||
|
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
|
for (size_t i = 0; i < variable_shape.size(); i++) {
|
||||||
|
variable_size_ *= variable_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
||||||
|
for (size_t i = 0; i < accumulation_shape.size(); i++) {
|
||||||
|
accumulation_size_ *= accumulation_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
|
||||||
|
for (size_t i = 0; i < gradient_shape.size(); i++) {
|
||||||
|
gradient_size_ *= gradient_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
InitSizeLists();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
|
void *stream_ptr) override {
|
||||||
|
T *variable = GetDeviceAddress<T>(inputs, 0);
|
||||||
|
T *accumulation = GetDeviceAddress<T>(inputs, 1);
|
||||||
|
T *learning_rate = GetDeviceAddress<T>(inputs, 2);
|
||||||
|
T *gradient = GetDeviceAddress<T>(inputs, 3);
|
||||||
|
ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation,
|
||||||
|
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void InitSizeLists() override {
|
||||||
|
input_size_list_.push_back(variable_size_);
|
||||||
|
input_size_list_.push_back(accumulation_size_);
|
||||||
|
input_size_list_.push_back(learning_rate_size_);
|
||||||
|
input_size_list_.push_back(gradient_size_);
|
||||||
|
output_size_list_.push_back(0);
|
||||||
|
output_size_list_.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
size_t variable_size_;
|
||||||
|
size_t accumulation_size_;
|
||||||
|
size_t learning_rate_size_;
|
||||||
|
size_t gradient_size_;
|
||||||
|
bool update_slots;
|
||||||
|
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_ADAGRAD_GPU_KERNEL_H
|
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||||
|
|
||||||
|
var_np = np.random.rand(3, 3).astype(np.float32)
|
||||||
|
accum_np = np.random.rand(3, 3).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.apply_adagrad = P.ApplyAdagrad()
|
||||||
|
self.var = Parameter(Tensor(var_np), name="var")
|
||||||
|
self.accum = Parameter(Tensor(accum_np), name="accum")
|
||||||
|
|
||||||
|
def construct(self, lr, grad):
|
||||||
|
self.apply_adagrad(self.var, self.accum, lr, grad)
|
||||||
|
return self.var, self.accum
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_gpu_training
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_apply_adagrad():
|
||||||
|
# numpy op
|
||||||
|
grident_np = np.random.rand(3, 3).astype(np.float32)
|
||||||
|
expect_accum_np = accum_np + grident_np * grident_np
|
||||||
|
expect_var_np = var_np - (0.001 * grident_np * (1 / np.sqrt(expect_accum_np + 1e-6)))
|
||||||
|
|
||||||
|
net = Net()
|
||||||
|
lr = Tensor(0.001, mstype.float32)
|
||||||
|
grad = Tensor(grident_np)
|
||||||
|
out = net(lr, grad)
|
||||||
|
res_var_mindspore = out[0].asnumpy()
|
||||||
|
res_accum_mindspore = out[1].asnumpy()
|
||||||
|
eps = np.array([1e-6 for i in range(9)]).reshape(3, 3)
|
||||||
|
|
||||||
|
assert np.all(expect_var_np - res_var_mindspore < eps)
|
||||||
|
assert np.all(expect_accum_np - res_accum_mindspore < eps)
|
Loading…
Reference in new issue