198 lines
5.7 KiB
Python
198 lines
5.7 KiB
Python
from datetime import datetime, timedelta
|
|
|
|
from sqlalchemy import Date
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql import func
|
|
|
|
from src.api.models.flight import Flight
|
|
from src.api.schemas.flight import Flight as FlightPydantic
|
|
|
|
|
|
def is_flight_unique(db: Session, flight: FlightPydantic, id: int = None):
|
|
return (
|
|
db.query(Flight)
|
|
.filter(
|
|
Flight.id != id if id else True,
|
|
Flight.flight_code == flight.flight_code,
|
|
Flight.origin == flight.origin,
|
|
Flight.destination == flight.destination,
|
|
Flight.departure_time == flight.departure_time,
|
|
Flight.arrival_time == flight.arrival_time,
|
|
)
|
|
.count()
|
|
== 0
|
|
)
|
|
|
|
|
|
def is_flight_collision(db: Session, flight: Flight, id: int = None):
|
|
if not flight.gate:
|
|
return False
|
|
|
|
if not isinstance(flight.departure_time, datetime):
|
|
setattr(
|
|
flight,
|
|
"departure_time",
|
|
datetime.strptime(flight.departure_time, "%Y-%m-%d %I:%M %p"),
|
|
)
|
|
|
|
time_window = timedelta(minutes=30)
|
|
departure_time_lower_limit = flight.departure_time - time_window
|
|
departure_time_upper_limit = flight.departure_time + time_window
|
|
|
|
collision_count = (
|
|
db.query(Flight)
|
|
.filter(
|
|
Flight.id != id if id else True,
|
|
Flight.gate == flight.gate,
|
|
Flight.origin == flight.origin,
|
|
Flight.departure_time.between(
|
|
departure_time_lower_limit, departure_time_upper_limit
|
|
),
|
|
Flight.status != "Deleted",
|
|
)
|
|
.count()
|
|
)
|
|
|
|
return collision_count > 0
|
|
|
|
|
|
def get_flight_by_id(db: Session, flight_id: int):
|
|
return db.query(Flight).filter(Flight.id == flight_id).first()
|
|
|
|
|
|
def get_flights(db: Session, page: int = 1, limit: int = 8):
|
|
if page <= 0:
|
|
page = 1
|
|
skip = (page - 1) * limit
|
|
count = db.query(Flight).count()
|
|
return db.query(Flight).offset(skip).limit(limit).all(), count
|
|
|
|
|
|
def create_flight(db: Session, flight: FlightPydantic):
|
|
if not is_flight_unique(db, flight):
|
|
raise ValueError
|
|
|
|
db_flight = Flight(
|
|
flight_code=flight.flight_code,
|
|
status=flight.status,
|
|
origin=flight.origin,
|
|
destination=flight.destination,
|
|
departure_time=flight.departure_time,
|
|
arrival_time=flight.arrival_time,
|
|
gate=flight.gate,
|
|
user_id=flight.user_id,
|
|
)
|
|
db.add(db_flight)
|
|
db.commit()
|
|
db.refresh(db_flight)
|
|
return db_flight
|
|
|
|
|
|
def update_flight_status(db: Session, status, id):
|
|
db_flight = db.query(Flight).filter(Flight.id == id).first()
|
|
if db_flight is None:
|
|
raise KeyError
|
|
if db_flight.user_id != status.user_id:
|
|
raise PermissionError
|
|
|
|
setattr(db_flight, "status", status.status)
|
|
setattr(db_flight, "last_updated", func.now())
|
|
db.commit()
|
|
db.refresh(db_flight)
|
|
return db_flight
|
|
|
|
|
|
def update_flight(db: Session, update_data, id):
|
|
db_flight = db.query(Flight).filter(Flight.id == id).first()
|
|
if db_flight is None:
|
|
raise KeyError
|
|
# if db_flight.user_id != update_data["user_id"] and role != "admin":
|
|
# raise PermissionError
|
|
|
|
new_flight = Flight(
|
|
**{
|
|
key: value
|
|
for key, value in {**db_flight.__dict__, **update_data}.items()
|
|
if not key.startswith("_")
|
|
}
|
|
)
|
|
|
|
if (
|
|
new_flight.flight_code != db_flight.flight_code
|
|
or new_flight.destination != db_flight.destination
|
|
or new_flight.origin != db_flight.origin
|
|
or new_flight.departure_time != db_flight.departure_time
|
|
or new_flight.arrival_time != db_flight.arrival_time
|
|
) and not is_flight_unique(db, new_flight, id):
|
|
raise ValueError("non-unique")
|
|
|
|
if (
|
|
new_flight.destination != db_flight.destination
|
|
or new_flight.origin != db_flight.origin
|
|
or new_flight.departure_time != db_flight.departure_time
|
|
or new_flight.arrival_time != db_flight.arrival_time
|
|
or new_flight.gate != db_flight.gate
|
|
) and is_flight_collision(db, new_flight, id):
|
|
raise ValueError("collision")
|
|
|
|
for key, value in update_data.items():
|
|
if key != "user_id":
|
|
setattr(db_flight, key, value)
|
|
setattr(db_flight, "last_updated", func.now())
|
|
|
|
db.commit()
|
|
db.refresh(db_flight)
|
|
return db_flight
|
|
|
|
|
|
def get_flights_by_origin(db: Session, origin: str, future: str):
|
|
if future:
|
|
return (
|
|
db.query(Flight)
|
|
.filter(
|
|
(Flight.origin == origin)
|
|
& (Flight.departure_time.cast(Date) >= func.current_date())
|
|
)
|
|
.all()
|
|
)
|
|
|
|
return db.query(Flight).filter(Flight.origin == origin).all()
|
|
|
|
|
|
def get_flights_by_destination(db: Session, destination: str, future: str):
|
|
if future:
|
|
return (
|
|
db.query(Flight)
|
|
.filter(
|
|
(Flight.destination == destination)
|
|
& (Flight.departure_time.cast(Date) >= func.current_date())
|
|
)
|
|
.all()
|
|
)
|
|
|
|
return db.query(Flight).filter(Flight.destination == destination).all()
|
|
|
|
|
|
def get_flights_update_origin(db: Session, origin: str, lastUpdate: str):
|
|
return (
|
|
db.query(Flight)
|
|
.filter(
|
|
(Flight.origin == origin)
|
|
& (Flight.last_updated >= lastUpdate)
|
|
& (Flight.departure_time.cast(Date) >= func.current_date())
|
|
)
|
|
.all()
|
|
)
|
|
|
|
|
|
def get_flights_update_destination(db: Session, destination: str, lastUpdate: str):
|
|
return (
|
|
db.query(Flight)
|
|
.filter(
|
|
(Flight.destination == destination)
|
|
& (Flight.last_updated >= lastUpdate)
|
|
& (Flight.arrival_time.cast(Date) >= func.current_date())
|
|
)
|
|
.all()
|
|
)
|