from datetime import datetime, timedelta from sqlalchemy import Date from sqlalchemy.orm import Session from sqlalchemy.sql import func from src.api.models.flight import Flight from src.api.schemas.flight import Flight as FlightPydantic def is_flight_unique(db: Session, flight: FlightPydantic, id: int = None): return ( db.query(Flight) .filter( Flight.id != id if id else True, Flight.flight_code == flight.flight_code, Flight.origin == flight.origin, Flight.destination == flight.destination, Flight.departure_time == flight.departure_time, Flight.arrival_time == flight.arrival_time, ) .count() == 0 ) def is_flight_collision(db: Session, flight: Flight, id: int = None): if not flight.gate: return False if not isinstance(flight.departure_time, datetime): setattr( flight, "departure_time", datetime.strptime(flight.departure_time, "%Y-%m-%d %I:%M %p"), ) time_window = timedelta(minutes=30) departure_time_lower_limit = flight.departure_time - time_window departure_time_upper_limit = flight.departure_time + time_window collision_count = ( db.query(Flight) .filter( Flight.id != id if id else True, Flight.gate == flight.gate, Flight.origin == flight.origin, Flight.departure_time.between( departure_time_lower_limit, departure_time_upper_limit ), Flight.status != "Deleted", ) .count() ) return collision_count > 0 def get_flight_by_id(db: Session, flight_id: int): return db.query(Flight).filter(Flight.id == flight_id).first() def get_flights(db: Session, page: int = 1, limit: int = 8): if page <= 0: page = 1 skip = (page - 1) * limit count = db.query(Flight).count() return db.query(Flight).offset(skip).limit(limit).all(), count def create_flight(db: Session, flight: FlightPydantic): if not is_flight_unique(db, flight): raise ValueError db_flight = Flight( flight_code=flight.flight_code, status=flight.status, origin=flight.origin, destination=flight.destination, departure_time=flight.departure_time, arrival_time=flight.arrival_time, gate=flight.gate, user_id=flight.user_id, ) db.add(db_flight) db.commit() db.refresh(db_flight) return db_flight def update_flight_status(db: Session, status, id): db_flight = db.query(Flight).filter(Flight.id == id).first() if db_flight is None: raise KeyError if db_flight.user_id != status.user_id: raise PermissionError setattr(db_flight, "status", status.status) setattr(db_flight, "last_updated", func.now()) db.commit() db.refresh(db_flight) return db_flight def update_flight(db: Session, update_data, id): db_flight = db.query(Flight).filter(Flight.id == id).first() if db_flight is None: raise KeyError new_flight = Flight( **{ key: value for key, value in {**db_flight.__dict__, **update_data}.items() if not key.startswith("_") } ) if ( new_flight.flight_code != db_flight.flight_code or new_flight.destination != db_flight.destination or new_flight.origin != db_flight.origin or new_flight.departure_time != db_flight.departure_time or new_flight.arrival_time != db_flight.arrival_time ) and not is_flight_unique(db, new_flight, id): raise ValueError("non-unique") if ( new_flight.destination != db_flight.destination or new_flight.origin != db_flight.origin or new_flight.departure_time != db_flight.departure_time or new_flight.arrival_time != db_flight.arrival_time or new_flight.gate != db_flight.gate ) and is_flight_collision(db, new_flight, id): raise ValueError("collision") for key, value in update_data.items(): if key != "user_id": setattr(db_flight, key, value) setattr(db_flight, "last_updated", func.now()) db.commit() db.refresh(db_flight) return db_flight def get_flights_by_origin(db: Session, origin: str, future: str): if future: return ( db.query(Flight) .filter( (Flight.origin == origin) & (Flight.departure_time.cast(Date) >= func.current_date()) ) .all() ) return db.query(Flight).filter(Flight.origin == origin).all() def get_flights_by_destination(db: Session, destination: str, future: str): if future: return ( db.query(Flight) .filter( (Flight.destination == destination) & (Flight.departure_time.cast(Date) >= func.current_date()) ) .all() ) return db.query(Flight).filter(Flight.destination == destination).all() def get_flights_update_origin(db: Session, origin: str, lastUpdate: str): return ( db.query(Flight) .filter( (Flight.origin == origin) & (Flight.last_updated >= lastUpdate) & (Flight.departure_time.cast(Date) >= func.current_date()) ) .all() ) def get_flights_update_destination(db: Session, destination: str, lastUpdate: str): return ( db.query(Flight) .filter( (Flight.destination == destination) & (Flight.last_updated >= lastUpdate) & (Flight.arrival_time.cast(Date) >= func.current_date()) ) .all() )