parent
9ee77b1f41
commit
d773c6c94e
@ -0,0 +1,74 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" For the PR that only modified the unit test, get cases in pull request. """
|
||||
|
||||
import os
|
||||
import json
|
||||
from github import Github
|
||||
|
||||
PADDLE_ROOT = os.getenv('PADDLE_ROOT', '/paddle/')
|
||||
|
||||
|
||||
class PRChecker(object):
|
||||
""" PR Checker. """
|
||||
|
||||
def __init__(self):
|
||||
self.github = Github(os.getenv('GITHUB_API_TOKEN'), timeout=60)
|
||||
self.repo = self.github.get_repo('PaddlePaddle/Paddle')
|
||||
self.pr = None
|
||||
|
||||
def init(self):
|
||||
""" Get pull request. """
|
||||
pr_id = os.getenv('GIT_PR_ID')
|
||||
if not pr_id:
|
||||
print('No PR ID')
|
||||
exit(0)
|
||||
self.pr = self.repo.get_pull(int(pr_id))
|
||||
|
||||
def get_pr_files(self):
|
||||
""" Get files in pull request. """
|
||||
page = 0
|
||||
file_list = []
|
||||
while True:
|
||||
files = self.pr.get_files().get_page(page)
|
||||
if not files:
|
||||
break
|
||||
for f in files:
|
||||
file_list.append(PADDLE_ROOT + f.filename)
|
||||
page += 1
|
||||
return file_list
|
||||
|
||||
def get_pr_ut(self):
|
||||
""" Get unit tests in pull request. """
|
||||
ut_list = []
|
||||
file_ut_map = None
|
||||
cmd = 'wget -q --no-check-certificate https://sys-p0.bj.bcebos.com/prec/file_ut.json'
|
||||
os.system(cmd)
|
||||
with open('file_ut.json') as jsonfile:
|
||||
file_ut_map = json.load(jsonfile)
|
||||
for f in self.get_pr_files():
|
||||
if f not in file_ut_map:
|
||||
return ''
|
||||
if f.endswith('.h') or f.endswith('.cu'):
|
||||
return ''
|
||||
else:
|
||||
ut_list.extend(file_ut_map.get(f))
|
||||
ut_list = list(set(ut_list))
|
||||
return ' '.join(ut_list)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pr_checker = PRChecker()
|
||||
pr_checker.init()
|
||||
print(pr_checker.get_pr_ut())
|
Loading…
Reference in new issue