# Copyright (c) 2019-2021 Almar Klein - This code is subject to the MIT license
"""
The itemdb library allows you to store and retrieve Python dicts in a
database on the local filesystem, in an easy, fast, and reliable way.

Based on the rock-solid and ACID compliant SQLite, but with easy and
explicit transactions using a ``with`` statement. It provides a simple
object-based API, with the flexibility to store (JSON-compatible) items
with arbitrary fields, and add indices when needed.
"""

import os
import json
import queue
import asyncio
import sqlite3
import threading


__version__ = "1.2.0"
version_info = tuple(map(int, __version__.split(".")))

__all__ = ["ItemDB", "AsyncItemDB", "asyncify"]

json_encode = json.JSONEncoder(ensure_ascii=True).encode
json_decode = json.JSONDecoder().decode


# Notes:
#
# * Setting isolation_level to None turns on autocommit mode. We need to do
#   this to prevent Python from issuing BEGIN before DML statements.
# * Using a connection object as a context manager auto-commits/rollbacks a
#   transaction.
# * We should close cursor objects as soon as possible, because they can hold
#   back waiting writers. That's why we dont have an iterator.
# * MongoDB's approach of db.tablename.push() looks nice, but I don't like
#   the "magical" side of it, especially since the db does not know its tables.
#   Also it makes the code more complex, introduces an extra class, and
#   increases the risk of preventing a db from closing (by holding a table).


def asyncify(func):
    """Wrap a normal function into an awaitable co-routine. Can be used
    as a decorator.

    The original function will be executed in a separate thread. This
    allows async code to execute io-bound code (like querying a sqlite
    database) without stalling.

    Note that the code in func must be thread-safe. It's probably best to
    isolate the io-bound parts of your code and only wrap these.
    """

    def threaded_func(loop, future, args, kwargs):
        try:
            result = func(*args, **kwargs)
        except BaseException as e:
            loop.call_soon_threadsafe(future.set_exception, e)
        else:
            loop.call_soon_threadsafe(future.set_result, result)

    async def asyncified_func(*args, **kwargs):
        loop = asyncio.get_running_loop()
        future = loop.create_future()
        threading.Thread(
            name="asyncify " + func.__name__,
            target=threaded_func,
            args=(loop, future, args, kwargs),
        ).start()
        return await future

    asyncified_func.__name__ = "asyncified_" + func.__name__
    return asyncified_func


