add optimizer:dpsgd,test=develop (#19915)
parent
37f76407b0
commit
766bd529d1
@ -0,0 +1,107 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/optimizers/dpsgd_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using Tensor = framework::Tensor;
|
||||||
|
class DpsgdOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true,
|
||||||
|
"Input(Param) of DpsgdOp should not be null.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true,
|
||||||
|
"Input(Grad) of DpsgdOp should not be null.");
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
|
||||||
|
"Input(LearningRate) of DpsgdOp should not be null.");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->GetInputsVarType("Param").front(),
|
||||||
|
framework::proto::VarType::LOD_TENSOR,
|
||||||
|
"The input var's type should be LoDTensor, but the received is %s",
|
||||||
|
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx->GetInputsVarType("Grad").front(),
|
||||||
|
framework::proto::VarType::LOD_TENSOR,
|
||||||
|
"The input var's type should be LoDTensor, but the received is %s",
|
||||||
|
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
|
||||||
|
"Output(ParamOut) of DpsgdOp should not be null.");
|
||||||
|
|
||||||
|
auto lr_dims = ctx->GetInputDim("LearningRate");
|
||||||
|
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
|
||||||
|
"Learning rate should have 1 dimension");
|
||||||
|
auto param_dims = ctx->GetInputDim("Param");
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
param_dims, ctx->GetInputDim("Grad"),
|
||||||
|
"Param and Grad input of DpsgdOp should have same dimension");
|
||||||
|
|
||||||
|
ctx->SetOutputDim("ParamOut", param_dims);
|
||||||
|
}
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext &ctx) const override {
|
||||||
|
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
|
||||||
|
ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DpsgdOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("Param", "(Tensor) Input parameter");
|
||||||
|
AddInput("Grad", "(Tensor) Input gradient");
|
||||||
|
AddInput("LearningRate", "(Tensor) Learning rate");
|
||||||
|
|
||||||
|
AddOutput("ParamOut", "(Tensor) Output parameter");
|
||||||
|
|
||||||
|
AddAttr<float>("clip",
|
||||||
|
"(float, default 0.9) "
|
||||||
|
"Exponential decay rate for the "
|
||||||
|
"1st moment estimates.")
|
||||||
|
.SetDefault(10.0f);
|
||||||
|
AddAttr<float>("batch_size",
|
||||||
|
"(float, default 0.999) "
|
||||||
|
"exponential decay rate for the weighted "
|
||||||
|
"infinity norm estimates.")
|
||||||
|
.SetDefault(16.0f);
|
||||||
|
AddAttr<float>("sigma",
|
||||||
|
"(float, default 1.0e-8) "
|
||||||
|
"Constant for numerical stability")
|
||||||
|
.SetDefault(1.0f);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Dpsgd Optimizer.
|
||||||
|
|
||||||
|
We implement the Dpsgd optimizer according to CCS16 paper -
|
||||||
|
Deep Learning with Differential Privacy.
|
||||||
|
|
||||||
|
Dpsgd updates:
|
||||||
|
CCS16 - Deep Learning with Differential Privacy.
|
||||||
|
[https://arxiv.org/abs/1607.00133]
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(dpsgd, ops::DpsgdOp, ops::DpsgdOpMaker);
|
||||||
|
REGISTER_OP_CPU_KERNEL(
|
||||||
|
dpsgd, ops::DpsgdOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||||
|
ops::DpsgdOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,114 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
#include <iostream>
|
||||||
|
#include "paddle/fluid/framework/eigen.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename DeviceContext, typename T>
|
||||||
|
class DpsgdOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||||
|
const auto *param_var = ctx.InputVar("Param");
|
||||||
|
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
|
||||||
|
"The Var(%s)'s type should be LoDTensor, "
|
||||||
|
"but the received is %s",
|
||||||
|
ctx.Inputs("Param").front(),
|
||||||
|
framework::ToTypeName(param_var->Type()));
|
||||||
|
|
||||||
|
const auto *grad_var = ctx.InputVar("Grad");
|
||||||
|
PADDLE_ENFORCE_EQ(grad_var->IsType<framework::LoDTensor>(), true,
|
||||||
|
"The Var(%s)'s type should be LoDTensor, "
|
||||||
|
"but the received is %s",
|
||||||
|
ctx.Inputs("Grad").front(),
|
||||||
|
framework::ToTypeName(grad_var->Type()));
|
||||||
|
|
||||||
|
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
|
||||||
|
|
||||||
|
const auto *param = ctx.Input<framework::Tensor>("Param");
|
||||||
|
const auto *grad = ctx.Input<framework::Tensor>("Grad");
|
||||||
|
|
||||||
|
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
|
||||||
|
|
||||||
|
auto sz = param_out->numel();
|
||||||
|
PADDLE_ENFORCE_EQ(param->numel(), sz);
|
||||||
|
PADDLE_ENFORCE_EQ(grad->numel(), sz);
|
||||||
|
|
||||||
|
const T *lr = learning_rate->data<T>();
|
||||||
|
const T *param_data = param->data<T>();
|
||||||
|
const T *grad_data = grad->data<T>();
|
||||||
|
|
||||||
|
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
|
||||||
|
T clip = static_cast<T>(ctx.Attr<float>("clip"));
|
||||||
|
T batch_size = static_cast<T>(ctx.Attr<float>("batch_size"));
|
||||||
|
T sigma = static_cast<T>(ctx.Attr<float>("sigma"));
|
||||||
|
|
||||||
|
// compute clipping
|
||||||
|
float l2_norm = 0.0;
|
||||||
|
for (int64_t i = 0; i < grad->numel(); ++i) {
|
||||||
|
l2_norm = l2_norm + grad_data[i] * grad_data[i];
|
||||||
|
}
|
||||||
|
l2_norm = std::sqrt(l2_norm);
|
||||||
|
|
||||||
|
float scale = 1.0;
|
||||||
|
if (l2_norm > clip) {
|
||||||
|
scale = l2_norm / clip;
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate gaussian noise.
|
||||||
|
// [https://en.wikipedia.org/wiki/Box-Muller_transform]
|
||||||
|
float V1, V2, S;
|
||||||
|
float X;
|
||||||
|
float mu = 0.0;
|
||||||
|
float U1, U2;
|
||||||
|
unsigned seed = (unsigned int)(time(NULL));
|
||||||
|
std::minstd_rand engine;
|
||||||
|
engine.seed(seed);
|
||||||
|
std::uniform_real_distribution<T> dist(0.0, 1.0);
|
||||||
|
do {
|
||||||
|
// srand((unsigned int)(time(NULL)));
|
||||||
|
// U1 = (rand() * 1.0) / RAND_MAX;
|
||||||
|
// U2 = (rand() * 1.0) / RAND_MAX;
|
||||||
|
// U1 = rand_rr(&seed) * (1.0 / RAND_MAX);
|
||||||
|
// U2 = rand_rr(&seed) * (1.0 / RAND_MAX);
|
||||||
|
U1 = dist(engine);
|
||||||
|
U2 = dist(engine);
|
||||||
|
V1 = 2 * U1 - 1;
|
||||||
|
V2 = 2 * U2 - 1;
|
||||||
|
S = V1 * V1 + V2 * V2;
|
||||||
|
} while (S >= 1 || S == 0);
|
||||||
|
|
||||||
|
X = V1 * sqrt(-2 * log(S) / S);
|
||||||
|
|
||||||
|
float gaussian_noise = mu + X * sigma;
|
||||||
|
|
||||||
|
// update parameters
|
||||||
|
for (int64_t i = 0; i < grad->numel(); ++i) {
|
||||||
|
out_data[i] =
|
||||||
|
param_data[i] -
|
||||||
|
lr[0] * (grad_data[i] / scale + gaussian_noise / batch_size);
|
||||||
|
}
|
||||||
|
// CCS16 - Deep Learning with Differential Privacy.
|
||||||
|
// [https://arxiv.org/abs/1607.00133]
|
||||||
|
} // Compute
|
||||||
|
}; // class
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
|
||||||
|
|
||||||
|
class TestDpsgdOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
'''Test Dpsgd Operator with supplied attributes
|
||||||
|
'''
|
||||||
|
self.op_type = "dpsgd"
|
||||||
|
param = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||||
|
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32")
|
||||||
|
|
||||||
|
learning_rate = 0.001
|
||||||
|
clip = 10000.0
|
||||||
|
batch_size = 16.0
|
||||||
|
sigma = 0.0
|
||||||
|
|
||||||
|
self.inputs = {
|
||||||
|
'Param': param,
|
||||||
|
'Grad': grad,
|
||||||
|
'LearningRate': np.array([learning_rate]).astype("float32")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.attrs = {'clip': clip, 'batch_size': batch_size, 'sigma': sigma}
|
||||||
|
|
||||||
|
param_out = dpsgd_step(self.inputs, self.attrs)
|
||||||
|
|
||||||
|
self.outputs = {'ParamOut': param_out}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
def dpsgd_step(inputs, attributes):
|
||||||
|
'''
|
||||||
|
Simulate one step of the dpsgd optimizer
|
||||||
|
:param inputs: dict of inputs
|
||||||
|
:param attributes: dict of attributes
|
||||||
|
:return tuple: tuple of output param, moment, inf_norm and
|
||||||
|
beta1 power accumulator
|
||||||
|
'''
|
||||||
|
param = inputs['Param']
|
||||||
|
grad = inputs['Grad']
|
||||||
|
lr = inputs['LearningRate']
|
||||||
|
|
||||||
|
clip = attributes['clip']
|
||||||
|
batch_size = attributes['batch_size']
|
||||||
|
sigma = attributes['sigma']
|
||||||
|
|
||||||
|
param_out = param - lr * grad
|
||||||
|
|
||||||
|
return param_out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue