parent
							
								
									e6f20bfb64
								
							
						
					
					
						commit
						c4c2562e8b
					
				| After Width: | Height: | Size: 36 KiB | 
| After Width: | Height: | Size: 37 KiB | 
| @ -0,0 +1,17 @@ | |||||||
|  | #!/bin/bash | ||||||
|  | # Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | 
 | ||||||
|  | python3 train_eval.py > result.log 2>&1 &  | ||||||
| @ -0,0 +1,99 @@ | |||||||
|  | # Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """ | ||||||
|  | MovieLens Environment. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import random | ||||||
|  | import numpy as np | ||||||
|  | from mindspore import Tensor | ||||||
|  | 
 | ||||||
|  | _MAX_NUM_ACTIONS = 1682 | ||||||
|  | _NUM_USERS = 943 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def load_movielens_data(data_file): | ||||||
|  |     """Loads the movielens data and returns the ratings matrix.""" | ||||||
|  |     ratings_matrix = np.zeros([_NUM_USERS, _MAX_NUM_ACTIONS]) | ||||||
|  |     with open(data_file, 'r') as f: | ||||||
|  |         for line in f.readlines(): | ||||||
|  |             row_infos = line.strip().split() | ||||||
|  |             user_id = int(row_infos[0]) | ||||||
|  |             item_id = int(row_infos[1]) | ||||||
|  |             rating = float(row_infos[2]) | ||||||
|  |             ratings_matrix[user_id - 1, item_id - 1] = rating | ||||||
|  |     return ratings_matrix | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MovieLensEnv: | ||||||
|  |     """ | ||||||
|  |     MovieLens dataset environment for bandit algorithms. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         data_file(str): path of movielens file, e.g. 'ua.base'. | ||||||
|  |         num_movies(int): number of movies for choices. | ||||||
|  |         rank_k(int): the dim of feature. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Environment for bandit algorithms. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, data_file, num_movies, rank_k): | ||||||
|  |         # Initialization | ||||||
|  |         self._num_actions = num_movies | ||||||
|  |         self._context_dim = rank_k | ||||||
|  | 
 | ||||||
|  |         # Load Movielens dataset | ||||||
|  |         self._data_matrix = load_movielens_data(data_file) | ||||||
|  |         # Keep only the first items | ||||||
|  |         self._data_matrix = self._data_matrix[:, :num_movies] | ||||||
|  |         # Filter the users with at least one rating score | ||||||
|  |         nonzero_users = list( | ||||||
|  |             np.nonzero( | ||||||
|  |                 np.sum( | ||||||
|  |                     self._data_matrix, | ||||||
|  |                     axis=1) > 0.0)[0]) | ||||||
|  |         self._data_matrix = self._data_matrix[nonzero_users, :] | ||||||
|  |         # Normalize the data_matrix into -1~1 | ||||||
|  |         self._data_matrix = 0.4 * (self._data_matrix - 2.5) | ||||||
|  | 
 | ||||||
|  |         # Compute the SVD # Only keep the largest rank_k singular values | ||||||
|  |         u, s, vh = np.linalg.svd(self._data_matrix, full_matrices=False) | ||||||
|  |         u_hat = u[:, :rank_k] * np.sqrt(s[:rank_k]) | ||||||
|  |         v_hat = np.transpose(np.transpose( | ||||||
|  |             vh[:rank_k, :]) * np.sqrt(s[:rank_k])) | ||||||
|  |         self._approx_ratings_matrix = np.matmul( | ||||||
|  |             u_hat, v_hat).astype(np.float32) | ||||||
|  | 
 | ||||||
|  |         # Prepare feature for user i and item j: u[i,:] * vh[:,j] | ||||||
|  |         # (elementwise product of user feature and item feature) | ||||||
|  |         self._ground_truth = s | ||||||
|  |         self._current_user = 0 | ||||||
|  |         self._feature = np.expand_dims(u[:, :rank_k], axis=1) * \ | ||||||
|  |             np.expand_dims(np.transpose(vh[:rank_k, :]), axis=0) | ||||||
|  |         self._feature = self._feature.astype(np.float32) | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def ground_truth(self): | ||||||
|  |         return self._ground_truth | ||||||
|  | 
 | ||||||
|  |     def observation(self): | ||||||
|  |         """random select a user and return its feature.""" | ||||||
|  |         sampled_user = random.randint(0, self._data_matrix.shape[0] - 1) | ||||||
|  |         self._current_user = sampled_user | ||||||
|  |         return Tensor(self._feature[sampled_user]) | ||||||
|  | 
 | ||||||
|  |     def current_rewards(self): | ||||||
|  |         """rewards for current user.""" | ||||||
|  |         return Tensor(self._approx_ratings_matrix[self._current_user]) | ||||||
| @ -0,0 +1,143 @@ | |||||||
|  | # Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """ | ||||||
|  | Linear UCB with locally differentially private. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import math | ||||||
|  | import numpy as np | ||||||
|  | 
 | ||||||
