diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu new file mode 100644 index 0000000000..b1a0eb2514 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cu @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh" + +template +__device__ __forceinline__ T SqrtFunc(T input) { + return sqrt(input); +} + +template <> +__device__ __forceinline__ half SqrtFunc(half input) { + return hsqrt(input); +} + +template +__global__ void ApplyAdagradKernel(const size_t size, + const bool update_slots, + const T *learning_rate, + const T *gradient, + T *variable, + T *accumulation) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + if (update_slots) { + accumulation[i] += gradient[i] * gradient[i]; + } + variable[i] -= learning_rate[0] * gradient[i] / SqrtFunc(accumulation[i]); + } +} + +template +void ApplyAdagrad(const size_t size, + const bool update_slots, + const T *learning_rate, + const T *gradient, + T *variable, + T *accumulation, + cudaStream_t cuda_stream) { + ApplyAdagradKernel<<< GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( + size, update_slots, learning_rate, gradient, variable, accumulation); +} + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const float *learning_rate, + const float *gradient, + float *variable, + float *accumulation, + cudaStream_t cuda_stream); + +template void ApplyAdagrad(const size_t size, + const bool update_slots, + const half *learning_rate, + const half *gradient, + half *variable, + half *accumulation, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh new file mode 100644 index 0000000000..3cfbd776e9 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyAdagrad(const size_t size, + const bool update_slots, + const T *learning_rate, + const T *gradient, + T *variable, + T *accumulation, + cudaStream_t stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAGRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc new file mode 100644 index 0000000000..25c459c14b --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.cc @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdagradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ApplyAdagrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdagradGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h new file mode 100644 index 0000000000..a6cf718b64 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adagrad_gpu_kernel.h @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_ADAGRAD_GPU_KERNEL_H +#define MINDSPORE_ADAGRAD_GPU_KERNEL_H + +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adagrad_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class AdagradGpuKernel : public GpuKernel { + public: + AdagradGpuKernel() + : variable_size_(0), accumulation_size_(0), learning_rate_size_(0), gradient_size_(0), update_slots(true) {} + + ~AdagradGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 4) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but adagrad needs 4 inputs."; + return false; + } + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + gradient_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *learning_rate = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + ApplyAdagrad(inputs[0]->size / sizeof(T), update_slots, learning_rate, gradient, variable, accumulation, + reinterpret_cast(stream_ptr)); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t learning_rate_size_; + size_t gradient_size_; + bool update_slots; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_ADAGRAD_GPU_KERNEL_H diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h index 8ac9839dcf..4965a2652c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adam_gpu_kernel.h @@ -64,7 +64,7 @@ class AdamGpuKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) override { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 10) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; + MS_LOG(ERROR) << "Input number is " << input_num << ", but adam needs 10 inputs."; return false; } diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index aa38da0687..d54a23804d 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4700,7 +4700,7 @@ class ApplyAdagrad(PrimitiveWithInfer): - **accum** (Tensor) - The same shape and data type as `accum`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` ``GPU`` Examples: >>> import numpy as np diff --git a/tests/st/ops/gpu/test_adagrad_op.py b/tests/st/ops/gpu/test_adagrad_op.py new file mode 100644 index 0000000000..7153595f55 --- /dev/null +++ b/tests/st/ops/gpu/test_adagrad_op.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +var_np = np.random.rand(3, 3).astype(np.float32) +accum_np = np.random.rand(3, 3).astype(np.float32) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.apply_adagrad = P.ApplyAdagrad() + self.var = Parameter(Tensor(var_np), name="var") + self.accum = Parameter(Tensor(accum_np), name="accum") + + def construct(self, lr, grad): + self.apply_adagrad(self.var, self.accum, lr, grad) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_apply_adagrad(): + # numpy op + grident_np = np.random.rand(3, 3).astype(np.float32) + expect_accum_np = accum_np + grident_np * grident_np + expect_var_np = var_np - (0.001 * grident_np * (1 / np.sqrt(expect_accum_np + 1e-6))) + + net = Net() + lr = Tensor(0.001, mstype.float32) + grad = Tensor(grident_np) + out = net(lr, grad) + res_var_mindspore = out[0].asnumpy() + res_accum_mindspore = out[1].asnumpy() + eps = np.array([1e-6 for i in range(9)]).reshape(3, 3) + + assert np.all(expect_var_np - res_var_mindspore < eps) + assert np.all(expect_accum_np - res_accum_mindspore < eps)