102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
import json
|
|
from datetime import datetime
|
|
|
|
from fastapi import BackgroundTasks
|
|
from fastapi.testclient import TestClient
|
|
|
|
from src.api.main import app
|
|
from src.api.models.flight import Flight
|
|
|
|
client = TestClient(app)
|
|
|
|
|
|
creating_flight = {
|
|
"flight_code": "ABC123",
|
|
"status": "pending",
|
|
"origin": "SLA",
|
|
"destination": "AEP",
|
|
"departure_time": datetime(2023, 10, 23, 12, 0, 0).isoformat(),
|
|
"arrival_time": datetime(2023, 10, 24, 12, 0, 0).isoformat(),
|
|
"gate": "10",
|
|
"user_id": 1,
|
|
}
|
|
|
|
|
|
def test_post_flight(test_database, get_flight):
|
|
test_database.query(Flight).delete()
|
|
|
|
api_call_retrieved_flight = client.post(
|
|
"/flights", data=json.dumps(creating_flight)
|
|
)
|
|
api_call_retrieved_flight_data = api_call_retrieved_flight.json()
|
|
db_retrieved_flight = get_flight(api_call_retrieved_flight_data["id"])
|
|
|
|
assert db_retrieved_flight.flight_code == creating_flight["flight_code"]
|
|
|
|
|
|
def add_task(self, func, *args, **kwargs) -> None:
|
|
return None
|
|
|
|
|
|
def test_patch_flight(test_database, create_flight, flight_to_create, monkeypatch):
|
|
monkeypatch.setattr(BackgroundTasks, "add_task", add_task)
|
|
|
|
test_database.query(Flight).delete()
|
|
created_flight = create_flight(flight_to_create)
|
|
api_call_retrieved_flight = client.patch(
|
|
f"/flights/{created_flight.id}",
|
|
data=json.dumps({"status": "on-boarding", "user_id": 1}),
|
|
)
|
|
assert api_call_retrieved_flight.status_code == 200
|
|
api_call_retrieved_flight_data = api_call_retrieved_flight.json()
|
|
assert api_call_retrieved_flight_data["id"] == created_flight.id
|
|
assert api_call_retrieved_flight_data["status"] == "on-boarding"
|
|
|
|
|
|
def test_all_flights(create_flights_on_database):
|
|
resp = client.get("/flights")
|
|
data = resp.json()
|
|
|
|
assert resp.status_code == 200
|
|
compare_flight_arrays(data, create_flights_on_database)
|
|
|
|
|
|
def test_all_flights_by_origin(random_origin, create_flights_on_database):
|
|
filtered_flights = [
|
|
flight
|
|
for flight in create_flights_on_database
|
|
if flight.origin == random_origin
|
|
]
|
|
|
|
resp = client.get(f"/flights?origin={random_origin}")
|
|
data = resp.json()
|
|
if len(filtered_flights) > 0:
|
|
assert resp.status_code == 200
|
|
compare_flight_arrays(data, filtered_flights)
|
|
else:
|
|
assert resp.status_code == 404
|
|
|
|
|
|
def test_all_flights_updated_since(
|
|
last_update, random_origin, create_flights_on_database
|
|
):
|
|
filtered_flights = [
|
|
flight
|
|
for flight in create_flights_on_database
|
|
if flight.last_updated >= last_update and flight.origin == random_origin
|
|
]
|
|
|
|
resp = client.get(f"/flights?origin={random_origin}&lastUpdated={last_update}")
|
|
data = resp.json()
|
|
if len(filtered_flights) > 0:
|
|
assert resp.status_code == 200
|
|
compare_flight_arrays(data, filtered_flights)
|
|
else:
|
|
assert resp.status_code == 404
|
|
|
|
|
|
def compare_flight_arrays(retrieved, pivotal):
|
|
assert len(retrieved) == len(pivotal)
|
|
for retrieved_flight, filtered in zip(retrieved, pivotal):
|
|
assert retrieved_flight["flight_code"] == filtered.flight_code
|