|  | import mindspore | ||||||
|  | import mindspore.nn as nn | ||||||
|  | from mindspore import Tensor | ||||||
|  | from mindspore.ops import operations as P | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class LinUCB(nn.Cell): | ||||||
|  |     """ | ||||||
|  |     Linear UCB with locally differentially private bandits learning. | ||||||
|  | 
 | ||||||
|  |     Args: | ||||||
|  |         context_dim(int): dim of input feature. | ||||||
|  |         epsilon(float): epsilon for private parameter. | ||||||
|  |         delta(float): delta for private parameter. | ||||||
|  |         alpha(float): failure probability. | ||||||
|  |         T(float/int): number of iterations. | ||||||
|  | 
 | ||||||
|  |     Returns: | ||||||
|  |         Tuple of Tensors: gradients to update parameters and optimal action. | ||||||
|  |     """ | ||||||
|  |     def __init__(self, context_dim, epsilon=100, delta=0.1, alpha=0.1, T=1e5): | ||||||
|  |         super(LinUCB, self).__init__() | ||||||
|  |         self.matmul = P.MatMul() | ||||||
|  |         self.expand_dims = P.ExpandDims() | ||||||
|  |         self.transpose = P.Transpose() | ||||||
|  |         self.reduce_sum = P.ReduceSum() | ||||||
|  |         self.squeeze = P.Squeeze(1) | ||||||
|  |         self.argmax = P.Argmax() | ||||||
|  |         self.reduce_max = P.ReduceMax() | ||||||
|  | 
 | ||||||
|  |         # Basic variables | ||||||
|  |         self._context_dim = context_dim | ||||||
|  |         self._epsilon = epsilon | ||||||
|  |         self._delta = delta | ||||||
|  |         self._alpha = alpha | ||||||
|  |         self._T = int(T) | ||||||
|  | 
 | ||||||
|  |         # Parameters | ||||||
|  |         self._V = Tensor( | ||||||
|  |             np.zeros( | ||||||
|  |                 (context_dim, | ||||||
|  |                  context_dim), | ||||||
|  |                 dtype=np.float32)) | ||||||
|  |         self._u = Tensor(np.zeros((context_dim,), dtype=np.float32)) | ||||||
|  |         self._theta = Tensor(np.zeros((context_dim,), dtype=np.float32)) | ||||||
|  | 
 | ||||||
|  |         # \sigma = 4*\sqrt{2*\ln{\farc{1.25}{\delta}}}/\epsilon | ||||||
|  |         self._sigma = 4 * \ | ||||||
|  |             math.sqrt(math.log(1.25 / self._delta)) / self._epsilon | ||||||
|  |         self._c = 0.1 | ||||||
|  |         self._step = 1 | ||||||
|  |         self._regret = 0 | ||||||
|  |         self._current_regret = 0 | ||||||
|  |         self.inverse_matrix() | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def theta(self): | ||||||
|  |         return self._theta | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def regret(self): | ||||||
|  |         return self._regret | ||||||
|  | 
 | ||||||
|  |     @property | ||||||
|  |     def current_regret(self): | ||||||
|  |         return self._current_regret | ||||||
|  | 
 | ||||||
