# Copyright 2020 James Arcus <jimbo@ucc.asn.au>
# Released under the terms of the GNU GPL

################################################################################

import os
import requests

################################################################################

# Formatting helpers

def PATH(endpoint):
    return "https://api.cloudflare.com/client/v4" + endpoint

def HEADERS(token):
    return {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + token
    }

def RESULT(response):
    details = response.json()

    if details["success"]:
        return details["result"]
    elif len(details["errors"]) == 1 and details["errors"][0]["code"] == 0:
        return [] # Special case for bad query, unsure why API returns a permisison error
    else:
        raise RuntimeError(repr(details["errors"]))

################################################################################

# API methods

def delete(token, endpoint):
    response = requests.delete(PATH(endpoint), headers=HEADERS(token))
    return RESULT(response)

def get(token, endpoint, query=None):
    response = requests.get(PATH(endpoint), params=query, headers=HEADERS(token))
    return RESULT(response)

def patch(token, endpoint, data):
    response = requests.patch(PATH(endpoint), json=data, headers=HEADERS(token))
    return RESULT(response)

def post(token, endpoint, data):
    response = requests.post(PATH(endpoint), json=data, headers=HEADERS(token))
    return RESULT(response)

def put(token, endpoint, data=None):
    response = requests.put(PATH(endpoint), json=data, headers=HEADERS(token))
    return RESULT(response)

################################################################################

# API endpoints

def list_zones(token, query=None):
    return get(token, "/zones", query)

def zone_details(token, zone_id):
    return get(token, f"/zones/{zone_id}")

def list_dns_records(token, zone_id, query=None):
    return get(token, f"/zones/{zone_id}/dns_records", query)

def create_dns_record(token, zone_id, data):
    return post(token, f"/zones/{zone_id}/dns_records", data)

def dns_record_details(token, zone_id, record_id):
    return get(token, f"/zones/{zone_id}/dns_records/{record_id}")

def update_dns_record(token, zone_id, record_id, data):
    return put(token, f"/zones/{zone_id}/dns_records/{record_id}", data)

def patch_dns_record(token, zone_id, record_id, data):
    return patch(token, f"/zones/{zone_id}/dns_records/{record_id}", data)

def delete_dns_record(token, zone_id, record_id):
    return delete(token, f"/zones/{zone_id}/dns_records/{record_id}")

################################################################################

# Query helpers

def LOOKUP_ONCE(seq, key):
    if len(seq) == 0:
        return None
    elif len(seq) == 1:
        return seq[0][key]
    else:
        raise ValueError("Multiple results: " + repr(seq))

def SET_UNLESS_NONE(map, key, value):
    if value is not None:
        map[key] = value

################################################################################

# API actions

def get_zone_id(token, zone_name):
    result = list_zones(token, {"name": zone_name})
    return LOOKUP_ONCE(result, "id")

def list_all_records(token, zone_id):
    return list_dns_records(token, zone_id)

def create_record(token, zone_id,
    record_type, record_name, record_content,
    record_ttl=1, record_priority=None, record_proxied=None):
    
    details = {
        "type": record_type,
        "name": record_name,
        "content": record_content,
        "ttl": record_ttl,
    }

    SET_UNLESS_NONE(details, "priority", record_priority)
    SET_UNLESS_NONE(details, "proxied", record_proxied)

    return create_dns_record(token, zone_id, details)

def get_record_id(token, zone_id, record_name):
    result = list_dns_records(token, zone_id, {"name": record_name})
    return LOOKUP_ONCE(result, "id")

def change_record(token, zone_id, record_id,
    record_type=None, record_name=None, record_content=None,
    record_ttl=None, record_proxied=None):
    
    changes = {}

    SET_UNLESS_NONE(changes, "type", record_type)
    SET_UNLESS_NONE(changes, "name", record_name)
    SET_UNLESS_NONE(changes, "content", record_content)
    SET_UNLESS_NONE(changes, "ttl", record_ttl)
    SET_UNLESS_NONE(changes, "proxied", record_proxied)

    return patch_dns_record(token, zone_id, record_id, changes)

def delete_record(token, zone_id, record_id):
    return delete_dns_record(token, zone_id, record_id)

###############################################################################

# User helpers

def FORMAT_NAME(zone, prefix):
    # Cloudflare doesn't use trailing dots
    if zone[-1] == '.':
        zone = zone[:-1]

    # Magic zone root prefix
    if prefix == '@':
        return zone
    else:
        return f"{prefix}.{zone}"