class ItemDB:
    """A transactional database for storage and retrieval of dict items.

    Parameters
    ----------
    filename : str
        The file to open. Use ":memory:" for an in-memory db.


    The items in the database can be any JSON serializable dictionary.
    Indices can be defined for specific fields to enable fast selection
    of items based on these fields. Indices can be marked as unique to
    make a field mandatory and *identify* items based on that field.

    Transactions are done by using the ``with`` statement, and are mandatory
    for all operations that write to the database.
    """

    def __init__(self, filename):
        self._mtime = -1
        if os.path.isfile(filename):
            self._mtime = os.path.getmtime(filename)
        self._conn = sqlite3.connect(
            filename, timeout=60, isolation_level=None, check_same_thread=False
        )
        self._cur = None
        self._indices_per_table = {}

    @property
    def mtime(self):
        """The time that the database file was last modified, as a Unix timestamp.

        Is -1 if the file did not exist, or if the filename is not represented
        on the filesystem.
        """
        return self._mtime

    def __enter__(self):
        if self._cur is not None:
            raise IOError("Already in a transaction")
        self._cur = self._conn.cursor()
        self._cur.execute("BEGIN IMMEDIATE")
        return self

    def __exit__(self, type, value, traceback):
        self._cur.close()
        self._cur = None
        if value:
            self._conn.rollback()
            self._indices_per_table.clear()  # we cannot trust this cache anymore
        else:
            self._conn.commit()

    def __del__(self):
        self._conn.close()

    def close(self):
        """Close the database connection.

        This will be automatically called when the instance is deleted.
        But since it can be held e.g. in a traceback, consider using
        ``with closing(db):``.
        """
        self._conn.close()

    def get_table_names(self):
        """Return a (sorted) list of table names present in the database."""
        cur = self._conn.cursor()
        try:
            cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
            table_names = {x[0] for x in cur}
        finally:
            cur.close()
        return list(sorted(table_names))

    def get_indices(self, table_name):
        """Get a set of index names for the given table.

        Parameters
        ----------
        table_name : str
            The name of the table to get the indices for.
            *To avoid SQL injection, this arg should not be based on unsafe data.*


        Names prefixed with "!" represent fields that are required and
        unique. Raises KeyError if the table does not exist.
        """
        # Use cached?
        try:
            return self._indices_per_table[table_name]
        except KeyError:
            pass
        except TypeError:
            raise TypeError(f"Table name must be str, not {table_name}.")

        # Check table name
        if not isinstance(table_name, str):
            raise TypeError(f"Table name must be str, not {table_name}")
        elif not table_name.isidentifier():
            raise ValueError(f"Table name must be an identifier, not '{table_name}'")

        # Get columns for the table (cid, name, type, notnull, default, pk)
        cur = self._conn.cursor()
        try:
            cur.execute(f"PRAGMA table_info('{table_name}');")
            found_indices = {(x[3] * "!" + x[1]) for x in cur}  # includes !_ob
        finally:
            cur.close()

        # Cache and return - or fail
        if found_indices:
            found_indices.difference_update({"!_ob", "_ob"})
            self._indices_per_table[table_name] = found_indices
            return found_indices
        else:
            raise KeyError(f"Table {table_name} not present, maybe use ensure_table()?")

    def ensure_table(self, table_name, *indices):
        """Ensure that the given table exists and has the given indices.

        Parameters
        ----------
        table_name : str
            The name of the table to make sure exists.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        indices : varargs
            A sequence of strings, representing index names. Fields that are
            indexed can be queried with e.g. ``select()``.
            *To avoid SQL injection, this arg should not be based on unsafe data.*


        If an index name is prefixed with "!", it indicates a field that is
        mandatory and unique. Note that new unique indices cannot be added
        when the table already exist.

        This method returns as quickly as possible when the table
        already exists and has the appropriate indices. Returns the
        ItemDB object, so calls to this method can be stacked.

        Although this call may modify the database, one does not need
        to call this in a transaction.
        """

        if not all(isinstance(x, str) for x in indices):
            raise TypeError("Indices must be str")

        # Select missing indices
        try:
            missing_indices = set(indices).difference(self.get_indices(table_name))
        except KeyError:
            missing_indices = {"--table--"}

        # Do we need to do some work? Allow being used under a context and not
        if missing_indices:
            if self._cur:
                self._ensure_table_helper1(table_name, indices, missing_indices)
            else:
                with self:
                    self._ensure_table_helper1(table_name, indices, missing_indices)

        return self  # allow stacking this function

    def _ensure_table_helper1(self, table_name, indices, missing_indices):
        # Make sure the table is complete
        self._ensure_table_helper2(table_name, indices)
        self._indices_per_table.pop(table_name, None)  # let it refresh
        # Update values that already had a value for the just added columns/indices
        items = [
            item
            for item in self.select_all(table_name)
            if any(x.lstrip("!") in item for x in missing_indices)
        ]
        self.put(table_name, *items)

    def _ensure_table_helper2(self, table_name, indices):
        """Slow version to ensure table."""

        cur = self._cur

        # Check the column names
        for fieldname in indices:
            key = fieldname.lstrip("!")
            if not key.isidentifier():
                raise ValueError("Column names must be identifiers.")
            elif key == "_ob":
                raise IndexError("Column names cannot be '_ob' (name is reserved).")

        # Ensure the table.
        # If there is one unique key, make it the primary key and omit rowid.
        # This results in smaller and faster databases.
        text = f"CREATE TABLE IF NOT EXISTS {table_name} (_ob TEXT NOT NULL"
        unique_keys = sorted(x.lstrip("!") for x in indices if x.startswith("!"))
        if len(unique_keys) == 1:
            index_key = unique_keys[0]
            text += f", {index_key} NOT NULL PRIMARY KEY) WITHOUT ROWID;"
        else:
            for index_key in unique_keys:
                text += f", {index_key} NOT NULL UNIQUE"
            text += ");"
        cur.execute(text)

        # Ensure the columns and indices
        cur.execute(f"PRAGMA table_info('{table_name}');")
        found_indices = {(x[3] * "!" + x[1]) for x in cur}

        for fieldname in sorted(indices):
            index_key = fieldname.lstrip("!")
            if fieldname not in found_indices:
                if fieldname.startswith("!"):
                    raise IndexError(
                        f"Cannot add unique index {fieldname!r} after the table has been created."
                    )
                elif fieldname in {x.lstrip("!") for x in found_indices}:
                    raise IndexError(f"Given index {fieldname!r} should be unique.")
                cur.execute(f"ALTER TABLE {table_name} ADD {index_key};")
            cmd = "CREATE INDEX IF NOT EXISTS"
            cur.execute(
                f"{cmd} idx_{table_name}_{index_key} ON {table_name} ({index_key})"
            )

    def delete_table(self, table_name):
        """Delete the table with the given name.

        Parameters
        ----------
        table_name : str
            The name of the table to delete.
            *To avoid SQL injection, this arg should not be based on unsafe data.*


        Be aware that this deletes the whole table, including all of
        its items.

        This method must be called within a transaction. Can raise
        KeyError if an invalid table is given, or IOError if not used
        within a transaction
        """
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        cur = self._cur
        if cur is None:
            raise IOError("Can only use delete_table() within a transaction.")
        self._indices_per_table.pop(table_name, None)
        self._cur.execute(f"DROP TABLE {table_name}")

    def rename_table(self, table_name, new_table_name):
        """Rename a table.

        Parameters
        ----------
        table_name : str
            The current name of the table.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        new_table_name : str
            The new name.
            *To avoid SQL injection, this arg should not be based on unsafe data.*


        This method must be called within a transaction. Can raise
        KeyError if an invalid table is given, or IOError if not used
        within a transaction
        """
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        if not (isinstance(new_table_name, str) and new_table_name.isidentifier()):
            raise TypeError(f"Table name must be a str identifier, not '{table_name}'")
        cur = self._cur
        if cur is None:
            raise IOError("Can only use rename_table() within a transaction.")
        self._indices_per_table.pop(table_name, None)
        self._cur.execute(f"ALTER TABLE {table_name} RENAME TO {new_table_name}")

    def count_all(self, table_name):
        """Get the total number of items in the given table."""
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        cur = self._conn.cursor()
        try:
            cur.execute(f"SELECT COUNT(*) FROM {table_name}")
            return cur.fetchone()[0]
        finally:
            cur.close()

    def count(self, table_name, query, *save_args):
        """Get the number of items in the given table that match the given query.

        Parameters
        ----------
        table_name : str
            The name of the table to count items in.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        query : str
            The query to select items on.
            *To avoid SQL injection, this arg should not be based on unsafe data;
            use save_args for end-user input.*
        save_args : varargs
            The values to select items on.


        Examples::

            # Count the persons older than 20
            db.count("persons", "age > ?", 20)
            # Count the persons older than a given value
            db.count("persons", "age > ?", min_age)
            # Use AND and OR for more precise queries
            db.count("persons", "age > ? AND age < ?", min_age, max_age)

        See ``select(``) for details on queries.

        Can raise KeyError if an invalid table is given, IndexError if an
        invalid field is used in the query, or sqlite3.OperationalError for
        an invalid query.
        """
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        cur = self._conn.cursor()
        try:
            cur.execute(f"SELECT COUNT(*) FROM {table_name} WHERE {query}", save_args)
            return cur.fetchone()[0]
        except sqlite3.OperationalError as err:
            if "no such column" in str(err).lower():
                raise IndexError(str(err))
            raise err
        finally:
            cur.close()

    def select_all(self, table_name):
        """Get all items in the given table. See ``select()`` for details."""
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        cur = self._conn.cursor()
        try:
            cur.execute(f"SELECT _ob FROM {table_name}")
            return [json_decode(x[0]) for x in cur]
        finally:
            cur.close()

    def select(self, table_name, query, *save_args):
        """Get the items in the given table that match the given query.

        Parameters
        ----------
        table_name : str
            The name of the table to select items in.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        query : str
            The query to select items on.
            *To avoid SQL injection, this arg should not be based on unsafe data;
            use save_args for end-user input.*
        save_args : varargs
            The values to select items on.


        The query follows SQLite syntax and can only include indexed
        fields. If needed, use ensure_table() to add indices. The query
        is always fast (which is why this method is called 'select', and
        not 'search').

        Examples::

            # Select the persons older than 20
            db.select("persons", "age > ?",  20)
            # Select the persons older than a given age
            db.select("persons", "age > ?", min_age)
            # Use AND and OR for more precise queries
            db.select("persons", "age > ? AND age < ?", min_age, max_age)

        There is no method to filter items bases on non-indexed fields,
        because this is easy using a list comprehension, e.g.::

            items = db.select_all("persons")
            items = [i for i in items if i["age"] > 20]

        Can raise KeyError if an invalid table is given, IndexError if an
        invalid field is used in the query, or sqlite3.OperationalError for
        an invalid query.
        """
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        # It is tempting to make this a generator, but also dangerous because
        # the cursor might not be closed if the generator is stored somewhere
        # and not run through the end.
        cur = self._conn.cursor()
        try:
            cur.execute(f"SELECT _ob FROM {table_name} WHERE {query}", save_args)
            return [json_decode(x[0]) for x in cur]
        except sqlite3.OperationalError as err:
            if "no such column" in str(err).lower():
                raise IndexError(str(err))
            raise err
        finally:
            cur.close()

    def select_one(self, table_name, query, *args):
        """Get the first item in the given table that match the given query.

        Parameters
        ----------
        table_name : str
            The name of the table to select an item in.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        query : str
            The query to select the item on.
            *To avoid SQL injection, this arg should not be based on unsafe data;
            use save_args for end-user input.*
        save_args : varargs
            The values to select the item on.


        Returns None if there was no match. See ``select()`` for details.
        """
        items = self.select(table_name, query, *args)
        return items[0] if items else None

    def put(self, table_name, *items):
        """Put one or more items into the given table.

        Parameters
        ----------
        table_name : str
            The name of the table to put the item(s) in.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        items : varargs
            The dicts to add. Keys that match an index can later be used for
            fast querying.


        This method must be called within a transaction. Can raise
        KeyError if an invalid table is given, IOError if not used
        within a transaction, TypeError if an item is not a (JSON
        serializable) dict, or IndexError if an item does not have a
        required field.
        """
        cur = self._cur
        if cur is None:
            raise IOError("Can only use put() within a transaction.")

        # Get indices - fail with KeyError for invalid table name
        indices = self.get_indices(table_name)

        for item in items:
            if not isinstance(item, dict):
                raise TypeError("Expecing each item to be a dict")

            index_keys = "_ob"
            row_plac = "?"
            row_vals = [json_encode(item)]  # Can raise TypeError
            for fieldname in indices:
                index_key = fieldname.lstrip("!")
                if index_key in item:
                    index_keys += ", " + index_key
                    row_plac += ", ?"
                    row_vals.append(item[index_key])
                elif fieldname.startswith("!"):
                    raise IndexError(f"Item does not have required field {index_key!r}")

            cur.execute(
                f"INSERT OR REPLACE INTO {table_name} ({index_keys}) VALUES ({row_plac})",
                row_vals,
            )

    def put_one(self, table_name, **item):
        """Put an item into the given table using kwargs.

        Parameters
        ----------
        table_name : str
            The name of the table to put the item(s) in.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        item : kwargs
            The dict to add. Keys that match an index can later be used for
            fast querying.


        This method must be called within a transaction.
        """
        self.put(table_name, item)

    def delete(self, table_name, query, *save_args):
        """Delete items from the given table.

        Parameters
        ----------
        table_name : str
            The name of the table to delete items from.
            *To avoid SQL injection, this arg should not be based on unsafe data.*
        query : str
            The query to select the items to delete.
            *To avoid SQL injection, this arg should not be based on unsafe data;
            use save_args for end-user input.*
        save_args : varargs
            The values to select the item on.


        Examples::

            # Delete the persons older than 20
            db.delete("persons", "age > ?", 20)
            #  Delete the persons older than a given age
            db.delete("persons", "age > ?", min_age)
            # Use AND and OR for more precise queries
            db.delete("persons", "age > ? AND age < ?", min_age, max_age)

        See ``select()`` for details on queries.

        This method must be called within a transaction. Can raise
        KeyError if an invalid table is given, IOError if not used
        within a transaction, IndexError if an invalid field is used
        in the query, or sqlite3.OperationalError for an invalid query.
        """
        self.get_indices(table_name)  # Fail with KeyError for invalid table name
        cur = self._cur
        if cur is None:
            raise IOError("Can only use delete() within a transaction.")
        try:
            cur.execute(f"DELETE FROM {table_name} WHERE {query}", save_args)
        except sqlite3.OperationalError as err:
            if "no such column" in str(err).lower():
                raise IndexError(str(err))
            raise err
        finally:
            cur.close()


