#!/usr/bin/env python3

from json import dumps
from requests import codes, Session


class APIClientException(Exception):
    def __init__(self, message, rsp=None):
        super(APIClientException, self).__init__(message)
        self.response = rsp

    @property
    def text(self):
        return self.response.text if self.response else ''

    @property
    def status_code(self):
        return self.response.status_code if self.response else ''


class BaseClient(object):

    auth_header = ""
    auth_endpoint = ""
    username_key = ""
    password_key = ""

    def __init__(self, base_url='https://app.threatseq.org/api'):
        self.base_url = base_url
        self.s = Session()
        self.s.headers.update({'content-type': 'application/json'})

    @staticmethod
    def get_token(response):
        raise NotImplementedError()

    def handleErrors(self, response):
        if response.status_code != codes.ok:
            msg = "{} Authentication Error - Unhandled".format(response.status_code)
            if response.status_code == codes.forbidden:
                msg = "{} Authentication Error - Forbidden".format(response.status_code)
            elif response.status_code == codes.unauthorized:
                msg = "{} Authentication Error - Unauthorized".format(response.status_code)
            raise APIClientException(msg, response) 

    def login(self, username, password):
        response = self.s.post(self.base_url + self.auth_endpoint, 
            data=dumps({self.username_key: username, self.password_key: password})
        )
        self.handleErrors(response)
        return response.json()

    def authenticate(self, username, password):
        self.s.headers.pop(self.auth_header, '_')
        response = self.login(username, password)
        self.s.headers.update({self.auth_header: self.get_token(response)})
        return response

    def endpoint(self, endpoint_url, method='GET', **kwargs):
        if self.auth_header not in self.s.headers:
            raise APIClientException("Must authenticate first!")
        m = self.s.get if method == 'GET' else self.s.post
        response = m(self.base_url + endpoint_url, **kwargs)
        self.handleErrors(response)
        return response.json()


class APIClient(BaseClient):

    auth_header = "Authorization"
    auth_endpoint = "/auth"
    username_key = "name"
    password_key = "key"

    @staticmethod
    def get_token(response):
        return "JWT {}".format(response['access_token'])


class ThreatSEQClient(APIClient):

    def __init__(self, *args, **kwargs):
        super(ThreatSEQClient, self).__init__(*args, **kwargs)

    def create_order(self, name, fasta_data, is_nucleotide=True):
        return self.endpoint("/ext/orders", 'POST',
            data=dumps({
                'name': name,
                'content': fasta_data,
                'type': 'nucleotide' if is_nucleotide else 'amino acid',
            }))

    def get_orders(self, name=""):
        return self.endpoint("/ext/orders?q={}".format(name) \
            if name else "/ext/orders")

    def get_order(self, order_id):
        return self.endpoint("/ext/orders/{}".format(order_id))

    def get_annotations(self, order_id, sequence_id):
        return self.endpoint("/ext/orders/{}/sequences/{}/annotations".format(order_id, sequence_id))

    def get_sequences(self, order_id, filters=[]):
        query = ""
        if filters:
            query = "?{}".format(["filters={}".format(t_s) for t_s in filters].join("&"))
        return self.endpoint("/ext/orders/{}/sequences{}".format(order_id, query))

    def get_sequence(self, order_id, sequence_id):
        return self.endpoint("/ext/orders/{}/sequences/{}".format(order_id, sequence_id))


# ----------------------------------------------------------------------------

# # Example Usage:

# # >>> from threatseq import ThreatSEQClient 
# # >>> c = ThreatSEQClient()
# # >>> c.authenticate('API-USER-NAME', 'API-KEY')
# # >>> c.get_orders()
# # >>> try:
# # ...     c.get_order(999)
# # ... except APIClientException as e:
# # ...     print(e.message, e.status_code, e.text, e.response)

if __name__ == '__main__':
    pass
