import json from datetime import datetime from fastapi.testclient import TestClient from src.api.main import app from src.api.models.flight import Flight client = TestClient(app) creating_flight = { "flight_code": "ABC123", "status": "pending", "origin": "SLA", "destination": "AEP", "departure_time": datetime(2023, 10, 23, 12, 0, 0).isoformat(), "arrival_time": datetime(2023, 10, 24, 12, 0, 0).isoformat(), "gate": "10", "user_id": 1, } def test_post_flight(test_database, get_flight): test_database.query(Flight).delete() api_call_retrieved_flight = client.post( "/flights", data=json.dumps(creating_flight) ) api_call_retrieved_flight_data = api_call_retrieved_flight.json() db_retrieved_flight = get_flight(api_call_retrieved_flight_data["id"]) assert db_retrieved_flight.flight_code == creating_flight["flight_code"] def test_patch_flight(test_database, create_flight, flight_to_create): test_database.query(Flight).delete() created_flight = create_flight(flight_to_create) api_call_retrieved_flight = client.patch( f"/flights/{created_flight.id}", data=json.dumps({"status": "on-boarding", "user_id": 1}), ) assert api_call_retrieved_flight.status_code == 200 api_call_retrieved_flight_data = api_call_retrieved_flight.json() assert api_call_retrieved_flight_data["id"] == created_flight.id assert api_call_retrieved_flight_data["status"] == "on-boarding" def test_all_flights(create_flights_on_database): resp = client.get("/flights") data = resp.json() assert resp.status_code == 200 compare_flight_arrays(data, create_flights_on_database) def test_all_flights_by_origin(random_origin, create_flights_on_database): filtered_flights = [ flight for flight in create_flights_on_database if flight.origin == random_origin ] resp = client.get(f"/flights?origin={random_origin}") data = resp.json() if len(filtered_flights) > 0: assert resp.status_code == 200 compare_flight_arrays(data, filtered_flights) else: assert resp.status_code == 404 def test_all_flights_updated_since( last_update, random_origin, create_flights_on_database ): filtered_flights = [ flight for flight in create_flights_on_database if flight.last_updated >= last_update and flight.origin == random_origin ] resp = client.get(f"/flights?origin={random_origin}&lastUpdated={last_update}") data = resp.json() if len(filtered_flights) > 0: assert resp.status_code == 200 compare_flight_arrays(data, filtered_flights) else: assert resp.status_code == 404 def compare_flight_arrays(retrieved, pivotal): assert len(retrieved) == len(pivotal) for retrieved_flight, filtered in zip(retrieved, pivotal): assert retrieved_flight["flight_code"] == filtered.flight_code