|  |     def inverse_matrix(self): | ||||||
|  |         """compute the inverse matrix of parameter matrix.""" | ||||||
|  |         Vc = self._V + Tensor(np.eye(self._context_dim, | ||||||
|  |                                      dtype=np.float32)) * self._c | ||||||
|  |         self._Vc_inv = Tensor(np.linalg.inv(Vc.asnumpy()), mindspore.float32) | ||||||
|  | 
 | ||||||
|  |     def update_status(self, step): | ||||||
|  |         """update status variables.""" | ||||||
|  |         t = max(step, 1) | ||||||
|  |         T = self._T | ||||||
|  |         d = self._context_dim | ||||||
|  |         alpha = self._alpha | ||||||
|  |         sigma = self._sigma | ||||||
|  | 
 | ||||||
|  |         gamma = sigma * \ | ||||||
|  |             math.sqrt(t) * (4 * math.sqrt(d) + 2 * math.log(2 * T / alpha)) | ||||||
|  |         self._c = 2 * gamma | ||||||
|  |         self._beta = 2 * sigma * math.sqrt(d * math.log(T)) + (math.sqrt( | ||||||
|  |             3 * gamma) + sigma * math.sqrt(d * t / gamma)) * d * math.log(T) | ||||||
|  | 
 | ||||||
|  |     def construct(self, x, rewards): | ||||||
|  |         """compute the perturbed gradients for parameters.""" | ||||||
|  |         # Choose optimal action | ||||||
|  |         x_transpose = self.transpose(x, (1, 0)) | ||||||
|  |         scores_a = self.squeeze(self.matmul(x, self.expand_dims(self._theta, 1))) | ||||||
|  |         scores_b = x_transpose * self.matmul(self._Vc_inv, x_transpose) | ||||||
|  |         scores_b = self.reduce_sum(scores_b, 0) | ||||||
|  |         scores = scores_a + self._beta * scores_b | ||||||
|  |         max_a = self.argmax(scores) | ||||||
|  |         xa = x[max_a] | ||||||
|  |         xaxat = self.matmul(self.expand_dims(xa, -1), self.expand_dims(xa, 0)) | ||||||
|  |         y = rewards[max_a] | ||||||
|  |         y_max = self.reduce_max(rewards) | ||||||
|  |         y_diff = y_max - y | ||||||
|  |         self._current_regret = float(y_diff.asnumpy()) | ||||||
|  |         self._regret += self._current_regret | ||||||
|  | 
 | ||||||
|  |         # Prepare noise | ||||||
|  |         B = np.random.normal(0, self._sigma, size=xaxat.shape) | ||||||
|  |         B = np.triu(B) | ||||||
|  |         B += B.transpose() - np.diag(B.diagonal()) | ||||||
|  |         B = Tensor(B.astype(np.float32)) | ||||||
|  |         Xi = np.random.normal(0, self._sigma, size=xa.shape) | ||||||
|  |         Xi = Tensor(Xi.astype(np.float32)) | ||||||
|  | 
 | ||||||
|  |         # Add noise and update parameters | ||||||
|  |         return xaxat + B, xa * y + Xi, max_a | ||||||
|  | 
 | ||||||
|  |     def server_update(self, xaxat, xay): | ||||||
|  |         """update parameters with perturbed gradients.""" | ||||||
|  |         self._V += xaxat | ||||||
|  |         self._u += xay | ||||||
|  |         self.inverse_matrix() | ||||||
|  |         theta = self.matmul(self._Vc_inv, self.expand_dims(self._u, 1)) | ||||||
|  |         self._theta = self.squeeze(theta) | ||||||
| @ -0,0 +1,89 @@ | |||||||
|  | # Copyright 2020 Huawei Technologies Co., Ltd | ||||||
|  | # | ||||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  | # you may not use this file except in compliance with the License. | ||||||
|  | # You may obtain a copy of the License at | ||||||
|  | # | ||||||
|  | # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  | # | ||||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  | # See the License for the specific language governing permissions and | ||||||
|  | # limitations under the License. | ||||||
|  | # ============================================================================ | ||||||
|  | """ | ||||||
|  | train/eval. | ||||||
|  | """ | ||||||
|  | 
 | ||||||
