fids/flights-domain/flights-information/src/api/cruds/flight.py

193 lines
5.5 KiB
Python

from datetime import datetime, timedelta
from sqlalchemy import Date
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from src.api.models.flight import Flight
from src.api.schemas.flight import Flight as FlightPydantic
def is_flight_unique(db: Session, flight: FlightPydantic, id: int = None):
return (
db.query(Flight)
.filter(
Flight.id != id if id else True,
Flight.flight_code == flight.flight_code,
Flight.origin == flight.origin,
Flight.destination == flight.destination,
Flight.departure_time == flight.departure_time,
Flight.arrival_time == flight.arrival_time,
)
.count()
== 0
)
def is_flight_collision(db: Session, flight: Flight, id: int = None):
if not flight.gate:
return False
if not isinstance(flight.departure_time, datetime):
setattr(
flight,
"departure_time",
datetime.strptime(flight.departure_time, "%Y-%m-%d %I:%M %p"),
)
time_window = timedelta(minutes=30)
departure_time_lower_limit = flight.departure_time - time_window
departure_time_upper_limit = flight.departure_time + time_window
collision_count = (
db.query(Flight)
.filter(
Flight.id != id if id else True,
Flight.gate == flight.gate,
Flight.origin == flight.origin,
Flight.departure_time.between(
departure_time_lower_limit, departure_time_upper_limit
),
)
.count()
)
return collision_count > 0
def get_flight_by_id(db: Session, flight_id: int):
return db.query(Flight).filter(Flight.id == flight_id).first()
def get_flights(db: Session, skip: int = 0, limit: int = 100):
return db.query(Flight).offset(skip).limit(limit).all()
def create_flight(db: Session, flight: FlightPydantic):
if not is_flight_unique(db, flight):
raise ValueError
db_flight = Flight(
flight_code=flight.flight_code,
status=flight.status,
origin=flight.origin,
destination=flight.destination,
departure_time=flight.departure_time,
arrival_time=flight.arrival_time,
gate=flight.gate,
user_id=flight.user_id,
)
db.add(db_flight)
db.commit()
db.refresh(db_flight)
return db_flight
def update_flight_status(db: Session, status, id):
db_flight = db.query(Flight).filter(Flight.id == id).first()
if db_flight is None:
raise KeyError
if db_flight.user_id != status.user_id:
raise PermissionError
setattr(db_flight, "status", status.status)
setattr(db_flight, "last_updated", func.now())
db.commit()
db.refresh(db_flight)
return db_flight
def update_flight(db: Session, update_data, id):
db_flight = db.query(Flight).filter(Flight.id == id).first()
if db_flight is None:
raise KeyError
print(update_data)
if db_flight.user_id != update_data["user_id"]:
raise PermissionError
new_flight = Flight(
**{
key: value
for key, value in {**db_flight.__dict__, **update_data}.items()
if not key.startswith("_")
}
)
if (
new_flight.flight_code != db_flight.flight_code
or new_flight.destination != db_flight.destination
or new_flight.origin != db_flight.origin
or new_flight.departure_time != db_flight.departure_time
or new_flight.arrival_time != db_flight.arrival_time
) and not is_flight_unique(db, new_flight, id):
raise ValueError("non-unique")
if (
new_flight.destination != db_flight.destination
or new_flight.origin != db_flight.origin
or new_flight.departure_time != db_flight.departure_time
or new_flight.arrival_time != db_flight.arrival_time
or new_flight.gate != db_flight.gate
) and is_flight_collision(db, new_flight, id):
raise ValueError("collision")
for key, value in update_data.items():
setattr(db_flight, key, value)
setattr(db_flight, "last_updated", func.now())
db.commit()
db.refresh(db_flight)
return db_flight
def get_flights_by_origin(db: Session, origin: str, future: str):
if future:
return (
db.query(Flight)
.filter(
(Flight.origin == origin)
& (Flight.departure_time.cast(Date) >= func.current_date())
)
.all()
)
return db.query(Flight).filter(Flight.origin == origin).all()
def get_flights_by_destination(db: Session, destination: str, future: str):
if future:
return (
db.query(Flight)
.filter(
(Flight.destination == destination)
& (Flight.departure_time.cast(Date) >= func.current_date())
)
.all()
)
return db.query(Flight).filter(Flight.destination == destination).all()
def get_flights_update_origin(db: Session, origin: str, lastUpdate: str):
return (
db.query(Flight)
.filter(
(Flight.origin == origin)
& (Flight.last_updated >= lastUpdate)
& (Flight.departure_time.cast(Date) >= func.current_date())
)
.all()
)
def get_flights_update_destination(db: Session, destination: str, lastUpdate: str):
return (
db.query(Flight)
.filter(
(Flight.destination == destination)
& (Flight.last_updated >= lastUpdate)
& (Flight.arrival_time.cast(Date) >= func.current_date())
)
.all()
)