from datetime import datetime, timedelta

from sqlalchemy import Date
from sqlalchemy.orm import Session
from sqlalchemy.sql import func

from src.api.models.flight import Flight
from src.api.schemas.flight import Flight as FlightPydantic


def is_flight_unique(db: Session, flight: FlightPydantic, id: int = None):
    return (
        db.query(Flight)
        .filter(
            Flight.id != id if id else True,
            Flight.flight_code == flight.flight_code,
            Flight.origin == flight.origin,
            Flight.destination == flight.destination,
            Flight.departure_time == flight.departure_time,
            Flight.arrival_time == flight.arrival_time,
        )
        .count()
        == 0
    )


def is_flight_collision(db: Session, flight: Flight, id: int = None):
    if not flight.gate:
        return False

    if not isinstance(flight.departure_time, datetime):
        setattr(
            flight,
            "departure_time",
            datetime.strptime(flight.departure_time, "%Y-%m-%d %I:%M %p"),
        )

    time_window = timedelta(minutes=30)
    departure_time_lower_limit = flight.departure_time - time_window
    departure_time_upper_limit = flight.departure_time + time_window

    collision_count = (
        db.query(Flight)
        .filter(
            Flight.id != id if id else True,
            Flight.gate == flight.gate,
            Flight.origin == flight.origin,
            Flight.departure_time.between(
                departure_time_lower_limit, departure_time_upper_limit
            ),
            Flight.status != "Deleted",
        )
        .count()
    )

    return collision_count > 0


def get_flight_by_id(db: Session, flight_id: int):
    return db.query(Flight).filter(Flight.id == flight_id).first()


def get_flights(db: Session, page: int = 1, limit: int = 8):
    if page <= 0:
        page = 1
    skip = (page - 1) * limit
    count = db.query(Flight).count()
    return db.query(Flight).offset(skip).limit(limit).all(), count


def create_flight(db: Session, flight: FlightPydantic):
    if not is_flight_unique(db, flight):
        raise ValueError

    db_flight = Flight(
        flight_code=flight.flight_code,
        status=flight.status,
        origin=flight.origin,
        destination=flight.destination,
        departure_time=flight.departure_time,
        arrival_time=flight.arrival_time,
        gate=flight.gate,
        user_id=flight.user_id,
    )
    db.add(db_flight)
    db.commit()
    db.refresh(db_flight)
    return db_flight


def update_flight_status(db: Session, status, id):
    db_flight = db.query(Flight).filter(Flight.id == id).first()
    if db_flight is None:
        raise KeyError
    if db_flight.user_id != status.user_id:
        raise PermissionError

    setattr(db_flight, "status", status.status)
    setattr(db_flight, "last_updated", func.now())
    db.commit()
    db.refresh(db_flight)
    return db_flight


def update_flight(db: Session, update_data, id):
    db_flight = db.query(Flight).filter(Flight.id == id).first()
    if db_flight is None:
        raise KeyError
    # if db_flight.user_id != update_data["user_id"] and role != "admin":
    #     raise PermissionError

    new_flight = Flight(
        **{
            key: value
            for key, value in {**db_flight.__dict__, **update_data}.items()
            if not key.startswith("_")
        }
    )

    if (
        new_flight.flight_code != db_flight.flight_code
        or new_flight.destination != db_flight.destination
        or new_flight.origin != db_flight.origin
        or new_flight.departure_time != db_flight.departure_time
        or new_flight.arrival_time != db_flight.arrival_time
    ) and not is_flight_unique(db, new_flight, id):
        raise ValueError("non-unique")

    if (
        new_flight.destination != db_flight.destination
        or new_flight.origin != db_flight.origin
        or new_flight.departure_time != db_flight.departure_time
        or new_flight.arrival_time != db_flight.arrival_time
        or new_flight.gate != db_flight.gate
    ) and is_flight_collision(db, new_flight, id):
        raise ValueError("collision")

    for key, value in update_data.items():
        if key != "user_id":
            setattr(db_flight, key, value)
    setattr(db_flight, "last_updated", func.now())

    db.commit()
    db.refresh(db_flight)
    return db_flight


def get_flights_by_origin(db: Session, origin: str, future: str):
    if future:
        return (
            db.query(Flight)
            .filter(
                (Flight.origin == origin)
                & (Flight.departure_time.cast(Date) >= func.current_date())
            )
            .all()
        )

    return db.query(Flight).filter(Flight.origin == origin).all()


def get_flights_by_destination(db: Session, destination: str, future: str):
    if future:
        return (
            db.query(Flight)
            .filter(
                (Flight.destination == destination)
                & (Flight.departure_time.cast(Date) >= func.current_date())
            )
            .all()
        )

    return db.query(Flight).filter(Flight.destination == destination).all()


def get_flights_update_origin(db: Session, origin: str, lastUpdate: str):
    return (
        db.query(Flight)
        .filter(
            (Flight.origin == origin)
            & (Flight.last_updated >= lastUpdate)
            & (Flight.departure_time.cast(Date) >= func.current_date())
        )
        .all()
    )


def get_flights_update_destination(db: Session, destination: str, lastUpdate: str):
    return (
        db.query(Flight)
        .filter(
            (Flight.destination == destination)
            & (Flight.last_updated >= lastUpdate)
            & (Flight.arrival_time.cast(Date) >= func.current_date())
        )
        .all()
    )