from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, List, Tuple

from sqlalchemy import and_, or_
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.orm import Session

from platypush.context import get_plugin
from platypush.entities import Entity


@dataclass
class _TaxonomyAwareEntity:
    """
    A support class used to map an entity and its level within a taxonomy to be
    flushed to the database.
    """

    entity: Entity
    level: int


class EntitiesDb:
    """
    This object is a facade around the entities database. It shouldn't be used
    directly. Instead, it is encapsulated by
    :class:`platypush.entities._repo.EntitiesRepository`, which is in charge of
    caching as well.
    """

    def get_session(self, *args, **kwargs) -> Session:
        db = get_plugin('db')
        assert db
        return db.get_session(*args, **kwargs)

    def fetch(
        self, session: Session, entities: Iterable[Entity]
    ) -> Dict[Tuple[str, str], Entity]:
        """
        Given a set of entities, it returns those that already exist on the database.
        """
        if not entities:
            return {}

        entities_filter = or_(
            *[
                and_(
                    Entity.external_id == entity.external_id,
                    Entity.plugin == entity.plugin,
                )
                for entity in entities
            ]
        )

        query = session.query(Entity).filter(entities_filter)
        existing_entities = {entity.entity_key: entity for entity in query.all()}

        return {
            entity.entity_key: existing_entities[entity.entity_key]
            for entity in entities
            if existing_entities.get(entity.entity_key)
        }

    @staticmethod
    def _close_batch(batch: List[_TaxonomyAwareEntity], batches: List[List[Entity]]):
        if batch:
            batches.append([item.entity for item in batch])

            batch.clear()

    def _split_entity_batches_for_flush(
        self, entities: Iterable[Entity]
    ) -> List[List[Entity]]:
        """
        This method retrieves the root entities given a list of entities and
        generates batches of "flushable" entities ready for upsert using a BFS
        algorithm.

        This is needed because we want hierarchies of entities to be flushed
        starting from the top layer, once their parents have been appropriately
        rewired. Otherwise, we may end up with conflicts on entities that have
        already been flushed.
        """
        # Index childrens by parent_id and by parent_key
        children_by_parent_id = defaultdict(list)
        children_by_parent_key = defaultdict(list)
        for entity in entities:
            parent_key = None
            parent_id = entity.parent_id
            if entity.parent:
                parent_id = parent_id or entity.parent.id
                parent_key = entity.parent.entity_key

            if parent_id:
                children_by_parent_id[parent_id].append(entity)
            if parent_key:
                children_by_parent_key[parent_key].append(entity)

        # Find the root entities in the hierarchy (i.e. those that have a null
        # parent)
        root_entities = list(
            {
                e.entity_key: e
                for e in entities
                if e.parent is None and e.parent_id is None
            }.values()
        )

        # Prepare a list of entities to process through BFS starting with the
        # root nodes (i.e. level=0)
        entities_to_process = [
            _TaxonomyAwareEntity(entity=e, level=0) for e in root_entities
        ]

        batches = []
        current_batch = []

        while entities_to_process:
            # Pop the first element in the list (FIFO implementation)
            item = entities_to_process.pop(0)
            entity = item.entity
            level = item.level

            # If the depth has increased compared to the previous node, flush
            # the current batch and open a new one.
            if current_batch and current_batch[-1].level < level:
                self._close_batch(current_batch, batches)
            current_batch.append(item)

            # Index the children nodes by key
            children_to_process = {
                e.entity_key: e
                for e in children_by_parent_key.get(entity.entity_key, [])
            }

            # If this entity has already been persisted, add back its children
            # that haven't been updated, so we won't lose those connections
            if entity.id:
                children_to_process.update(
                    {e.entity_key: e for e in children_by_parent_id.get(entity.id, [])}
                )

            # Add all the updated+inserted+existing children to the next layer
            # to be expanded
            entities_to_process += [
                _TaxonomyAwareEntity(entity=e, level=level + 1)
                for e in children_to_process.values()
            ]

        # Close any pending batches
        self._close_batch(current_batch, batches)
        return batches

    def upsert(
        self,
        session: Session,
        entities: Iterable[Entity],
    ) -> Iterable[Entity]:
        """
        Persist a set of entities.
        """
        # Get the "unwrapped" batches
        batches = self._split_entity_batches_for_flush(entities)

        # Flush each batch as we process it
        for batch in batches:
            session.add_all(batch)
            session.flush()

        session.commit()

        all_entities = list(
            {
                entity.entity_key: entity for batch in batches for entity in batch
            }.values()
        )

        # Remove all the entities from the existing session, so they can be
        # accessed outside of this context
        for e in all_entities:
            try:
                session.expunge(e)
            except InvalidRequestError:
                pass

        return all_entities