parent
e6f20bfb64
commit
c4c2562e8b
After Width: | Height: | Size: 36 KiB |
After Width: | Height: | Size: 37 KiB |
@ -0,0 +1,17 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
python3 train_eval.py > result.log 2>&1 &
|
@ -0,0 +1,99 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
MovieLens Environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import Tensor
|
||||||
|
|
||||||
|
_MAX_NUM_ACTIONS = 1682
|
||||||
|
_NUM_USERS = 943
|
||||||
|
|
||||||
|
|
||||||
|
def load_movielens_data(data_file):
|
||||||
|
"""Loads the movielens data and returns the ratings matrix."""
|
||||||
|
ratings_matrix = np.zeros([_NUM_USERS, _MAX_NUM_ACTIONS])
|
||||||
|
with open(data_file, 'r') as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
row_infos = line.strip().split()
|
||||||
|
user_id = int(row_infos[0])
|
||||||
|
item_id = int(row_infos[1])
|
||||||
|
rating = float(row_infos[2])
|
||||||
|
ratings_matrix[user_id - 1, item_id - 1] = rating
|
||||||
|
return ratings_matrix
|
||||||
|
|
||||||
|
|
||||||
|
class MovieLensEnv:
|
||||||
|
"""
|
||||||
|
MovieLens dataset environment for bandit algorithms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_file(str): path of movielens file, e.g. 'ua.base'.
|
||||||
|
num_movies(int): number of movies for choices.
|
||||||
|
rank_k(int): the dim of feature.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Environment for bandit algorithms.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_file, num_movies, rank_k):
|
||||||
|
# Initialization
|
||||||
|
self._num_actions = num_movies
|
||||||
|
self._context_dim = rank_k
|
||||||
|
|
||||||
|
# Load Movielens dataset
|
||||||
|
self._data_matrix = load_movielens_data(data_file)
|
||||||
|
# Keep only the first items
|
||||||
|
self._data_matrix = self._data_matrix[:, :num_movies]
|
||||||
|
# Filter the users with at least one rating score
|
||||||
|
nonzero_users = list(
|
||||||
|
np.nonzero(
|
||||||
|
np.sum(
|
||||||
|
self._data_matrix,
|
||||||
|
axis=1) > 0.0)[0])
|
||||||
|
self._data_matrix = self._data_matrix[nonzero_users, :]
|
||||||
|
# Normalize the data_matrix into -1~1
|
||||||
|
self._data_matrix = 0.4 * (self._data_matrix - 2.5)
|
||||||
|
|
||||||
|
# Compute the SVD # Only keep the largest rank_k singular values
|
||||||
|
u, s, vh = np.linalg.svd(self._data_matrix, full_matrices=False)
|
||||||
|
u_hat = u[:, :rank_k] * np.sqrt(s[:rank_k])
|
||||||
|
v_hat = np.transpose(np.transpose(
|
||||||
|
vh[:rank_k, :]) * np.sqrt(s[:rank_k]))
|
||||||
|
self._approx_ratings_matrix = np.matmul(
|
||||||
|
u_hat, v_hat).astype(np.float32)
|
||||||
|
|
||||||
|
# Prepare feature for user i and item j: u[i,:] * vh[:,j]
|
||||||
|
# (elementwise product of user feature and item feature)
|
||||||
|
self._ground_truth = s
|
||||||
|
self._current_user = 0
|
||||||
|
self._feature = np.expand_dims(u[:, :rank_k], axis=1) * \
|
||||||
|
np.expand_dims(np.transpose(vh[:rank_k, :]), axis=0)
|
||||||
|
self._feature = self._feature.astype(np.float32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ground_truth(self):
|
||||||
|
return self._ground_truth
|
||||||
|
|
||||||
|
def observation(self):
|
||||||
|
"""random select a user and return its feature."""
|
||||||
|
sampled_user = random.randint(0, self._data_matrix.shape[0] - 1)
|
||||||
|
self._current_user = sampled_user
|
||||||
|
return Tensor(self._feature[sampled_user])
|
||||||
|
|
||||||
|
def current_rewards(self):
|
||||||
|
"""rewards for current user."""
|
||||||
|
return Tensor(self._approx_ratings_matrix[self._current_user])
|
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Linear UCB with locally differentially private.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class LinUCB(nn.Cell):
|
||||||
|
"""
|
||||||
|
Linear UCB with locally differentially private bandits learning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_dim(int): dim of input feature.
|
||||||
|
epsilon(float): epsilon for private parameter.
|
||||||
|
delta(float): delta for private parameter.
|
||||||
|
alpha(float): failure probability.
|
||||||
|
T(float/int): number of iterations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of Tensors: gradients to update parameters and optimal action.
|
||||||
|
"""
|
||||||
|
def __init__(self, context_dim, epsilon=100, delta=0.1, alpha=0.1, T=1e5):
|
||||||
|
super(LinUCB, self).__init__()
|
||||||
|
self.matmul = P.MatMul()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.reduce_sum = P.ReduceSum()
|
||||||
|
self.squeeze = P.Squeeze(1)
|
||||||
|
self.argmax = P.Argmax()
|
||||||
|
self.reduce_max = P.ReduceMax()
|
||||||
|
|
||||||
|
# Basic variables
|
||||||
|
self._context_dim = context_dim
|
||||||
|
self._epsilon = epsilon
|
||||||
|
self._delta = delta
|
||||||
|
self._alpha = alpha
|
||||||
|
self._T = int(T)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
self._V = Tensor(
|
||||||
|
np.zeros(
|
||||||
|
(context_dim,
|
||||||
|
context_dim),
|
||||||
|
dtype=np.float32))
|
||||||
|
self._u = Tensor(np.zeros((context_dim,), dtype=np.float32))
|
||||||
|
self._theta = Tensor(np.zeros((context_dim,), dtype=np.float32))
|
||||||
|
|
||||||
|
# \sigma = 4*\sqrt{2*\ln{\farc{1.25}{\delta}}}/\epsilon
|
||||||
|
self._sigma = 4 * \
|
||||||
|
math.sqrt(math.log(1.25 / self._delta)) / self._epsilon
|
||||||
|
self._c = 0.1
|
||||||
|
self._step = 1
|
||||||
|
self._regret = 0
|
||||||
|
self._current_regret = 0
|
||||||
|
self.inverse_matrix()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def theta(self):
|
||||||
|
return self._theta
|
||||||
|
|
||||||
|
@property
|
||||||
|
def regret(self):
|
||||||
|
return self._regret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_regret(self):
|
||||||
|
return self._current_regret
|
||||||
|
|
||||||
|
def inverse_matrix(self):
|
||||||
|
"""compute the inverse matrix of parameter matrix."""
|
||||||
|
Vc = self._V + Tensor(np.eye(self._context_dim,
|
||||||
|
dtype=np.float32)) * self._c
|
||||||
|
self._Vc_inv = Tensor(np.linalg.inv(Vc.asnumpy()), mindspore.float32)
|
||||||
|
|
||||||
|
def update_status(self, step):
|
||||||
|
"""update status variables."""
|
||||||
|
t = max(step, 1)
|
||||||
|
T = self._T
|
||||||
|
d = self._context_dim
|
||||||
|
alpha = self._alpha
|
||||||
|
sigma = self._sigma
|
||||||
|
|
||||||
|
gamma = sigma * \
|
||||||
|
math.sqrt(t) * (4 * math.sqrt(d) + 2 * math.log(2 * T / alpha))
|
||||||
|
self._c = 2 * gamma
|
||||||
|
self._beta = 2 * sigma * math.sqrt(d * math.log(T)) + (math.sqrt(
|
||||||
|
3 * gamma) + sigma * math.sqrt(d * t / gamma)) * d * math.log(T)
|
||||||
|
|
||||||
|
def construct(self, x, rewards):
|
||||||
|
"""compute the perturbed gradients for parameters."""
|
||||||
|
# Choose optimal action
|
||||||
|
x_transpose = self.transpose(x, (1, 0))
|
||||||
|
scores_a = self.squeeze(self.matmul(x, self.expand_dims(self._theta, 1)))
|
||||||
|
scores_b = x_transpose * self.matmul(self._Vc_inv, x_transpose)
|
||||||
|
scores_b = self.reduce_sum(scores_b, 0)
|
||||||
|
scores = scores_a + self._beta * scores_b
|
||||||
|
max_a = self.argmax(scores)
|
||||||
|
xa = x[max_a]
|
||||||
|
xaxat = self.matmul(self.expand_dims(xa, -1), self.expand_dims(xa, 0))
|
||||||
|
y = rewards[max_a]
|
||||||
|
y_max = self.reduce_max(rewards)
|
||||||
|
y_diff = y_max - y
|
||||||
|
self._current_regret = float(y_diff.asnumpy())
|
||||||
|
self._regret += self._current_regret
|
||||||
|
|
||||||
|
# Prepare noise
|
||||||
|
B = np.random.normal(0, self._sigma, size=xaxat.shape)
|
||||||
|
B = np.triu(B)
|
||||||
|
B += B.transpose() - np.diag(B.diagonal())
|
||||||
|
B = Tensor(B.astype(np.float32))
|
||||||
|
Xi = np.random.normal(0, self._sigma, size=xa.shape)
|
||||||
|
Xi = Tensor(Xi.astype(np.float32))
|
||||||
|
|
||||||
|
# Add noise and update parameters
|
||||||
|
return xaxat + B, xa * y + Xi, max_a
|
||||||
|
|
||||||
|
def server_update(self, xaxat, xay):
|
||||||
|
"""update parameters with perturbed gradients."""
|
||||||
|
self._V += xaxat
|
||||||
|
self._u += xay
|
||||||
|
self.inverse_matrix()
|
||||||
|
theta = self.matmul(self._Vc_inv, self.expand_dims(self._u, 1))
|
||||||
|
self._theta = self.squeeze(theta)
|
@ -0,0 +1,89 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
train/eval.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from src.dataset import MovieLensEnv
|
||||||
|
from src.linucb import LinUCB
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""parse args"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--data_file', type=str, default='ua.base',
|
||||||
|
help='data file for movielens')
|
||||||
|
parser.add_argument('--rank_k', type=int, default=20,
|
||||||
|
help='rank for data matrix')
|
||||||
|
parser.add_argument('--num_actions', type=int, default=20,
|
||||||
|
help='movie number for choices')
|
||||||
|
parser.add_argument('--epsilon', type=float, default=8e5,
|
||||||
|
help='epsilon for differentially private')
|
||||||
|
parser.add_argument('--delta', type=float, default=1e-1,
|
||||||
|
help='delta for differentially private')
|
||||||
|
parser.add_argument('--alpha', type=float, default=1e-1,
|
||||||
|
help='failure probability')
|
||||||
|
parser.add_argument('--iter_num', type=float, default=1e6,
|
||||||
|
help='iteration number for training')
|
||||||
|
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# build environment
|
||||||
|
args = parse_args()
|
||||||
|
env = MovieLensEnv(args.data_file, args.num_actions, args.rank_k)
|
||||||
|
|
||||||
|
# Linear UCB
|
||||||
|
lin_ucb = LinUCB(
|
||||||
|
args.rank_k,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
delta=args.delta,
|
||||||
|
alpha=args.alpha,
|
||||||
|
T=args.iter_num)
|
||||||
|
|
||||||
|
print('start')
|
||||||
|
start_time = time.time()
|
||||||
|
cumulative_regrets = []
|
||||||
|
for i in range(int(args.iter_num)):
|
||||||
|
x = env.observation()
|
||||||
|
rewards = env.current_rewards()
|
||||||
|
lin_ucb.update_status(i + 1)
|
||||||
|
xaxat, xay, max_a = lin_ucb(x, rewards)
|
||||||
|
cumulative_regrets.append(float(lin_ucb.regret))
|
||||||
|
lin_ucb.server_update(xaxat, xay)
|
||||||
|
diff = np.abs(lin_ucb.theta.asnumpy() - env.ground_truth).sum()
|
||||||
|
print(
|
||||||
|
f'--> Step: {i}, diff: {diff:.3f},'
|
||||||
|
f'current_regret: {lin_ucb.current_regret:.3f},'
|
||||||
|
f'cumulative regret: {lin_ucb.regret:.3f}')
|
||||||
|
end_time = time.time()
|
||||||
|
print(f'Regret: {lin_ucb.regret}, cost time: {end_time-start_time:.3f}s')
|
||||||
|
print(f'theta: {lin_ucb.theta.asnumpy()}')
|
||||||
|
print(f' gt: {env.ground_truth}')
|
||||||
|
|
||||||
|
np.save(f'e_{args.epsilon:.1e}.npy', cumulative_regrets)
|
||||||
|
plt.plot(
|
||||||
|
range(len(cumulative_regrets)),
|
||||||
|
cumulative_regrets,
|
||||||
|
label=f'epsilon={args.epsilon:.1e}')
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f'regret_{args.epsilon:.1e}.png')
|
Loading…
Reference in new issue