215 lines
7.3 KiB
Python
215 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
from urllib import error, request
|
|
|
|
DRY_RUN = object()
|
|
|
|
|
|
def compress_query(query: str) -> str:
|
|
query = query.replace("\n", "")
|
|
query = re.sub(r"\s+", " ", query)
|
|
return query
|
|
|
|
|
|
def post(url: str, body: Optional[Any] = None, auth: Optional[Tuple[str, str]] = None):
|
|
logging.info(f"Requesting POST to {url} with {body}")
|
|
headers: Dict[Any, Any] = {}
|
|
req = request.Request(url, headers=headers, method="POST")
|
|
if auth is not None:
|
|
auth_str = base64.b64encode(f"{auth[0]}:{auth[1]}".encode())
|
|
req.add_header("Authorization", f"Basic {auth_str.decode()}")
|
|
|
|
if body is None:
|
|
body = ""
|
|
|
|
req.add_header("Content-Type", "application/json; charset=utf-8")
|
|
data = json.dumps(body).encode("utf-8")
|
|
req.add_header("Content-Length", str(len(data)))
|
|
|
|
with request.urlopen(req, data) as response:
|
|
return response.read()
|
|
|
|
|
|
def dry_run_token(is_dry_run: bool) -> Any:
|
|
if is_dry_run:
|
|
return DRY_RUN
|
|
return os.environ["GITHUB_TOKEN"]
|
|
|
|
|
|
class GitHubRepo:
|
|
GRAPHQL_URL = "https://api.github.com/graphql"
|
|
|
|
def __init__(self, user, repo, token, test_data=None):
|
|
self.token = token
|
|
self.user = user
|
|
self.repo = repo
|
|
self.test_data = test_data
|
|
self.num_calls = 0
|
|
self.base = f"https://api.github.com/repos/{user}/{repo}/"
|
|
|
|
def headers(self):
|
|
return {
|
|
"Authorization": f"Bearer {self.token}",
|
|
}
|
|
|
|
def dry_run(self) -> bool:
|
|
return self.token == DRY_RUN
|
|
|
|
def graphql(
|
|
self, query: str, variables: Optional[Dict[str, str]] = None
|
|
) -> Dict[str, Any]:
|
|
query = compress_query(query)
|
|
if variables is None:
|
|
variables = {}
|
|
|
|
response = self._request(
|
|
self.GRAPHQL_URL,
|
|
{"query": query, "variables": variables},
|
|
method="POST",
|
|
)
|
|
if self.dry_run():
|
|
return self.testing_response("POST", self.GRAPHQL_URL)
|
|
|
|
if "data" not in response:
|
|
msg = f"Error fetching data with query:\n{query}\n\nvariables:\n{variables}\n\nerror:\n{json.dumps(response, indent=2)}"
|
|
raise RuntimeError(msg)
|
|
return response
|
|
|
|
def testing_response(self, method: str, url: str) -> Any:
|
|
self.num_calls += 1
|
|
key = f"[{self.num_calls}] {method} - {url}"
|
|
if self.test_data is not None and key in self.test_data:
|
|
return self.test_data[key]
|
|
logging.info(f"Unknown URL in dry run: {key}")
|
|
return {}
|
|
|
|
def _request(
|
|
self, full_url: str, body: Dict[str, Any], method: str
|
|
) -> Dict[str, Any]:
|
|
if self.dry_run():
|
|
logging.info(
|
|
f"Dry run, would have requested a {method} to {full_url} with {body}"
|
|
)
|
|
return self.testing_response(method, full_url)
|
|
|
|
logging.info(f"Requesting {method} to {full_url} with {body}")
|
|
req = request.Request(full_url, headers=self.headers(), method=method.upper())
|
|
req.add_header("Content-Type", "application/json; charset=utf-8")
|
|
data = json.dumps(body).encode("utf-8")
|
|
req.add_header("Content-Length", str(len(data)))
|
|
|
|
try:
|
|
with request.urlopen(req, data) as response:
|
|
content = response.read()
|
|
except error.HTTPError as e:
|
|
msg = str(e)
|
|
error_data = e.read().decode()
|
|
raise RuntimeError(f"Error response: {msg}\n{error_data}") from e
|
|
|
|
logging.info(f"Got response from {full_url}: {content}")
|
|
try:
|
|
response = json.loads(content)
|
|
except json.decoder.JSONDecodeError:
|
|
return content
|
|
|
|
return response
|
|
|
|
def put(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
return self._request(self.base + url, data, method="PUT")
|
|
|
|
def patch(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
return self._request(self.base + url, data, method="PATCH")
|
|
|
|
def post(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
return self._request(self.base + url, data, method="POST")
|
|
|
|
def get(self, url: str) -> Dict[str, Any]:
|
|
if self.dry_run():
|
|
logging.info(f"Dry run, would have requested a GET to {url}")
|
|
return self.testing_response("GET", url)
|
|
url = self.base + url
|
|
logging.info(f"Requesting GET to {url}")
|
|
req = request.Request(url, headers=self.headers())
|
|
with request.urlopen(req) as response:
|
|
response = json.loads(response.read())
|
|
return response
|
|
|
|
def delete(self, url: str) -> Dict[str, Any]:
|
|
if self.dry_run():
|
|
logging.info(f"Dry run, would have requested a DELETE to {url}")
|
|
return self.testing_response("DELETE", url)
|
|
url = self.base + url
|
|
logging.info(f"Requesting DELETE to {url}")
|
|
req = request.Request(url, headers=self.headers(), method="DELETE")
|
|
with request.urlopen(req) as response:
|
|
response = json.loads(response.read())
|
|
return response
|
|
|
|
|
|
def parse_remote(remote: str) -> Tuple[str, str]:
|
|
"""
|
|
Get a GitHub (user, repo) pair out of a git remote
|
|
"""
|
|
if remote.startswith("https://"):
|
|
# Parse HTTP remote
|
|
parts = remote.split("/")
|
|
if len(parts) < 2:
|
|
raise RuntimeError(f"Unable to parse remote '{remote}'")
|
|
user, repo = parts[-2], parts[-1].replace(".git", "")
|
|
else:
|
|
# Parse SSH remote
|
|
m = re.search(r":(.*)/(.*)\.git", remote)
|
|
if m is None or len(m.groups()) != 2:
|
|
raise RuntimeError(f"Unable to parse remote '{remote}'")
|
|
user, repo = m.groups()
|
|
|
|
user = os.getenv("DEBUG_USER", user)
|
|
repo = os.getenv("DEBUG_REPO", repo)
|
|
return user, repo
|
|
|
|
|
|
def git(command, **kwargs):
|
|
command = ["git"] + command
|
|
logging.info(f"Running {command}")
|
|
proc = subprocess.run(command, stdout=subprocess.PIPE, encoding="utf-8", **kwargs)
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"Command failed {command}:\nstdout:\n{proc.stdout}")
|
|
return proc.stdout.strip()
|
|
|
|
|
|
def find_ccs(body: str) -> List[str]:
|
|
matches = re.findall(r"(cc( @[-A-Za-z0-9]+)+)", body, flags=re.MULTILINE)
|
|
matches = [full for full, last in matches]
|
|
|
|
reviewers = set()
|
|
for match in matches:
|
|
if match.startswith("cc "):
|
|
match = match.replace("cc ", "")
|
|
users = [x.strip() for x in match.split("@")]
|
|
reviewers.update(users)
|
|
|
|
return [x for x in reviewers if x != ""]
|