From c3360a84cd2a2edd73e51df6ec14a782d1ef4c52 Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Wed, 17 Jun 2020 16:35:43 +0800 Subject: [PATCH] add ftrl optimizer --- .../ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu | 87 ++++++++++++ .../ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh | 26 ++++ .../ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc | 46 +++++++ .../ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h | 130 ++++++++++++++++++ tests/st/ops/gpu/test_ftrl_op.py | 78 +++++++++++ 5 files changed, 367 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_ftrl_op.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu new file mode 100644 index 0000000000..ea6ffdbbdc --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" + +template +__device__ __forceinline__ T PowFunc(T x, T y) { + return pow(x, y); +} + +template <> +__device__ __forceinline__ half PowFunc(half x, half y) { + return __float2half(pow(__half2float(x), __half2float(y))); +} + +template +__device__ __forceinline__ bool CompareFunc(T x, T y) { + return abs(x) > y; +} + +template <> +__device__ __forceinline__ bool CompareFunc(half x, half y) { + return abs(__half2float(x)) > __half2float(y); +} + +template +__device__ __forceinline__ T Sgn(T x) { + return static_cast(x != 0 ? (x > 0 ? 1 : -1) : 0); +} + +template <> +__device__ __forceinline__ half Sgn(half x) { + return __float2half(__half2float(x) != 0 ? (__half2float(x) > 0 ? 1 : -1) : 0); +} + +template +__global__ void ApplyFtrlKernel(const size_t size, const T *gradient, const T *learning_rate, + const T *l1_regularization, const T *l2_regularization, const T *learning_rate_power, + T *variable, T *accumulation, T *linear) { + const T two = static_cast(2.0); + const T learning_rate_power_val = -learning_rate_power[0]; + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + const T cur_accumulation = accumulation[i] + gradient[i] * gradient[i]; + const T accumulation_power = PowFunc(accumulation[i], learning_rate_power_val); + const T cur_accumulation_power = PowFunc(cur_accumulation, learning_rate_power_val); + const T sigma = (cur_accumulation_power - accumulation_power) / learning_rate[0]; + + linear[i] += gradient[i] - sigma * variable[i]; + variable[i] = CompareFunc(linear[i], l1_regularization[0]) + ? ((l1_regularization[0] * Sgn(linear[i]) - linear[i]) / + (cur_accumulation_power / learning_rate[0] + two * l2_regularization[0])) + : static_cast(0); + accumulation[i] = cur_accumulation; + } +} + +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream) { + ApplyFtrlKernel<<>>(size, gradient, learning_rate, l1_regularization, + l2_regularization, learning_rate_power, variable, + accumulation, linear); +} + +template void ApplyFtrl(const size_t size, const float *gradient, const float *learning_rate, + const float *l1_regularization, const float *l2_regularization, + const float *learning_rate_power, float *variable, float *accumulation, float *linear, + cudaStream_t cuda_stream); +template void ApplyFtrl(const size_t size, const half *gradient, const half *learning_rate, + const half *l1_regularization, const half *l2_regularization, + const half *learning_rate_power, half *variable, half *accumulation, half *linear, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh new file mode 100644 index 0000000000..ba4a8fa816 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/ftrl_impl.cuh @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ + +#include "device/gpu/cuda_common.h" +template +void ApplyFtrl(const size_t size, const T *gradient, const T *learning_rate, const T *l1_regularization, + const T *l2_regularization, const T *learning_rate_power, T *variable, T *accumulation, T *linear, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FTRL_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc new file mode 100644 index 0000000000..4d30130931 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/ftrl_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + FtrlGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ApplyFtrl, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + FtrlGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h new file mode 100644 index 0000000000..9e2153965b --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/ftrl_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/ftrl_impl.cuh" +namespace mindspore { +namespace kernel { +template +class FtrlGpuKernel : public GpuKernel { + public: + FtrlGpuKernel() + : variable_size_(0), + accumulation_size_(0), + linear_size_(0), + gradient_size_(0), + learning_rate_size_(0), + l1_regularization_size_(0), + l2_regularization_size_(0), + learning_rate_power_size_(0) {} + + ~FtrlGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *accumulation = GetDeviceAddress(inputs, 1); + T *linear = GetDeviceAddress(inputs, 2); + T *gradient = GetDeviceAddress(inputs, 3); + T *learning_rate = GetDeviceAddress(inputs, 4); + T *l1_regularization = GetDeviceAddress(inputs, 5); + T *l2_regularization = GetDeviceAddress(inputs, 6); + T *learning_rate_power = GetDeviceAddress(inputs, 7); + ApplyFtrl(inputs[0]->size / sizeof(T), gradient, learning_rate, l1_regularization, l2_regularization, + learning_rate_power, variable, accumulation, linear, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 8) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 8 inputs."; + return false; + } + + variable_size_ = sizeof(T); + accumulation_size_ = sizeof(T); + linear_size_ = sizeof(T); + gradient_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + l1_regularization_size_ = sizeof(T); + l2_regularization_size_ = sizeof(T); + learning_rate_power_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto accumulation_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < accumulation_shape.size(); i++) { + accumulation_size_ *= accumulation_shape[i]; + } + + auto linear_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < linear_shape.size(); i++) { + linear_size_ *= linear_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(accumulation_size_); + input_size_list_.push_back(linear_size_); + input_size_list_.push_back(gradient_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(l1_regularization_size_); + input_size_list_.push_back(l2_regularization_size_); + input_size_list_.push_back(learning_rate_power_size_); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t accumulation_size_; + size_t linear_size_; + size_t gradient_size_; + size_t learning_rate_size_; + size_t l1_regularization_size_; + size_t l2_regularization_size_; + size_t learning_rate_power_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_FTRL_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_ftrl_op.py b/tests/st/ops/gpu/test_ftrl_op.py new file mode 100644 index 0000000000..df40d65b9f --- /dev/null +++ b/tests/st/ops/gpu/test_ftrl_op.py @@ -0,0 +1,78 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import FTRL +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class NetFtrl(nn.Cell): + def __init__(self): + super(NetFtrl, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_ftrl(): + epoch = 3 + net = NetFtrl() + optimizer = FTRL(filter(lambda x: x.requires_grad, + net.get_parameters()), learning_rate=0.01) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell( + net_with_criterion, optimizer) + train_network.set_train() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + losses1 = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape( + 1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses1.append(loss.asnumpy()) + assert losses1[0] > losses1[1] + assert losses1[1] > losses1[2] + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + losses2 = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape( + 1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses2.append(loss.asnumpy()) + assert losses2[0] > losses2[1] + assert losses2[1] > losses2[2]