|  | import argparse | ||||||
|  | import time | ||||||
|  | import numpy as np | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | 
 | ||||||
|  | from src.dataset import MovieLensEnv | ||||||
|  | from src.linucb import LinUCB | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def parse_args(): | ||||||
|  |     """parse args""" | ||||||
|  |     parser = argparse.ArgumentParser() | ||||||
|  |     parser.add_argument('--data_file', type=str, default='ua.base', | ||||||
|  |                         help='data file for movielens') | ||||||
|  |     parser.add_argument('--rank_k', type=int, default=20, | ||||||
|  |                         help='rank for data matrix') | ||||||
|  |     parser.add_argument('--num_actions', type=int, default=20, | ||||||
|  |                         help='movie number for choices') | ||||||
|  |     parser.add_argument('--epsilon', type=float, default=8e5, | ||||||
|  |                         help='epsilon for differentially private') | ||||||
|  |     parser.add_argument('--delta', type=float, default=1e-1, | ||||||
|  |                         help='delta for differentially private') | ||||||
|  |     parser.add_argument('--alpha', type=float, default=1e-1, | ||||||
|  |                         help='failure probability') | ||||||
|  |     parser.add_argument('--iter_num', type=float, default=1e6, | ||||||
|  |                         help='iteration number for training') | ||||||
|  | 
 | ||||||
|  |     args_opt = parser.parse_args() | ||||||
|  |     return args_opt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     # build environment | ||||||
|  |     args = parse_args() | ||||||
|  |     env = MovieLensEnv(args.data_file, args.num_actions, args.rank_k) | ||||||
|  | 
 | ||||||
|  |     # Linear UCB | ||||||
|  |     lin_ucb = LinUCB( | ||||||
|  |         args.rank_k, | ||||||
|  |         epsilon=args.epsilon, | ||||||
|  |         delta=args.delta, | ||||||
|  |         alpha=args.alpha, | ||||||
|  |         T=args.iter_num) | ||||||
|  | 
 | ||||||
|  |     print('start') | ||||||
|  |     start_time = time.time() | ||||||
|  |     cumulative_regrets = [] | ||||||
|  |     for i in range(int(args.iter_num)): | ||||||
|  |         x = env.observation() | ||||||
|  |         rewards = env.current_rewards() | ||||||
|  |         lin_ucb.update_status(i + 1) | ||||||
|  |         xaxat, xay, max_a = lin_ucb(x, rewards) | ||||||
|  |         cumulative_regrets.append(float(lin_ucb.regret)) | ||||||
|  |         lin_ucb.server_update(xaxat, xay) | ||||||
|  |         diff = np.abs(lin_ucb.theta.asnumpy() - env.ground_truth).sum() | ||||||
|  |         print( | ||||||
|  |             f'--> Step: {i}, diff: {diff:.3f},' | ||||||
|  |             f'current_regret: {lin_ucb.current_regret:.3f},' | ||||||
|  |             f'cumulative regret: {lin_ucb.regret:.3f}') | ||||||
|  |     end_time = time.time() | ||||||
|  |     print(f'Regret: {lin_ucb.regret}, cost time: {end_time-start_time:.3f}s') | ||||||
|  |     print(f'theta: {lin_ucb.theta.asnumpy()}') | ||||||
|  |     print(f'   gt: {env.ground_truth}') | ||||||
|  | 
 | ||||||
|  |     np.save(f'e_{args.epsilon:.1e}.npy', cumulative_regrets) | ||||||
|  |     plt.plot( | ||||||
|  |         range(len(cumulative_regrets)), | ||||||
|  |         cumulative_regrets, | ||||||
|  |         label=f'epsilon={args.epsilon:.1e}') | ||||||
|  |     plt.legend() | ||||||
|  |     plt.savefig(f'regret_{args.epsilon:.1e}.png') | ||||||
					Loading…
					
					
				
		Reference in new issue