diff --git a/MANIFEST.in b/MANIFEST.in index c4329a1..64cc9d0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ -include *.txt *.ini *.cfg *.rst +include *.txt *.ini *.cfg *.rst .htpasswd .env recursive-include bifrost *.txt diff --git a/bifrost/auth.py b/bifrost/auth.py index c19d226..29fff3b 100644 --- a/bifrost/auth.py +++ b/bifrost/auth.py @@ -1,22 +1,44 @@ import os from crypt import crypt +from functools import lru_cache + +import pkg_resources +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials + +from . import config + +security = HTTPBasic() +@lru_cache() +def get_settings(): + return config.Settings -# def setup_auth(app): -# htpasswd: str = app.config.HTPASSWD -# print(htpasswd) -# -# @auth.verify_password -# def htpasswd(username, password): -# if not os.path.isfile(htpasswd): -# return None -# users = {} -# with open(htpasswd) as f: -# for line in f: -# login, pwd = line.split(':') -# users[login] = pwd.rstrip('\n') -# if username in users: -# return crypt(password, users[username]) == users[username] -# else: -# return False + +def validate(username: str, password: str, settings: config.Settings): + file = pkg_resources.resource_filename("bifrost", "../" + settings.htpasswd) + if not os.path.isfile(file): + return None + users = {} + with open(file, "r") as f: + for line in f: + login, pwd = line.split(":") + users[login] = pwd.rstrip("\n") + if username in users: + return crypt(password, users[username]) == users[username] + else: + return False + + +def get_current_username( + credentials: HTTPBasicCredentials = Depends(security), + settings: config.Settings = Depends(get_settings), +): + if not validate(credentials.username, credentials.password, settings): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect email or password", + headers={"WWW-Authenticate": "Basic"}, + ) + return credentials.username diff --git a/bifrost/digital_ocean.py b/bifrost/digital_ocean.py index 195beee..a68dbbb 100644 --- a/bifrost/digital_ocean.py +++ b/bifrost/digital_ocean.py @@ -3,11 +3,23 @@ import requests def create_domain_a_record(domain, name, ip_address, api_url_base, api_token, logger): - headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer {0}'.format(api_token)} + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {0}".format(api_token), + } api_url = f"{api_url_base}/domains/{domain}/records" - data = {"type": "A", "name": name, "data": ip_address, "priority": None, "port": None, "ttl": 1800, "weight": None, - "flags": None, "tag": None} + data = { + "type": "A", + "name": name, + "data": ip_address, + "priority": None, + "port": None, + "ttl": 1800, + "weight": None, + "flags": None, + "tag": None, + } response = requests.post(api_url, headers=headers, json=data) if response.status_code > 499: @@ -28,15 +40,22 @@ def create_domain_a_record(domain, name, ip_address, api_url_base, api_token, lo logger.error(f"[!] [{response.status_code}] Unexpected redirect") return None elif response.status_code == 201: - logger.info(f'{domain} updated to {ip_address}') + logger.info(f"{domain} updated to {ip_address}") return json.loads(response.content) else: - logger.error(f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}") + logger.error( + f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}" + ) return None -def update_domain_a_record(domain, record_id, new_ip_address, api_url_base, api_token, logger): - headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer {0}'.format(api_token)} +def update_domain_a_record( + domain, record_id, new_ip_address, api_url_base, api_token, logger +): + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {0}".format(api_token), + } api_url = f"{api_url_base}/domains/{domain}/records/{record_id}" data = {"data": new_ip_address} @@ -53,28 +72,37 @@ def update_domain_a_record(domain, record_id, new_ip_address, api_url_base, api_ return None elif response.status_code > 399: logger.error(f"[!] [{response.status_code}] Bad Request") - logger.error(f"Domain: {domain}, Record ID: {record_id}, New IP Address: {new_ip_address}") + logger.error( + f"Domain: {domain}, Record ID: {record_id}, New IP Address: {new_ip_address}" + ) logger.error(response.content) return None elif response.status_code > 299: logger.error(f"[!] [{response.status_code}] Unexpected redirect") return None elif response.status_code == 201: - logger.info(f'{domain} updated to {new_ip_address}') + logger.info(f"{domain} updated to {new_ip_address}") return json.loads(response.content) else: - logger.error(f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}") + logger.error( + f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}" + ) return None def list_all_domain_records(domain, api_url_base, api_token, logger): - headers = {'Content-Type': 'application/json', 'Authorization': 'Bearer {0}'.format(api_token)} + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer {0}".format(api_token), + } api_url = f"{api_url_base}/domains/{domain}/records" response = requests.get(api_url, headers=headers) if response.status_code == 200: - return json.loads(response.content.decode('utf-8'))["domain_records"] + return json.loads(response.content.decode("utf-8"))["domain_records"] else: - logger.error(f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}") + logger.error( + f"[?] Unexpected Error: [HTTP {response.status_code}]: Content: {response.content}" + ) return None @@ -89,10 +117,20 @@ def get_domain_a_record(domain, name, api_url_base, api_token, logger): def update_domain(domain, name, new_ip_address, api_url_base, api_token, logger): record = get_domain_a_record(domain, name, api_url_base, api_token, logger) if record is None: - logger.info(f'Creating domain a record for {domain} with ip address {new_ip_address}') - create_domain_a_record(domain, name, new_ip_address, api_url_base, api_token, logger) + logger.info( + f"Creating domain a record for {domain} with ip address {new_ip_address}" + ) + create_domain_a_record( + domain, name, new_ip_address, api_url_base, api_token, logger + ) elif record["data"] != new_ip_address: - logger.info(f'Updating domain a record for {domain}. Old ip address {record["data"]}, new ip address {new_ip_address}') - update_domain_a_record(domain, record["id"], new_ip_address, api_url_base, api_token, logger) + logger.info( + f'Updating domain a record for {domain}. Old ip address {record["data"]}, new ip address {new_ip_address}' + ) + update_domain_a_record( + domain, record["id"], new_ip_address, api_url_base, api_token, logger + ) else: - logger.info(f'Not updating domain a record for {domain}. IP address {record["data"]} is current') + logger.info( + f'Not updating domain a record for {domain}. IP address {record["data"]} is current' + ) diff --git a/bifrost/main.py b/bifrost/main.py index 14d03d7..eae8cac 100644 --- a/bifrost/main.py +++ b/bifrost/main.py @@ -11,17 +11,9 @@ app = FastAPI() async def root(): return {"message": "Hello World"} + app.include_router(routers.router) def init(): uvicorn.run(app, host=settings.host, port=settings.port) - # app.add_route(update_view, '/update', methods=['GET']) - # # setup_auth(app) - - # app.run( - # host=app.config.HOST, - # port=app.config.PORT, - # debug=app.config.DEBUG, - # access_log=app.config.ACCESS_LOG - # ) diff --git a/bifrost/routers.py b/bifrost/routers.py index 41b8ddf..fd06279 100644 --- a/bifrost/routers.py +++ b/bifrost/routers.py @@ -2,6 +2,7 @@ from functools import lru_cache import logging as logger from fastapi import APIRouter, Header, Request, Depends +from .auth import get_current_username from .digital_ocean import update_domain from . import config @@ -12,22 +13,31 @@ router = APIRouter() def get_settings(): return config.Settings + @router.get("/update") -def update_view(domain: str, request: Request, ip: str = None, x_forwarded_for: str = Header(None), settings: config.Settings = Depends(get_settings)): - logger.info('Here is your log') +def update_view( + domain: str, + request: Request, + ip: str = None, + x_forwarded_for: str = Header(None), + username: str = Depends(get_current_username), + settings: config.Settings = Depends(get_settings), +): + logger.info("Here is your log") if ip is not None: current_ip = ip elif x_forwarded_for is not None: current_ip = x_forwarded_for else: current_ip = request.ip - name, domain = domain.split('.', maxsplit=1) - update_domain(domain, name, current_ip, settings.api_url_base, settings.api_token, logger) - return {"domain": domain, "name": name, "current_ip": current_ip, "api_base": settings.api_url_base, "api_token": settings.api_token} - - -# @forbidden_view_config() -# def basic_challenge(request): -# response = HTTPUnauthorized() -# response.headers.update(forget(request)) -# return response + name, domain = domain.split(".", maxsplit=1) + update_domain( + domain, name, current_ip, settings.api_url_base, settings.api_token, logger + ) + return { + "domain": domain, + "name": name, + "current_ip": current_ip, + "api_base": settings.api_url_base, + "api_token": settings.api_token, + }