class AsyncItemDB:
    """An async version of ItemDB. The API is exactly the same, except
    that all methods are async, and one must use `async with` instead
    of the normal `with`.
    """

    async def __new__(cls, filename):
        self = super().__new__(cls)
        self._loop = asyncio.get_running_loop()
        self._queue = queue.Queue()
        self._thread = Thread4AsyncItemDB(self._queue)
        self._thread.start()
        self.db = self._thread.db = await self._handle(ItemDB, filename)
        return self

    @property
    def mtime(self):
        return self.db.mtime

    async def _handle(self, function, *args, **kwargs):
        future = self._loop.create_future()
        self._queue.put_nowait((future, function, args, kwargs))
        return await future

    async def __aenter__(self):
        return await self._handle(self.db.__enter__)

    async def __aexit__(self, type, value, traceback):
        return await self._handle(self.db.__exit__, type, value, traceback)

    def __del__(self):
        future = self._loop.create_future()
        self._queue.put_nowait((future, self.db.close, (), {}))
        self._queue.put_nowait((None, None, None, None))

    async def close(self):
        future = self._loop.create_future()
        self._queue.put_nowait((future, self.db.close, (), {}))
        self._queue.put_nowait((None, None, None, None))
        return await future

    async def get_table_names(self, *args, **kwargs):
        return await self._handle(self.db.get_table_names, *args, **kwargs)

    async def get_indices(self, *args, **kwargs):
        return await self._handle(self.db.get_indices, *args, **kwargs)

    async def ensure_table(self, *args, **kwargs):
        return await self._handle(self.db.ensure_table, *args, **kwargs)

    async def delete_table(self, *args, **kwargs):
        return await self._handle(self.db.delete_table, *args, **kwargs)

    async def rename_table(self, *args, **kwargs):
        return await self._handle(self.db.rename_table, *args, **kwargs)

    async def count_all(self, *args, **kwargs):
        return await self._handle(self.db.count_all, *args, **kwargs)

    async def count(self, *args, **kwargs):
        return await self._handle(self.db.count, *args, **kwargs)

    async def select_all(self, *args, **kwargs):
        return await self._handle(self.db.select_all, *args, **kwargs)

    async def select(self, *args, **kwargs):
        return await self._handle(self.db.select, *args, **kwargs)

    async def select_one(self, *args, **kwargs):
        return await self._handle(self.db.select_one, *args, **kwargs)

    async def put(self, *args, **kwargs):
        return await self._handle(self.db.put, *args, **kwargs)

    async def put_one(self, *args, **kwargs):
        return await self._handle(self.db.put_one, *args, **kwargs)

    async def delete(self, *args, **kwargs):
        return await self._handle(self.db.delete, *args, **kwargs)


class Thread4AsyncItemDB(threading.Thread):
    """Thread that does the work for the AsyncItemDB."""

    _count = 0

    def __init__(self, queue):
        Thread4AsyncItemDB._count += 1
        super().__init__(name=f"AsyncItemDB_{Thread4AsyncItemDB._count}")
        self.daemon = True
        self._queue = queue
        self.db = None

    def run(self) -> None:
        while True:
            # Continues running until all queue items are processed,
            # even after closed (so we can finalize all futures)
            future, function, args, kwargs = self._queue.get()
            if future is None:
                break

            try:
                result = function(*args, **kwargs)

                def set_result(fut, result):
                    if not fut.done():
                        fut.set_result(result)

                loop = future.get_loop()
                loop.call_soon_threadsafe(set_result, future, result)

            except BaseException as e:

                def set_exception(fut, e):
                    if not fut.done():
                        fut.set_exception(e)

                loop = future.get_loop()
                loop.call_soon_threadsafe(set_exception, future, e)