diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..a7088e1a Binary files /dev/null and b/.DS_Store differ diff --git a/.gitignore b/.gitignore index 3126ddb2..4bc75f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,32 @@ *.pyc myenv/ +venv/ +logs/ +.github +ppo_plane_tensorboard/ + +# Virtual environments +venv/ +venv1/ + +# Python cache and bytecode +__pycache__/ +*.pyc +*.pyo +*.pyd + +# Distribution / packaging +build/ +dist/ +*.egg-info/ + +# VSCode / PyCharm / macOS files +.vscode/ +.idea/ +.DS_Store + +# Jupyter Notebook checkpoints +.ipynb_checkpoints/ + +# Logs +*.log diff --git a/README.md b/README.md index dc547ff1..cd444177 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![Demo](https://i.imgur.com/xpUow2f.png)](https://youtu.be/W5fCgqlECeI) -# python_mini_metro -This repo uses `pygame` to implement Mini Metro, a fun 2D strategic game where you try to optimize the max number of passengers your metro system can handle. Both human and program inputs are supported. One of the purposes of this implementation is to enable reinforcement learning agents to be trained on it. +# python_mini_plane +This repo uses `pygame` to implement Mini plane, a fun 2D strategic game where you try to optimize the max number of passengers your plane system can handle. Both human and program inputs are supported. One of the purposes of this implementation is to enable reinforcement learning agents to be trained on it. # Installation `pip install -r requirements.txt` diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 00000000..64698292 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/PPO/.DS_Store b/models/PPO/.DS_Store new file mode 100644 index 00000000..bb3edbd7 Binary files /dev/null and b/models/PPO/.DS_Store differ diff --git a/models/PPO/1763002825/.DS_Store b/models/PPO/1763002825/.DS_Store new file mode 100644 index 00000000..5008ddfc Binary files /dev/null and b/models/PPO/1763002825/.DS_Store differ diff --git a/models/final_model.zip b/models/final_model.zip new file mode 100644 index 00000000..3592dec0 Binary files /dev/null and b/models/final_model.zip differ diff --git a/models/vec_normalize.pkl b/models/vec_normalize.pkl new file mode 100644 index 00000000..c90bcf79 Binary files /dev/null and b/models/vec_normalize.pkl differ diff --git a/requirements.txt b/requirements.txt index 1532a747..a1484313 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,47 @@ -numpy==1.24.2 -pygame==2.3.0 -shapely==2.0.1 -shortuuid==1.0.11 +absl-py==2.3.1 +ale-py==0.11.2 +cloudpickle==3.1.2 +contourpy==1.3.3 +cycler==0.12.1 +Farama-Notifications==0.0.4 +filelock==3.20.0 +fonttools==4.60.1 +fsspec==2025.10.0 +grpcio==1.76.0 +gymnasium==1.2.2 +Jinja2==3.1.6 +kiwisolver==1.4.9 +Markdown==3.10 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.7 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.5 +numpy==2.2.6 +opencv-python==4.12.0.88 +packaging==25.0 +pandas==2.3.3 +pillow==12.0.0 +protobuf==6.33.0 +psutil==7.1.3 +pygame==2.6.1 +Pygments==2.19.2 +pyparsing==3.2.5 +python-dateutil==2.9.0.post0 +pytz==2025.2 +rich==14.2.0 +sb3_contrib==2.7.0 +setuptools==80.9.0 +shapely==2.1.2 +shortuuid==1.0.13 +six==1.17.0 +stable_baselines3==2.7.0 +sympy==1.14.0 +tensorboard==2.20.0 +tensorboard-data-server==0.7.2 +torch==2.9.1 +tqdm==4.67.1 +typing_extensions==4.15.0 +tzdata==2025.2 +Werkzeug==3.1.3 \ No newline at end of file diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 00000000..2adc7bf7 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/config.py b/src/config.py index 1f40a56e..62536860 100644 --- a/src/config.py +++ b/src/config.py @@ -7,19 +7,21 @@ screen_width = 1920 screen_height = 1080 screen_color = (255, 255, 255) - -# station -num_stations = 10 -station_size = 30 -station_capacity = 12 -station_color = (0, 0, 0) -station_shape_type_list = [ +border_padding = 100 + +# airport +num_airports = 10 +airport_size = 30 +airport_capacity = 12 +airport_color = (0, 0, 0) +airport_shape_type_list = [ ShapeType.RECT, ShapeType.CIRCLE, ShapeType.TRIANGLE, ShapeType.CROSS, ] -station_passengers_per_row = 4 +airport_passengers_per_row = 4 +airport_spawn_interval = 1250 # passenger passenger_size = 5 @@ -28,13 +30,13 @@ passenger_spawning_interval_step = 10 * framerate passenger_display_buffer = 3 * passenger_size -# metro -num_metros = 4 -metro_size = 30 -metro_color = (200, 200, 200) -metro_capacity = 6 -metro_speed_per_ms = 150 / 1000 # pixels / ms -metro_passengers_per_row = 3 +# plane +num_planes = 4 +plane_size = 30 +plane_color = (200, 200, 200) +plane_capacity = 6 +plane_speed_per_ms = 150 / 1000 # pixels / ms +plane_passengers_per_row = 3 # path num_paths = 3 @@ -55,3 +57,6 @@ # text score_font_size = 50 score_display_coords = (20, 20) + +airport_max_passengers = 6 +overcrowd_time_limit_ms = 20_000 # 20 seconds diff --git a/src/entity/station.py b/src/entity/airport.py similarity index 52% rename from src/entity/station.py rename to src/entity/airport.py index 1a42f26c..692875e7 100644 --- a/src/entity/station.py +++ b/src/entity/airport.py @@ -3,24 +3,26 @@ import pygame from shortuuid import uuid # type: ignore -from config import station_capacity, station_passengers_per_row, station_size +from config import airport_capacity, airport_passengers_per_row, airport_size from entity.holder import Holder from geometry.point import Point from geometry.shape import Shape -class Station(Holder): +class Airport(Holder): def __init__(self, shape: Shape, position: Point) -> None: super().__init__( shape=shape, - capacity=station_capacity, - id=f"Station-{uuid()}-{shape.type}", + capacity=airport_capacity, + id=f"Airport-{uuid()}-{shape.type}", ) - self.size = station_size + self.size = airport_size self.position = position - self.passengers_per_row = station_passengers_per_row + self.passengers_per_row = airport_passengers_per_row + self.overcrowd_start_time = 0 + self.is_overcrowded = False - def __eq__(self, other: Station) -> bool: + def __eq__(self, other: Airport) -> bool: return self.id == other.id def __hash__(self): diff --git a/src/entity/get_entity.py b/src/entity/get_entity.py index fe35ed80..f62599e8 100644 --- a/src/entity/get_entity.py +++ b/src/entity/get_entity.py @@ -1,26 +1,57 @@ +import random from typing import List -from config import screen_height, screen_width -from entity.metro import Metro -from entity.station import Station -from utils import get_random_position, get_random_station_shape +from config import screen_height, screen_width, airport_size +from entity.plane import Plane +from entity.airport import Airport +from geometry.point import Point +from geometry.type import ShapeType +from utils import get_random_position, get_random_airport_shape, get_shape_from_type -def get_random_station() -> Station: - shape = get_random_station_shape() +def get_new_random_airport() -> Airport: + all_shape_types = list(ShapeType) + + weights_map = { + ShapeType.CIRCLE: 0.40, + ShapeType.TRIANGLE: 0.30, + ShapeType.RECT: 0.20, + ShapeType.CROSS: 0.10 + } + + ordered_weights = [weights_map[shape_type] for shape_type in all_shape_types] + + chosen_shape_type = random.choices(all_shape_types, ordered_weights, k=1)[0] + position = get_random_position(screen_width, screen_height) - return Station(shape, position) + shape = get_shape_from_type(chosen_shape_type, (0, 0, 0), airport_size) + return Airport(shape, position) +def get_initial_airports() -> List[Airport]: + airports: List[Airport] = [] + + initial_shapes = [ShapeType.TRIANGLE, ShapeType.CIRCLE, ShapeType.RECT] + + positions = [ + Point(screen_width * 0.25, screen_height * 0.3), + Point(screen_width * 0.75, screen_height * 0.3), + Point(screen_width * 0.5, screen_height * 0.7), + ] -def get_random_stations(num: int) -> List[Station]: - stations: List[Station] = [] - for _ in range(num): - stations.append(get_random_station()) - return stations + for i, shape_type in enumerate(initial_shapes): + shape = get_shape_from_type(shape_type, (0, 0, 0), airport_size) + airport = Airport(shape, positions[i]) + airports.append(airport) + + return airports +def get_random_airport() -> Airport: + shape = get_random_airport_shape() + position = get_random_position(screen_width, screen_height) + return Airport(shape, position) -def get_metros(num: int) -> List[Metro]: - metros: List[Metro] = [] +def get_planes(num: int) -> List[Plane]: + planes: List[Plane] = [] for _ in range(num): - metros.append(Metro()) - return metros + planes.append(Plane()) + return planes diff --git a/src/entity/metro.py b/src/entity/metro.py deleted file mode 100644 index e51f28a7..00000000 --- a/src/entity/metro.py +++ /dev/null @@ -1,32 +0,0 @@ -import pygame -from shortuuid import uuid # type: ignore - -from config import ( - metro_capacity, - metro_color, - metro_passengers_per_row, - metro_size, - metro_speed_per_ms, -) -from entity.holder import Holder -from entity.segment import Segment -from entity.station import Station -from geometry.rect import Rect - - -class Metro(Holder): - def __init__(self) -> None: - self.size = metro_size - metro_shape = Rect(color=metro_color, width=2 * self.size, height=self.size) - super().__init__( - shape=metro_shape, - capacity=metro_capacity, - id=f"Metro-{uuid()}", - ) - self.current_station: Station | None = None - self.current_segment: Segment | None = None - self.current_segment_idx = 0 - self.path_id = "" - self.speed = metro_speed_per_ms - self.is_forward = True - self.passengers_per_row = metro_passengers_per_row diff --git a/src/entity/path.py b/src/entity/path.py index c36d4846..e9c55f6f 100644 --- a/src/entity/path.py +++ b/src/entity/path.py @@ -5,11 +5,11 @@ from shortuuid import uuid # type: ignore from config import path_width -from entity.metro import Metro +from entity.plane import Plane from entity.padding_segment import PaddingSegment from entity.path_segment import PathSegment from entity.segment import Segment -from entity.station import Station +from entity.airport import Airport from geometry.line import Line from geometry.point import Point from geometry.utils import direction, distance @@ -20,8 +20,8 @@ class Path: def __init__(self, color: Color) -> None: self.id = f"Path-{uuid()}" self.color = color - self.stations: List[Station] = [] - self.metros: List[Metro] = [] + self.airports: List[Airport] = [] + self.planes: List[Plane] = [] self.is_looped = False self.is_being_created = False self.temp_point: Point | None = None @@ -33,8 +33,8 @@ def __init__(self, color: Color) -> None: def __repr__(self) -> str: return self.id - def add_station(self, station: Station) -> None: - self.stations.append(station) + def add_airport(self, airport: Airport) -> None: + self.airports.append(airport) self.update_segments() def update_segments(self) -> None: @@ -42,17 +42,17 @@ def update_segments(self) -> None: self.path_segments = [] self.padding_segments = [] - for i in range(len(self.stations) - 1): + for i in range(len(self.airports) - 1): self.path_segments.append( PathSegment( - self.color, self.stations[i], self.stations[i + 1], self.path_order + self.color, self.airports[i], self.airports[i + 1], self.path_order ) ) if self.is_looped: self.path_segments.append( PathSegment( - self.color, self.stations[-1], self.stations[0], self.path_order + self.color, self.airports[-1], self.airports[0], self.path_order ) ) @@ -78,6 +78,49 @@ def update_segments(self) -> None: self.padding_segments.append(padding_segment) self.segments.append(padding_segment) + def insert_airport_on_segment( + self, + airport_to_insert: Airport, + existing_airport_1: Airport, + existing_airport_2: Airport, + ) -> bool: + """ + Finds a segment between two existing airports and inserts a new airport. + For example, turns a path ...-A-B-... into ...-A-C-B-... + + Args: + airport_to_insert: The new airport object to add to the path. + existing_airport_1: The first airport of the existing segment. + existing_airport_2: The second airport of the existing segment. + + Returns: + True if the segment was found and the airport was inserted, False otherwise. + """ + for i in range(len(self.airports) - 1): + airport_a = self.airports[i] + airport_b = self.airports[i + 1] + + if (airport_a == existing_airport_1 and airport_b == existing_airport_2) or \ + (airport_a == existing_airport_2 and airport_b == existing_airport_1): + + insert_index = i + 1 + self.airports.insert(insert_index, airport_to_insert) + self.update_segments() + return True + + if self.is_looped and len(self.airports) > 1: + airport_a = self.airports[-1] + airport_b = self.airports[0] + + if (airport_a == existing_airport_1 and airport_b == existing_airport_2) or \ + (airport_a == existing_airport_2 and airport_b == existing_airport_1): + + self.airports.append(airport_to_insert) + self.update_segments() + return True + + return False + def draw(self, surface: pygame.surface.Surface, path_order: int) -> None: self.path_order = path_order self.update_segments() @@ -88,7 +131,7 @@ def draw(self, surface: pygame.surface.Surface, path_order: int) -> None: if self.temp_point: temp_line = Line( color=self.color, - start=self.stations[-1].position, + start=self.airports[-1].position, end=self.temp_point, width=path_width, ) @@ -108,59 +151,59 @@ def remove_loop(self) -> None: self.is_looped = False self.update_segments() - def add_metro(self, metro: Metro) -> None: - metro.shape.color = self.color - metro.current_segment = self.segments[metro.current_segment_idx] - metro.position = metro.current_segment.segment_start - metro.path_id = self.id - self.metros.append(metro) - - def move_metro(self, metro: Metro, dt_ms: int) -> None: - assert metro.current_segment is not None - if metro.is_forward: - dst_station = metro.current_segment.end_station - dst_position = metro.current_segment.segment_end + def add_plane(self, plane: Plane) -> None: + plane.shape.color = self.color + plane.current_segment = self.segments[plane.current_segment_idx] + plane.position = plane.current_segment.segment_start + plane.path_id = self.id + self.planes.append(plane) + + def move_plane(self, plane: Plane, dt_ms: int) -> None: + assert plane.current_segment is not None + if plane.is_forward: + dst_airport = plane.current_segment.end_airport + dst_position = plane.current_segment.segment_end else: - dst_station = metro.current_segment.start_station - dst_position = metro.current_segment.segment_start + dst_airport = plane.current_segment.start_airport + dst_position = plane.current_segment.segment_start - start_point = metro.position + start_point = plane.position end_point = dst_position dist = distance(start_point, end_point) direct = direction(start_point, end_point) radians = math.atan2(direct.top, direct.left) degrees = math.degrees(radians) - metro.shape.set_degrees(degrees) - travel_dist_in_dt = metro.speed * dt_ms - # metro is not at one end of segment + plane.shape.set_degrees(degrees) + travel_dist_in_dt = plane.speed * dt_ms + # plane is not at one end of segment if dist > travel_dist_in_dt: - metro.current_station = None - metro.position += direct * travel_dist_in_dt - # metro is at one end of segment + plane.current_airport = None + plane.position += direct * travel_dist_in_dt + # plane is at one end of segment else: - metro.current_station = dst_station + plane.current_airport = dst_airport if len(self.segments) == 1: - metro.is_forward = not metro.is_forward - elif metro.current_segment_idx == len(self.segments) - 1: + plane.is_forward = not plane.is_forward + elif plane.current_segment_idx == len(self.segments) - 1: if self.is_looped: - metro.current_segment_idx = 0 + plane.current_segment_idx = 0 else: - if metro.is_forward: - metro.is_forward = False + if plane.is_forward: + plane.is_forward = False else: - metro.current_segment_idx -= 1 - elif metro.current_segment_idx == 0: - if metro.is_forward: - metro.current_segment_idx += 1 + plane.current_segment_idx -= 1 + elif plane.current_segment_idx == 0: + if plane.is_forward: + plane.current_segment_idx += 1 else: if self.is_looped: - metro.current_segment_idx = len(self.segments) - 1 + plane.current_segment_idx = len(self.segments) - 1 else: - metro.is_forward = True + plane.is_forward = True else: - if metro.is_forward: - metro.current_segment_idx += 1 + if plane.is_forward: + plane.current_segment_idx += 1 else: - metro.current_segment_idx -= 1 + plane.current_segment_idx -= 1 - metro.current_segment = self.segments[metro.current_segment_idx] + plane.current_segment = self.segments[plane.current_segment_idx] diff --git a/src/entity/path_segment.py b/src/entity/path_segment.py index 33988262..277f11ff 100644 --- a/src/entity/path_segment.py +++ b/src/entity/path_segment.py @@ -7,7 +7,7 @@ from config import path_order_shift, path_width from entity.segment import Segment -from entity.station import Station +from entity.airport import Airport from geometry.line import Line from geometry.point import Point from geometry.utils import direction, distance @@ -18,24 +18,24 @@ class PathSegment(Segment): def __init__( self, color: Color, - start_station: Station, - end_station: Station, + start_airport: airport, + end_airport: airport, path_order: int, ) -> None: super().__init__(color) self.id = f"PathSegment-{uuid()}" - self.start_station = start_station - self.end_station = end_station + self.start_airport = start_airport + self.end_airport = end_airport self.path_order = path_order - start_point = start_station.position - end_point = end_station.position + start_point = start_airport.position + end_point = end_airport.position direct = direction(start_point, end_point) buffer_vector = direct * path_order_shift buffer_vector = buffer_vector.rotate(90) - self.segment_start = start_station.position + buffer_vector * self.path_order - self.segment_end = end_station.position + buffer_vector * self.path_order + self.segment_start = start_airport.position + buffer_vector * self.path_order + self.segment_end = end_airport.position + buffer_vector * self.path_order self.line = Line( color=self.color, start=self.segment_start, diff --git a/src/entity/plane.py b/src/entity/plane.py new file mode 100644 index 00000000..375eb2a8 --- /dev/null +++ b/src/entity/plane.py @@ -0,0 +1,32 @@ +import pygame +from shortuuid import uuid # type: ignore + +from config import ( + plane_capacity, + plane_color, + plane_passengers_per_row, + plane_size, + plane_speed_per_ms, +) +from entity.holder import Holder +from entity.segment import Segment +from entity.airport import Airport +from geometry.rect import Rect + + +class Plane(Holder): + def __init__(self) -> None: + self.size = plane_size + plane_shape = Rect(color=plane_color, width=2 * self.size, height=self.size) + super().__init__( + shape=plane_shape, + capacity=plane_capacity, + id=f"Plane-{uuid()}", + ) + self.current_airport: airport | None = None + self.current_segment: Segment | None = None + self.current_segment_idx = 0 + self.path_id = "" + self.speed = plane_speed_per_ms + self.is_forward = True + self.passengers_per_row = plane_passengers_per_row diff --git a/src/entity/segment.py b/src/entity/segment.py index 8157ec8a..5a66de77 100644 --- a/src/entity/segment.py +++ b/src/entity/segment.py @@ -6,7 +6,7 @@ from shortuuid import uuid # type: ignore from config import screen_height, screen_width -from entity.station import Station +from entity.airport import Airport from geometry.line import Line from geometry.point import Point from type import Color @@ -16,8 +16,8 @@ class Segment(ABC): def __init__(self, color: Color) -> None: self.id = f"Segment-{uuid()}" self.color = color - self.start_station: Station | None = None - self.end_station: Station | None = None + self.start_airport: airport | None = None + self.end_airport: airport | None = None self.segment_start: Point self.segment_end: Point self.line: Line diff --git a/src/geometry/utils.py b/src/geometry/utils.py index 7c368446..eeb8d114 100644 --- a/src/geometry/utils.py +++ b/src/geometry/utils.py @@ -10,4 +10,7 @@ def distance(p1: Point, p2: Point) -> float: def direction(p1: Point, p2: Point) -> Point: diff = p2 - p1 diff_magnitude = distance(p1, p2) - return Point(diff.left / diff_magnitude, diff.top / diff_magnitude) + if diff_magnitude == 0: + return Point(0, 0) + else: + return Point(diff.left / diff_magnitude, diff.top / diff_magnitude) \ No newline at end of file diff --git a/src/graph/graph_algo.py b/src/graph/graph_algo.py index b68346f3..e26fb1c7 100644 --- a/src/graph/graph_algo.py +++ b/src/graph/graph_algo.py @@ -1,30 +1,30 @@ from typing import Dict, List from entity.path import Path -from entity.station import Station +from entity.airport import Airport from graph.node import Node -def build_station_nodes_dict(stations: List[Station], paths: List[Path]): - station_nodes: List[Node] = [] +def build_airport_nodes_dict(airports: List[Airport], paths: List[Path]): + airport_nodes: List[Node] = [] connections: List[List[Node]] = [] - station_nodes_dict: Dict[Station, Node] = {} + airport_nodes_dict: Dict[Airport, Node] = {} - for station in stations: - node = Node(station) - station_nodes.append(node) - station_nodes_dict[station] = node + for airport in airports: + node = Node(airport) + airport_nodes.append(node) + airport_nodes_dict[airport] = node for path in paths: if path.is_being_created: continue connection = [] - for station in path.stations: - station_nodes_dict[station].paths.add(path) - connection.append(station_nodes_dict[station]) + for airport in path.airports: + airport_nodes_dict[airport].paths.add(path) + connection.append(airport_nodes_dict[airport]) connections.append(connection) - while len(station_nodes) > 0: - root = station_nodes[0] + while len(airport_nodes) > 0: + root = airport_nodes[0] for connection in connections: for idx in range(len(connection)): node = connection[idx] @@ -33,10 +33,10 @@ def build_station_nodes_dict(stations: List[Station], paths: List[Path]): root.neighbors.add(connection[idx - 1]) if idx + 1 <= len(connection) - 1: root.neighbors.add(connection[idx + 1]) - station_nodes.remove(root) - station_nodes_dict[root.station] = root + airport_nodes.remove(root) + airport_nodes_dict[root.airport] = root - return station_nodes_dict + return airport_nodes_dict def bfs(start: Node, end: Node) -> List[Node]: diff --git a/src/graph/node.py b/src/graph/node.py index 5f8d9900..1ca1da88 100644 --- a/src/graph/node.py +++ b/src/graph/node.py @@ -5,21 +5,21 @@ from shortuuid import uuid # type: ignore from entity.path import Path -from entity.station import Station +from entity.airport import Airport class Node: - def __init__(self, station: Station) -> None: + def __init__(self, airport: Airport) -> None: self.id = f"Node-{uuid()}" - self.station = station + self.airport = airport self.neighbors: Set[Node] = set() self.paths: Set[Path] = set() def __eq__(self, other: Node) -> bool: - return self.station == other.station + return self.airport == other.airport def __hash__(self) -> int: return hash(self.id) def __repr__(self) -> str: - return f"Node-{self.station.__repr__()}" + return f"Node-{self.airport.__repr__()}" diff --git a/src/mediator.py b/src/mediator.py index 17baf51b..d40a5140 100644 --- a/src/mediator.py +++ b/src/mediator.py @@ -1,34 +1,40 @@ from __future__ import annotations +import math import pprint import random from typing import Dict, List import pygame +#diff from main.py -- can define the game as 1 class, so can train mutliple games at once +# game is done as a timestep --> so can run more games at once and train faster from config import ( - num_metros, + num_planes, num_paths, - num_stations, + num_airports, passenger_color, passenger_size, passenger_spawning_interval_step, passenger_spawning_start_step, score_display_coords, score_font_size, + airport_max_passengers, + overcrowd_time_limit_ms, + airport_spawn_interval ) -from entity.get_entity import get_random_stations -from entity.metro import Metro +from entity.get_entity import get_initial_airports, get_new_random_airport +from entity.plane import Plane from entity.passenger import Passenger from entity.path import Path -from entity.station import Station +from entity.airport import Airport from event.event import Event from event.keyboard import KeyboardEvent from event.mouse import MouseEvent from event.type import KeyboardEventType, MouseEventType from geometry.point import Point from geometry.type import ShapeType -from graph.graph_algo import bfs, build_station_nodes_dict +from graph.graph_algo import bfs, build_airport_nodes_dict from graph.node import Node from travel_plan import TravelPlan from type import Color @@ -48,18 +54,19 @@ def __init__(self) -> None: self.passenger_spawning_step = passenger_spawning_start_step self.passenger_spawning_interval_step = passenger_spawning_interval_step self.num_paths = num_paths - self.num_metros = num_metros - self.num_stations = num_stations + self.num_planes = num_planes + self.num_airports = num_airports # UI self.path_buttons = get_path_buttons(self.num_paths) self.path_to_button: Dict[Path, PathButton] = {} self.buttons = [*self.path_buttons] self.font = pygame.font.SysFont("arial", score_font_size) + self.game_over_font = pygame.font.SysFont("arial", 72) # entities - self.stations = get_random_stations(self.num_stations) - self.metros: List[Metro] = [] + self.airports = get_initial_airports() + self.planes: List[Plane] = [] self.paths: List[Path] = [] self.passengers: List[Passenger] = [] self.path_colors: Dict[Color, bool] = {} @@ -78,6 +85,11 @@ def __init__(self) -> None: self.travel_plans: TravelPlans = {} self.is_paused = False self.score = 0 + self.is_game_over = False + self.steps_since_last_airport_spawn = 0 + self.is_extending_path = False + self.original_airports_before_extend: List[Airport] = [] + self.is_old_path_looped = False def assign_paths_to_buttons(self): for path_button in self.path_buttons: @@ -90,34 +102,168 @@ def assign_paths_to_buttons(self): button.assign_path(path) self.path_to_button[path] = button + def spawn_new_airport(self): + new_airport = get_new_random_airport() + self.airports.append(new_airport) + def render(self, screen: pygame.surface.Surface) -> None: for idx, path in enumerate(self.paths): path_order = idx - round(self.num_paths / 2) path.draw(screen, path_order) - for station in self.stations: - station.draw(screen) - for metro in self.metros: - metro.draw(screen) + for airport in self.airports: + airport.draw(screen) + if airport.is_overcrowded: + duration = self.time_ms - airport.overcrowd_start_time_ms + progress_pct = min(duration / overcrowd_time_limit_ms, 1.0) + radius = airport.size + 5 + center_point = airport.position + rect = pygame.Rect( + center_point.left - radius, + center_point.top - radius, + radius * 2, + radius * 2, + ) + start_angle = -math.pi / 2 + end_angle = start_angle + (2 * math.pi * progress_pct) + pygame.draw.arc(screen, (255, 0, 0), rect, start_angle, end_angle, 3) + for plane in self.planes: + plane.draw(screen) for button in self.buttons: button.draw(screen) text_surface = self.font.render(f"Score: {self.score}", True, (0, 0, 0)) screen.blit(text_surface, score_display_coords) + if self.is_game_over: + overlay = pygame.Surface(screen.get_size(), pygame.SRCALPHA) + overlay.fill((0, 0, 0, 150)) + screen.blit(overlay, (0, 0)) + + text_surface = self.game_over_font.render( + "GAME OVER", True, (255, 0, 0) + ) + text_rect = text_surface.get_rect( + center=(screen.get_width() / 2, screen.get_height() / 2) + ) + screen.blit(text_surface, text_rect) + + def create_or_extend_path(self, airport_a: Airport, airport_b: Airport) -> bool: + """Atomically creates a new path or extends an existing one.""" + for path in self.paths: + if not path.airports or path.is_looped: + continue + + extend_from_end = path.airports[-1] == airport_a + extend_from_start = path.airports[0] == airport_a + + if extend_from_end or extend_from_start: + if airport_b in path.airports: + if (extend_from_end and path.airports[0] == airport_b) or \ + (extend_from_start and path.airports[-1] == airport_b): + if len(path.airports) > 2: + path.set_loop() + return True + else: + return False + return False + + if extend_from_start: + path.airports.reverse() + + path.add_airport(airport_b) + return True + + if len(self.paths) < self.num_paths: + assigned_color = (0, 0, 0) + for path_color, taken in self.path_colors.items(): + if not taken: + assigned_color = path_color + self.path_colors[path_color] = True + break + new_path = Path(assigned_color) + self._assign_color_to_path(new_path, assigned_color) + new_path.add_airport(airport_a) + new_path.add_airport(airport_b) + self._add_plane_to_path(new_path) + self.paths.append(new_path) + return True + + return False + + def _assign_color_to_path(self, path: Path, color: Color): + """Assigns a color to a path and marks it as taken.""" + path.color = color + self.path_colors[color] = True + self.path_to_color[path] = color + + def _add_plane_to_path(self, path: Path): + """Adds a new plane to a path if the limit has not been reached.""" + if len(self.planes) < self.num_planes: + plane = Plane() + path.add_plane(plane) + self.planes.append(plane) + + def insert_airport_on_path(self, s_insert: Airport, s1: Airport, s2: Airport) -> bool: + if s_insert == s1 or s_insert == s2 or s1 == s2: + return False + + for path in self.paths: + if path.insert_airport_on_segment(s_insert, s1, s2): + return True + return False + def react_mouse_event(self, event: MouseEvent): entity = self.get_containing_entity(event.position) if event.event_type == MouseEventType.MOUSE_DOWN: self.is_mouse_down = True if entity: - if isinstance(entity, Station): - self.start_path_on_station(entity) + if isinstance(entity, Airport): + if self.is_creating_path: + return + + path_to_extend = None + is_extending_from_start = False + + for path in self.paths: + if not path.airports: + continue + + if path.airports[0] == entity: + path_to_extend = path + is_extending_from_start = True + break + elif path.airports[-1] == entity: + if path.is_looped: + continue + path_to_extend = path + is_extending_from_start = False + break + + if path_to_extend: + self.is_creating_path = True + self.is_extending_path = True + self.path_being_created = path_to_extend + + self.original_airports_before_extend = list(path_to_extend.airports) + self.is_old_path_looped = path_to_extend.is_looped + + if is_extending_from_start: + path_to_extend.airports.reverse() + + if path_to_extend.is_looped: + path_to_extend.remove_loop() + + path_to_extend.is_being_created = True + path_to_extend.remove_temporary_point() + else: + self.start_path_on_airport(entity) elif event.event_type == MouseEventType.MOUSE_UP: self.is_mouse_down = False if self.is_creating_path: assert self.path_being_created is not None - if entity and isinstance(entity, Station): - self.end_path_on_station(entity) + if entity and isinstance(entity, Airport): + self.end_path_on_airport(entity) else: self.abort_path_creation() else: @@ -128,8 +274,8 @@ def react_mouse_event(self, event: MouseEvent): elif event.event_type == MouseEventType.MOUSE_MOTION: if self.is_mouse_down: if self.is_creating_path and self.path_being_created: - if entity and isinstance(entity, Station): - self.add_station_to_path(entity) + if entity and isinstance(entity, Airport): + self.add_airport_to_path(entity) else: self.path_being_created.set_temporary_point(event.position) else: @@ -151,25 +297,25 @@ def react(self, event: Event | None): self.react_keyboard_event(event) def get_containing_entity(self, position: Point): - for station in self.stations: - if station.contains(position): - return station + for airport in self.airports: + if airport.contains(position): + return airport for button in self.buttons: if button.contains(position): return button def remove_path(self, path: Path): self.path_to_button[path].remove_path() - for metro in path.metros: - for passenger in metro.passengers: + for plane in path.planes: + for passenger in plane.passengers: self.passengers.remove(passenger) - self.metros.remove(metro) + self.planes.remove(plane) self.release_color_for_path(path) self.paths.remove(path) self.assign_paths_to_buttons() self.find_travel_plan_for_passengers() - def start_path_on_station(self, station: Station) -> None: + def start_path_on_airport(self, airport: Airport) -> None: if len(self.paths) < self.num_paths: self.is_creating_path = True assigned_color = (0, 0, 0) @@ -180,78 +326,100 @@ def start_path_on_station(self, station: Station) -> None: break path = Path(assigned_color) self.path_to_color[path] = assigned_color - path.add_station(station) + path.add_airport(airport) path.is_being_created = True self.path_being_created = path self.paths.append(path) - def add_station_to_path(self, station: Station) -> None: + def add_airport_to_path(self, airport: Airport) -> None: assert self.path_being_created is not None - if self.path_being_created.stations[-1] == station: + if self.path_being_created.airports[-1] == airport: return - # loop + if ( - len(self.path_being_created.stations) > 1 - and self.path_being_created.stations[0] == station + len(self.path_being_created.airports) > 1 + and self.path_being_created.airports[0] == airport ): self.path_being_created.set_loop() - # non-loop - elif self.path_being_created.stations[0] != station: - if self.path_being_created.is_looped: - self.path_being_created.remove_loop() - self.path_being_created.add_station(station) + return + + if airport in self.path_being_created.airports: + return + + # Any loop should be removed + if self.path_being_created.is_looped: + self.path_being_created.remove_loop() + self.path_being_created.add_airport(airport) def abort_path_creation(self) -> None: assert self.path_being_created is not None self.is_creating_path = False - self.release_color_for_path(self.path_being_created) - self.paths.remove(self.path_being_created) + if self.is_extending_path: + self.path_being_created.airports = self.original_airports_before_extend + if self.is_old_path_looped: + self.path_being_created.set_loop() + else: + self.path_being_created.remove_loop() + + self.path_being_created.is_being_created = False + self.path_being_created.remove_temporary_point() + else: + self.release_color_for_path(self.path_being_created) + self.paths.remove(self.path_being_created) + self.is_extending_path = False + self.original_airports_before_extend = [] + self.is_old_path_looped = False self.path_being_created = None def release_color_for_path(self, path: Path) -> None: - self.path_colors[path.color] = False - del self.path_to_color[path] + if path in self.path_to_color: + color = self.path_to_color[path] + self.path_colors[color] = False + del self.path_to_color[path] def finish_path_creation(self) -> None: assert self.path_being_created is not None + was_new_path = not self.is_extending_path self.is_creating_path = False + self.is_extending_path = False + self.original_airports_before_extend = [] + self.is_old_path_looped = False self.path_being_created.is_being_created = False self.path_being_created.remove_temporary_point() - if len(self.metros) < self.num_metros: - metro = Metro() - self.path_being_created.add_metro(metro) - self.metros.append(metro) + if was_new_path and len(self.planes) < self.num_planes: + plane = Plane() + self.path_being_created.add_plane(plane) + self.planes.append(plane) self.path_being_created = None self.assign_paths_to_buttons() - def end_path_on_station(self, station: Station) -> None: + def end_path_on_airport(self, airport: Airport) -> None: assert self.path_being_created is not None - # current station de-dupe + if self.path_being_created.airports[-1] == airport: + if len(self.path_being_created.airports) > 1: + self.finish_path_creation() + else: + self.abort_path_creation() + return if ( - len(self.path_being_created.stations) > 1 - and self.path_being_created.stations[-1] == station - ): - self.finish_path_creation() - # loop - elif ( - len(self.path_being_created.stations) > 1 - and self.path_being_created.stations[0] == station + len(self.path_being_created.airports) > 1 + and self.path_being_created.airports[0] == airport ): self.path_being_created.set_loop() self.finish_path_creation() - # non-loop - elif self.path_being_created.stations[0] != station: - self.path_being_created.add_station(station) - self.finish_path_creation() - else: + return + if airport in self.path_being_created.airports: self.abort_path_creation() + return + self.path_being_created.add_airport(airport) + self.finish_path_creation() - def get_station_shape_types(self): - station_shape_types: List[ShapeType] = [] - for station in self.stations: - if station.shape.type not in station_shape_types: - station_shape_types.append(station.shape.type) - return station_shape_types + def get_airport_shape_types(self): + airport_shape_types: List[ShapeType] = [] + for airport in self.airports: + if airport.shape.type not in airport_shape_types: + airport_shape_types.append(airport.shape.type) + return airport_shape_types def is_passenger_spawn_time(self) -> bool: return ( @@ -260,101 +428,122 @@ def is_passenger_spawn_time(self) -> bool: ) def spawn_passengers(self): - for station in self.stations: - station_types = self.get_station_shape_types() - other_station_shape_types = [ - x for x in station_types if x != station.shape.type + for airport in self.airports: + airport_types = self.get_airport_shape_types() + other_airport_shape_types = [ + x for x in airport_types if x != airport.shape.type ] - destination_shape_type = random.choice(other_station_shape_types) + destination_shape_type = random.choice(other_airport_shape_types) destination_shape = get_shape_from_type( destination_shape_type, passenger_color, passenger_size ) passenger = Passenger(destination_shape) - if station.has_room(): - station.add_passenger(passenger) + if airport.has_room(): + airport.add_passenger(passenger) self.passengers.append(passenger) def increment_time(self, dt_ms: int) -> None: - if self.is_paused: + if self.is_paused or self.is_game_over: return # record time self.time_ms += dt_ms self.steps += 1 self.steps_since_last_spawn += 1 + self.steps_since_last_airport_spawn += 1 - # move metros + for airport in self.airports: + if len(airport.passengers) > airport_max_passengers: + if not airport.is_overcrowded: + airport.is_overcrowded = True + airport.overcrowd_start_time_ms = self.time_ms + else: + duration = self.time_ms - airport.overcrowd_start_time_ms + if duration >= overcrowd_time_limit_ms: + self.is_game_over = True + break + else: + if airport.is_overcrowded: + airport.is_overcrowded = False + airport.overcrowd_start_time_ms = 0 + + # move planes for path in self.paths: - for metro in path.metros: - path.move_metro(metro, dt_ms) + for plane in path.planes: + path.move_plane(plane, dt_ms) # spawn passengers if self.is_passenger_spawn_time(): self.spawn_passengers() self.steps_since_last_spawn = 0 + # spawn airports + if self.steps_since_last_airport_spawn >= airport_spawn_interval: + self.spawn_new_airport() + self.steps_since_last_airport_spawn = 0 + self.find_travel_plan_for_passengers() self.move_passengers() def move_passengers(self) -> None: - for metro in self.metros: - if metro.current_station: + for plane in self.planes: + if plane.current_airport: passengers_to_remove = [] - passengers_from_metro_to_station = [] - passengers_from_station_to_metro = [] + passengers_from_plane_to_airport = [] + passengers_from_airport_to_plane = [] # queue - for passenger in metro.passengers: + for passenger in plane.passengers: if ( - metro.current_station.shape.type + plane.current_airport.shape.type == passenger.destination_shape.type ): passengers_to_remove.append(passenger) elif ( - self.travel_plans[passenger].get_next_station() - == metro.current_station + self.travel_plans[passenger].get_next_airport() + == plane.current_airport ): - passengers_from_metro_to_station.append(passenger) - for passenger in metro.current_station.passengers: + passengers_from_plane_to_airport.append(passenger) + for passenger in plane.current_airport.passengers: if ( self.travel_plans[passenger].next_path - and self.travel_plans[passenger].next_path.id == metro.path_id # type: ignore + and self.travel_plans[passenger].next_path.id == plane.path_id # type: ignore ): - passengers_from_station_to_metro.append(passenger) + passengers_from_airport_to_plane.append(passenger) # process for passenger in passengers_to_remove: passenger.is_at_destination = True - metro.remove_passenger(passenger) + plane.remove_passenger(passenger) self.passengers.remove(passenger) del self.travel_plans[passenger] self.score += 1 - for passenger in passengers_from_metro_to_station: - if metro.current_station.has_room(): - metro.move_passenger(passenger, metro.current_station) - self.travel_plans[passenger].increment_next_station() - self.find_next_path_for_passenger_at_station( - passenger, metro.current_station + for passenger in passengers_from_plane_to_airport: + if plane.current_airport.has_room(): + plane.move_passenger(passenger, plane.current_airport) + self.travel_plans[passenger].increment_next_airport() + self.find_next_path_for_passenger_at_airport( + passenger, plane.current_airport ) - for passenger in passengers_from_station_to_metro: - if metro.has_room(): - metro.current_station.move_passenger(passenger, metro) + for passenger in passengers_from_airport_to_plane: + if plane.has_room(): + plane.current_airport.move_passenger(passenger, plane) - def get_stations_for_shape_type(self, shape_type: ShapeType): - stations: List[Station] = [] - for station in self.stations: - if station.shape.type == shape_type: - stations.append(station) - random.shuffle(stations) + def get_airports_for_shape_type(self, shape_type: ShapeType): + airports: List[Airport] = [] + for airport in self.airports: + if airport.shape.type == shape_type: + airports.append(airport) + random.shuffle(airports) - return stations + return airports - def find_shared_path(self, station_a: Station, station_b: Station) -> Path | None: + def find_shared_path(self, airport_a: Airport, airport_b: Airport) -> Path | None: for path in self.paths: - stations = path.stations - if (station_a in stations) and (station_b in stations): + airports = path.airports + if (airport_a in airports) and (airport_b in airports): return path return None @@ -364,15 +553,15 @@ def passenger_has_travel_plan(self, passenger: Passenger) -> bool: and self.travel_plans[passenger].next_path is not None ) - def find_next_path_for_passenger_at_station( - self, passenger: Passenger, station: Station + def find_next_path_for_passenger_at_airport( + self, passenger: Passenger, airport: Airport ): - next_station = self.travel_plans[passenger].get_next_station() - assert next_station is not None - next_path = self.find_shared_path(station, next_station) + next_airport = self.travel_plans[passenger].get_next_airport() + assert next_airport is not None + next_path = self.find_shared_path(airport, next_airport) self.travel_plans[passenger].next_path = next_path - def skip_stations_on_same_path(self, node_path: List[Node]): + def skip_airports_on_same_path(self, node_path: List[Node]): assert len(node_path) >= 2 if len(node_path) == 2: return node_path @@ -397,31 +586,31 @@ def skip_stations_on_same_path(self, node_path: List[Node]): return node_path def find_travel_plan_for_passengers(self) -> None: - station_nodes_dict = build_station_nodes_dict(self.stations, self.paths) - for station in self.stations: - for passenger in station.passengers: + airport_nodes_dict = build_airport_nodes_dict(self.airports, self.paths) + for airport in self.airports: + for passenger in airport.passengers: if not self.passenger_has_travel_plan(passenger): - possible_dst_stations = self.get_stations_for_shape_type( + possible_dst_airports = self.get_airports_for_shape_type( passenger.destination_shape.type ) should_set_null_path = True - for possible_dst_station in possible_dst_stations: - start = station_nodes_dict[station] - end = station_nodes_dict[possible_dst_station] + for possible_dst_airport in possible_dst_airports: + start = airport_nodes_dict[airport] + end = airport_nodes_dict[possible_dst_airport] node_path = bfs(start, end) if len(node_path) == 1: # passenger arrived at destination - station.remove_passenger(passenger) + airport.remove_passenger(passenger) self.passengers.remove(passenger) passenger.is_at_destination = True del self.travel_plans[passenger] should_set_null_path = False break elif len(node_path) > 1: - node_path = self.skip_stations_on_same_path(node_path) + node_path = self.skip_airports_on_same_path(node_path) self.travel_plans[passenger] = TravelPlan(node_path[1:]) - self.find_next_path_for_passenger_at_station( - passenger, station + self.find_next_path_for_passenger_at_airport( + passenger, airport ) should_set_null_path = False break diff --git a/src/pilot_planning_env.py b/src/pilot_planning_env.py new file mode 100644 index 00000000..9f0f54db --- /dev/null +++ b/src/pilot_planning_env.py @@ -0,0 +1,317 @@ +import gymnasium as gym +from gymnasium import spaces +import numpy as np +import itertools +from typing import Dict, Any + +from mediator import Mediator +from config import num_paths, screen_width, screen_height, screen_color +from geometry.type import ShapeType +from geometry.utils import distance + +import pygame + +MAX_airportS = 20 +MAX_PATHS = num_paths +MAX_airportS_PER_PATH = 12 + +class PlaneGameEnv(gym.Env): + """A Gymnasium environment for the Python Mini plane game.""" + + metadata = {"render_modes": ["human"], "render_fps": 30} + + def __init__(self, render_mode: str | None = None): + super().__init__() + self.mediator = Mediator() + + self.render_mode = render_mode + self.screen = None + self.clock = None + if self.render_mode == "human": + pygame.init() + pygame.display.set_caption("plane RL Training") + self.screen = pygame.display.set_mode((screen_width, screen_height)) + self.clock = pygame.time.Clock() + + self.shape_types = sorted([e.value for e in ShapeType]) + self.shape_to_idx = {shape: i for i, shape in enumerate(self.shape_types)} + self.num_shape_types = len(self.shape_types) + + self._action_map = self._create_action_map() + self.action_space = spaces.Discrete(len(self._action_map)) + + self.observation_space = self._create_observation_space() + + def _get_action_mask(self) -> np.ndarray: + """ + Generates a boolean mask for valid actions + Basically just lets the model know immeaditly + if it can take certain actions + """ + mask = np.zeros(self.action_space.n, dtype=np.int8) + num_airports = len(self.mediator.airports) + + for action_id, action_info in self._action_map.items(): + action_type = action_info["type"] + is_valid = False + if action_type == "NO_OP": + is_valid = True + elif action_type == "CREATE_OR_EXTEND_PATH": + start_idx, end_idx = action_info["start_idx"], action_info["end_idx"] + if start_idx < num_airports and end_idx < num_airports and start_idx != end_idx: + is_valid = True + elif action_type == "INSERT_airport": + insert_idx, exist1_idx, exist2_idx = action_info["insert_idx"], action_info["exist1_idx"], action_info["exist2_idx"] + if all(i < num_airports for i in [insert_idx, exist1_idx, exist2_idx]) and len({insert_idx, exist1_idx, exist2_idx}) == 3: + s_insert = self.mediator.airports[insert_idx] + s1 = self.mediator.airports[exist1_idx] + s2 = self.mediator.airports[exist2_idx] + for p in self.mediator.paths: + if s_insert in p.airports: continue # Cannot insert a airport already on the path + for i in range(len(p.airports) - 1): + if (p.airports[i] == s1 and p.airports[i+1] == s2) or \ + (p.airports[i] == s2 and p.airports[i+1] == s1): + is_valid = True; break + if is_valid: break + if p.is_looped and len(p.airports) > 1: + if (p.airports[-1] == s1 and p.airports[0] == s2) or \ + (p.airports[-1] == s2 and p.airports[0] == s1): + is_valid = True; break + if is_valid: + mask[action_id] = 1 + return mask + + def render(self): + if self.render_mode != "human" or self.screen is None: + return + + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self.close() + return + + self.screen.fill(screen_color) + self.mediator.render(self.screen) + pygame.display.flip() + self.clock.tick(self.metadata["render_fps"]) + + def close(self): + if self.screen is not None: + pygame.display.quit() + pygame.quit() + self.screen = None + self.clock = None + + def _create_action_map(self) -> Dict[int, Dict[str, Any]]: + """Creates the mapping from discrete action int to game action.""" + action_map = {0: {"type": "NO_OP"}} + action_id = 1 + + airport_pairs = list(itertools.permutations(range(MAX_airportS), 2)) + for start_idx, end_idx in airport_pairs: + action_map[action_id] = {"type": "CREATE_OR_EXTEND_PATH", "start_idx": start_idx, "end_idx": end_idx} + action_id += 1 + + airport_trios = list(itertools.permutations(range(MAX_airportS), 3)) + for insert_idx, exist1_idx, exist2_idx in airport_trios: + action_map[action_id] = {"type": "INSERT_airport", "insert_idx": insert_idx, "exist1_idx": exist1_idx, "exist2_idx": exist2_idx} + action_id += 1 + + return action_map + + def _create_observation_space(self) -> spaces.Box: + """Correctly defines the size of the observation space.""" + # 1 (exists) + 1 (is_connected) + 2 (pos) + 1 (overcrowd) + 1 (timer) + num_shapes (type) + num_shapes (passengers) + airport_obs_size = 1 + 1 + 2 + 1 + 1 + self.num_shape_types + self.num_shape_types + total_airport_obs_size = MAX_airportS * airport_obs_size + + # 1 (exists/is_loop) + MAX_airportS_PER_PATH (airport indices) + path_obs_size = 1 + MAX_airportS_PER_PATH + total_path_obs_size = MAX_PATHS * path_obs_size + + total_size = total_airport_obs_size + total_path_obs_size + # low=-1 for airport indices in paths, high=screen_width just to be safe (though 1.0 is max for most) + return spaces.Box(low=-1.0, high=2.0, shape=(total_size,), dtype=np.float32) + + def _get_obs(self) -> np.ndarray: + """ + The observation is a flattened `spaces.Box` vector composed of two + main parts: + + 1. **airport Data** (for `MAX_airportS`): + - `exists` (1): 1.0 if the airport exists, 0.0 otherwise. + - `is_connected` (1): 1.0 if part of any path, 0.0 otherwise. + - `position` (2): (x, y) normalized by screen dimensions. + - `is_overcrowded` (1): 1.0 if overcrowded, 0.0 otherwise. + - `overcrowd_timer` (1): Normalized time since overcrowding started + (0.0 to 1.0). + - `type` (num_shapes): One-hot encoding of the airport's shape. + - `passengers` (num_shapes): Count of waiting passengers + for each destination shape, normalized by airport capacity. + + 2. **Path Data** (for `MAX_PATHS`): + - `exists/is_loop` (1): 0.0 for non-existent, 1.0 for existing, + 2.0 for looped path. + - `airports` (MAX_airportS_PER_PATH): List of airport indices (-1 + for empty). + + Final vector will look like this: + [ + --- airport 0 (14 floats) --- + exists, is_connected, x_pos, y_pos, is_overcrowd, crowd_timer, + (type_shape_0, type_shape_1, type_shape_2, type_shape_3), <-- 1-hot type + (pass_shape_0, pass_shape_1, pass_shape_2, pass_shape_3), <-- passenger counts + + --- airport 1 (14 floats) --- + exists, is_connected, x_pos, y_pos, is_overcrowd, crowd_timer, + (0, 0, 1, 0), <-- 1-hot (e.g., is a triangle) + (1.2, 0.5, 0.0, 3.1), <-- passenger counts (normalized) + + ... (repeated for all MAX_airportS) ... + + --- airport 19 (14 floats) --- + (0, 0, 0, 0, 0, 0, (0,0,0,0), (0,0,0,0)), <-- all zeros if airport doesn't exist + + --- Path 0 (13 floats) --- + exists_or_loop_status, (idx_0, idx_1, idx_2, ..., idx_11), + + --- Path 1 (13 floats) --- + exists_or_loop_status, (idx_0, idx_1, idx_2, ..., idx_11), + + ... (repeated for all MAX_PATHS) ... + ] + """ + # figure out which airports are connected + airports_in_paths = set() + for path in self.mediator.paths: + for airport in path.airports: + airports_in_paths.add(airport.id) + + obs = np.zeros(self.observation_space.shape, dtype=np.float32) + + # 1 (exists) + 1 (is_connected) + 2 (pos) + 1 (overcrowd) + 1 (timer) + num_shapes (type) + num_shapes (passengers) + airport_chunk_size = 1 + 1 + 2 + 1 + 1 + self.num_shape_types + self.num_shape_types + + for i in range(MAX_airportS): + offset = i * airport_chunk_size + # Create a vector representation of every airport: + if i < len(self.mediator.airports): + airport = self.mediator.airports[i] + + # Base offset for this airport's features + feat_offset = 0 + + # 1. Existence flag + obs[offset + feat_offset] = 1.0 + feat_offset += 1 + + # 2. is_connected + obs[offset + feat_offset] = 1.0 if airport.id in airports_in_paths else 0.0 + feat_offset += 1 + + # 3. Position (2 floats) + obs[offset + feat_offset] = airport.position.left / screen_width + feat_offset += 1 + obs[offset + feat_offset] = airport.position.top / screen_height + feat_offset += 1 + + # 4. is_overcrowded flag + obs[offset + feat_offset] = 1.0 if airport.is_overcrowded else 0.0 + feat_offset += 1 + + # 5. Overcrowd timer + if airport.is_overcrowded: + elapsed = self.mediator.time_ms - airport.overcrowd_start_time_ms + obs[offset + feat_offset] = min(elapsed / 10000.0, 1.0) + feat_offset += 1 + + # 6. airport shape (one-hot) + shape_idx = self.shape_to_idx[airport.shape.type.value] + obs[offset + feat_offset + shape_idx] = 1.0 + feat_offset += self.num_shape_types + + # 7. Passenger counts (per destination shape) + passenger_counts = np.zeros(self.num_shape_types, dtype=np.float32) + for p in airport.passengers: + dest_idx = self.shape_to_idx[p.destination_shape.type.value] + passenger_counts[dest_idx] += 1 + + obs[offset + feat_offset : offset + airport_chunk_size] = passenger_counts / airport.capacity + + path_chunk_size = 1 + MAX_airportS_PER_PATH + airport_offset = MAX_airportS * airport_chunk_size + airport_to_game_idx = {s.id: i for i, s in enumerate(self.mediator.airports)} + + for i in range(MAX_PATHS): + offset = airport_offset + i * path_chunk_size + if i < len(self.mediator.paths): + path = self.mediator.paths[i] + obs[offset] = 1.0 if not path.is_looped else 2.0 # Use 2.0 to signify a loop + path_indices = [-1.0] * MAX_airportS_PER_PATH + for j, airport in enumerate(path.airports): + if j < MAX_airportS_PER_PATH: + path_indices[j] = airport_to_game_idx.get(airport.id, -1.0) + obs[offset + 1 : offset + path_chunk_size] = path_indices + return obs + + def _get_info(self) -> Dict[str, Any]: + """Returns info dict, including the crucial action mask.""" + return { + "score": self.mediator.score, + "steps": self.mediator.steps, + "action_mask": self._get_action_mask() + } + + def reset(self, seed=None, options=None) -> tuple[np.ndarray, Dict[str, Any]]: + super().reset(seed=seed) + self.mediator = Mediator() + if self.render_mode == "human": + self.render() + return self._get_obs(), self._get_info() + + def step(self, action: int) -> tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: + prev_score = self.mediator.score + num_loops_before = sum(1 for p in self.mediator.paths if p.is_looped) + + action_info = self._action_map.get(action) + action_was_valid = False + + if action_info: + action_type = action_info["type"] + if action_type == "NO_OP": + action_was_valid = True + + elif action_type == "CREATE_OR_EXTEND_PATH": + start_idx, end_idx = action_info["start_idx"], action_info["end_idx"] + if start_idx < len(self.mediator.airports) and end_idx < len(self.mediator.airports) and start_idx != end_idx: + start_airport = self.mediator.airports[start_idx] + end_airport = self.mediator.airports[end_idx] + action_was_valid = self.mediator.create_or_extend_path(start_airport, end_airport) + + elif action_type == "INSERT_airport": + insert_idx, exist1_idx, exist2_idx = action_info["insert_idx"], action_info["exist1_idx"], action_info["exist2_idx"] + if all(i < len(self.mediator.airports) for i in [insert_idx, exist1_idx, exist2_idx]): + s_insert = self.mediator.airports[insert_idx] + s1 = self.mediator.airports[exist1_idx] + s2 = self.mediator.airports[exist2_idx] + action_was_valid = self.mediator.insert_airport_on_path(s_insert, s1, s2) + + # Simulate 15 game-ticks + for _ in range(15): + if self.mediator.is_game_over: break + self.mediator.increment_time(16) + + reward = (self.mediator.score - prev_score) * 25.0 + + reward += 0.01 + + if not action_was_valid: + reward -= 1.0 + + terminated = self.mediator.is_game_over + + if self.render_mode == "human": + self.render() + + return self._get_obs(), reward, terminated, False, self._get_info() + \ No newline at end of file diff --git a/src/run_agent.py b/src/run_agent.py new file mode 100644 index 00000000..e4032ed9 --- /dev/null +++ b/src/run_agent.py @@ -0,0 +1,77 @@ +import os +import time +import argparse +import gymnasium as gym +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv +from stable_baselines3.common.env_util import make_vec_env + +from pilot_planning_env import PlaneGameEnv + +def run_agent(model_folder): + """ + Loads and runs a trained PPO agent with a UI. + """ + + model_path = os.path.join(model_folder, "final_model.zip") + stats_path = os.path.join(model_folder, "vec_normalize.pkl") + + if not os.path.exists(model_path): + print(f"Warning: 'final_model.zip' not found. Searching for latest checkpoint...") + checkpoints = [f for f in os.listdir(model_folder) if f.startswith("plane_rl_model_") and f.endswith(".zip")] + if not checkpoints: + print(f"Error: No model files found in {model_folder}. Aborting.") + return + + checkpoints.sort(key=lambda f: int(f.split('_')[3])) + model_path = os.path.join(model_folder, checkpoints[-1]) + print(f"Loading latest checkpoint: {model_path}") + + if not os.path.exists(stats_path): + print(f"Error: 'vec_normalize.pkl' not found at {stats_path}. This file is required. Aborting.") + return + + def create_eval_env(): + env = PlaneGameEnv(render_mode="human") + env = gym.wrappers.TimeLimit(env, max_episode_steps=5000) + return env + + env = DummyVecEnv([create_eval_env]) + + env = VecNormalize.load(stats_path, env) + env.training = False + env.norm_reward = False + + print(f"Loading model from {model_path}...") + model = PPO.load(model_path, env=env) + print("Model loaded successfully.") + + obs = env.reset() + total_reward = 0 + + print("Starting simulation") + try: + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, terminated, info = env.step(action) + + total_reward += reward[0] + if terminated[0]: + print(f"Episode finished. Total Reward: {total_reward}") + total_reward = 0 + print("Resetting environment...") + time.sleep(2) + obs = env.reset() + + except KeyboardInterrupt: + print("\nSimulation stopped by user.") + finally: + env.close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run a trained Mini Plane PPO agent with UI.") + parser.add_argument("model_folder", type=str, help="Path to the directory containing the saved model (.zip) and stats (vec_normalize.pkl).") + + args = parser.parse_args() + + run_agent(args.model_folder) \ No newline at end of file diff --git a/src/train.py b/src/train.py new file mode 100644 index 00000000..20a66f3f --- /dev/null +++ b/src/train.py @@ -0,0 +1,94 @@ +import os +import time +import gymnasium as gym +import platform +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import CheckpointCallback +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3.common.vec_env import VecNormalize +from pilot_planning_env import PlaneGameEnv + + +LOG_DIR = f"logs/256-128-128/" +MODEL_DIR = f"models/PPO/256-128-128/" +TOTAL_TIMESTEPS = 25_000_000 +SAVE_FREQ = 25_000 + +TB_LOG_NAME = "PPO_plane_Run" +net_arch_config = [256, 128, 128] +policy_kwargs = dict(net_arch=net_arch_config) + +def create_env(): + """Helper function to create and wrap the environment.""" + env = PlaneGameEnv(render_mode=None) + env = gym.wrappers.TimeLimit(env, max_episode_steps=5000) + return env + +def train_agent(): + """Initializes and trains the PPO agent.""" + if __name__ == '__main__': + os.makedirs(LOG_DIR, exist_ok=True) + os.makedirs(MODEL_DIR, exist_ok=True) + num_cpu = os.cpu_count() - 1 if os.cpu_count() > 1 else 1 + start_method = 'fork' if platform.system() != 'Windows' else 'spawn' + env = make_vec_env( + create_env, + n_envs=num_cpu, + vec_env_cls=SubprocVecEnv, + vec_env_kwargs=dict(start_method=start_method) + ) + + env = VecNormalize(env, gamma=0.99) + + checkpoint_callback = CheckpointCallback( + save_freq=SAVE_FREQ, + save_path=MODEL_DIR, + name_prefix="plane_rl_model", + save_replay_buffer=True, + save_vecnormalize=True, + ) + + checkpoint_callback = CheckpointCallback( + save_freq=SAVE_FREQ, + save_path=MODEL_DIR, + name_prefix="plane_rl_model", + save_replay_buffer=True, + save_vecnormalize=True, + ) + + model = PPO( + "MlpPolicy", + env, + verbose=1, + tensorboard_log=LOG_DIR, + device="cpu", + n_steps=4096, + learning_rate=1e-5, + policy_kwargs=policy_kwargs, + batch_size=64, + gamma=0.99, + gae_lambda=0.95, + n_epochs=10, + ent_coef=0.01, + vf_coef=0.5, + max_grad_norm=0.5, + ) + + print(f"Starting training on {num_cpu} cores for {TOTAL_TIMESTEPS} timesteps...") + + model.learn( + total_timesteps=TOTAL_TIMESTEPS, + callback=checkpoint_callback, + tb_log_name=TB_LOG_NAME + ) + + final_model_path = os.path.join(MODEL_DIR, "final_model") + model.save(final_model_path) + env.save(os.path.join(MODEL_DIR, "vec_normalize.pkl")) + print(f"Training complete! Final model saved to {final_model_path}") + + env.close() + +if __name__ == '__main__': + train_agent() diff --git a/src/travel_plan.py b/src/travel_plan.py index 2292ec78..8a8ddb6a 100644 --- a/src/travel_plan.py +++ b/src/travel_plan.py @@ -3,33 +3,33 @@ from typing import List from entity.path import Path -from entity.station import Station +from entity.airport import Airport from graph.node import Node - +# determines what trains passengers need to go on -- does BFS essentially class TravelPlan: def __init__( self, node_path: List[Node], ) -> None: self.next_path: Path | None = None - self.next_station: Station | None = None + self.next_airport: Airport | None = None self.node_path = node_path - self.next_station_idx = 0 + self.next_airport_idx = 0 - def get_next_station(self) -> Station | None: + def get_next_airport(self) -> Airport | None: if self.node_path is not None and len(self.node_path) > 0: - next_node = self.node_path[self.next_station_idx] - next_station = next_node.station - self.next_station = next_station - return next_station + next_node = self.node_path[self.next_airport_idx] + next_airport = next_node.airport + self.next_airport = next_airport + return next_airport else: return None - def increment_next_station(self) -> None: - self.next_station_idx += 1 + def increment_next_airport(self) -> None: + self.next_airport_idx += 1 def __repr__(self) -> str: return ( - f"TravelPlan = get on {self.next_path}, then get off at {self.next_station}" + f"TravelPlan = get on {self.next_path}, then get off at {self.next_airport}" ) diff --git a/src/utils.py b/src/utils.py index 717147a4..6a4c77b7 100644 --- a/src/utils.py +++ b/src/utils.py @@ -4,7 +4,7 @@ import numpy as np -from config import passenger_size, station_color, station_shape_type_list, station_size +from config import passenger_size, airport_color, airport_shape_type_list, airport_size from geometry.circle import Circle from geometry.cross import Cross from geometry.point import Point @@ -42,12 +42,12 @@ def get_random_shape( return get_shape_from_type(shape_type, color, size) -def get_random_station_shape() -> Shape: - return get_random_shape(station_shape_type_list, station_color, station_size) +def get_random_airport_shape() -> Shape: + return get_random_shape(airport_shape_type_list, airport_color, airport_size) def get_random_passenger_shape() -> Shape: - return get_random_shape(station_shape_type_list, get_random_color(), passenger_size) + return get_random_shape(airport_shape_type_list, get_random_color(), passenger_size) def tuple_to_point(tuple: Tuple[int, int]) -> Point: diff --git a/test/test_gameplay.py b/test/test_gameplay.py index 66e8156b..90b104e9 100644 --- a/test/test_gameplay.py +++ b/test/test_gameplay.py @@ -8,7 +8,7 @@ import pygame from config import screen_height, screen_width -from entity.get_entity import get_random_stations +from entity.get_entity import get_random_airports from event.keyboard import KeyboardEvent from event.mouse import MouseEvent from event.type import KeyboardEventType, MouseEventType @@ -27,36 +27,36 @@ def setUp(self): pygame.draw = MagicMock() self.mediator.render(self.screen) - def connect_stations(self, station_idx): + def connect_airports(self, airport_idx): self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[station_idx[0]].position, + self.mediator.airports[airport_idx[0]].position, ) ) - for idx in station_idx[1:]: + for idx in airport_idx[1:]: self.mediator.react( MouseEvent( - MouseEventType.MOUSE_MOTION, self.mediator.stations[idx].position + MouseEventType.MOUSE_MOTION, self.mediator.airports[idx].position ) ) self.mediator.react( MouseEvent( MouseEventType.MOUSE_UP, - self.mediator.stations[station_idx[-1]].position, + self.mediator.airports[airport_idx[-1]].position, ) ) def test_react_mouse_down_start_path(self): - self.mediator.start_path_on_station = MagicMock() + self.mediator.start_path_on_airport = MagicMock() self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[3].position + Point(1, 1), + self.mediator.airports[3].position + Point(1, 1), ) ) - self.mediator.start_path_on_station.assert_called_once() + self.mediator.start_path_on_airport.assert_called_once() def test_mouse_down_and_up_at_the_same_point_does_not_create_path(self): self.mediator.react(MouseEvent(MouseEventType.MOUSE_DOWN, Point(-1, -1))) @@ -64,41 +64,41 @@ def test_mouse_down_and_up_at_the_same_point_does_not_create_path(self): self.assertEqual(len(self.mediator.paths), 0) - def test_mouse_dragged_between_stations_creates_path(self): + def test_mouse_dragged_between_airports_creates_path(self): self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[0].position + Point(1, 1), + self.mediator.airports[0].position + Point(1, 1), ) ) self.mediator.react(MouseEvent(MouseEventType.MOUSE_MOTION, Point(2, 2))) self.mediator.react( MouseEvent( MouseEventType.MOUSE_UP, - self.mediator.stations[1].position + Point(1, 1), + self.mediator.airports[1].position + Point(1, 1), ) ) self.assertEqual(len(self.mediator.paths), 1) self.assertSequenceEqual( - self.mediator.paths[0].stations, - [self.mediator.stations[0], self.mediator.stations[1]], + self.mediator.paths[0].airports, + [self.mediator.airports[0], self.mediator.airports[1]], ) - def test_mouse_dragged_between_non_station_points_does_not_create_path(self): + def test_mouse_dragged_between_non_airport_points_does_not_create_path(self): self.mediator.react(MouseEvent(MouseEventType.MOUSE_DOWN, Point(0, 0))) self.mediator.react(MouseEvent(MouseEventType.MOUSE_MOTION, Point(2, 2))) self.mediator.react(MouseEvent(MouseEventType.MOUSE_UP, Point(0, 1))) self.assertEqual(len(self.mediator.paths), 0) - def test_mouse_dragged_between_station_and_non_station_points_does_not_create_path( + def test_mouse_dragged_between_airport_and_non_airport_points_does_not_create_path( self, ): self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[0].position + Point(1, 1), + self.mediator.airports[0].position + Point(1, 1), ) ) self.mediator.react(MouseEvent(MouseEventType.MOUSE_MOTION, Point(2, 2))) @@ -106,26 +106,26 @@ def test_mouse_dragged_between_station_and_non_station_points_does_not_create_pa self.assertEqual(len(self.mediator.paths), 0) - def test_mouse_dragged_between_3_stations_creates_looped_path(self): - self.connect_stations([0, 1, 2, 0]) + def test_mouse_dragged_between_3_airports_creates_looped_path(self): + self.connect_airports([0, 1, 2, 0]) self.assertEqual(len(self.mediator.paths), 1) self.assertTrue(self.mediator.paths[0].is_looped) - def test_mouse_dragged_between_4_stations_creates_looped_path(self): - self.connect_stations([0, 1, 2, 3, 0]) + def test_mouse_dragged_between_4_airports_creates_looped_path(self): + self.connect_airports([0, 1, 2, 3, 0]) self.assertEqual(len(self.mediator.paths), 1) self.assertTrue(self.mediator.paths[0].is_looped) - def test_path_between_2_stations_is_not_looped(self): - self.connect_stations([0, 1]) + def test_path_between_2_airports_is_not_looped(self): + self.connect_airports([0, 1]) self.assertEqual(len(self.mediator.paths), 1) self.assertFalse(self.mediator.paths[0].is_looped) - def test_mouse_dragged_between_3_stations_without_coming_back_to_first_does_not_create_loop( + def test_mouse_dragged_between_3_airports_without_coming_back_to_first_does_not_create_loop( self, ): - self.connect_stations([0, 1, 2]) + self.connect_airports([0, 1, 2]) self.assertEqual(len(self.mediator.paths), 1) self.assertFalse(self.mediator.paths[0].is_looped) @@ -139,10 +139,10 @@ def test_space_key_pauses_and_unpauses_game(self): self.assertFalse(self.mediator.is_paused) def test_path_button_removes_path_on_click(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) - self.connect_stations([0, 1]) + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) + self.connect_airports([0, 1]) self.mediator.react( MouseEvent(MouseEventType.MOUSE_UP, self.mediator.path_buttons[0].position) ) @@ -150,43 +150,43 @@ def test_path_button_removes_path_on_click(self): self.assertEqual(len(self.mediator.path_to_button.items()), 0) def test_path_buttons_get_assigned_upon_path_creation(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) - self.connect_stations([0, 1]) + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) + self.connect_airports([0, 1]) self.assertEqual(len(self.mediator.path_to_button.items()), 1) self.assertIn(self.mediator.paths[0], self.mediator.path_to_button) - self.connect_stations([2, 3]) + self.connect_airports([2, 3]) self.assertEqual(len(self.mediator.path_to_button.items()), 2) self.assertIn(self.mediator.paths[0], self.mediator.path_to_button) self.assertIn(self.mediator.paths[1], self.mediator.path_to_button) - self.connect_stations([1, 3]) + self.connect_airports([1, 3]) self.assertEqual(len(self.mediator.path_to_button.items()), 3) self.assertIn(self.mediator.paths[0], self.mediator.path_to_button) self.assertIn(self.mediator.paths[1], self.mediator.path_to_button) self.assertIn(self.mediator.paths[2], self.mediator.path_to_button) def test_more_paths_can_be_created_after_removing_paths(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) - self.connect_stations([0, 1]) - self.connect_stations([2, 3]) - self.connect_stations([1, 4]) + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) + self.connect_airports([0, 1]) + self.connect_airports([2, 3]) + self.connect_airports([1, 4]) self.mediator.react( MouseEvent(MouseEventType.MOUSE_UP, self.mediator.path_buttons[0].position) ) self.assertEqual(len(self.mediator.paths), 2) - self.connect_stations([1, 3]) + self.connect_airports([1, 3]) self.assertEqual(len(self.mediator.paths), 3) def test_assigned_path_buttons_bubble_to_left(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) - self.connect_stations([0, 1]) - self.connect_stations([2, 3]) - self.connect_stations([1, 4]) + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) + self.connect_airports([0, 1]) + self.connect_airports([2, 3]) + self.connect_airports([1, 4]) self.mediator.react( MouseEvent(MouseEventType.MOUSE_UP, self.mediator.path_buttons[0].position) ) diff --git a/test/test_graph.py b/test/test_graph.py index 9054b986..d8387ab0 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -3,19 +3,19 @@ import unittest from unittest.mock import create_autospec -from entity.get_entity import get_random_stations +from entity.get_entity import get_random_airports sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../src") import pygame -from config import screen_height, screen_width, station_color, station_size -from entity.station import Station +from config import screen_height, screen_width, airport_color, airport_size +from entity.airport import airport from event.mouse import MouseEvent from event.type import MouseEventType from geometry.circle import Circle from geometry.rect import Rect -from graph.graph_algo import bfs, build_station_nodes_dict +from graph.graph_algo import bfs, build_airport_nodes_dict from graph.node import Node from mediator import Mediator from utils import get_random_color, get_random_position @@ -28,114 +28,114 @@ def setUp(self): self.position = get_random_position(self.width, self.height) self.color = get_random_color() self.mediator = Mediator() - for station in self.mediator.stations: - station.draw(self.screen) + for airport in self.mediator.airports: + airport.draw(self.screen) - def connect_stations(self, station_idx): + def connect_airports(self, airport_idx): self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[station_idx[0]].position, + self.mediator.airports[airport_idx[0]].position, ) ) - for idx in station_idx[1:]: + for idx in airport_idx[1:]: self.mediator.react( MouseEvent( - MouseEventType.MOUSE_MOTION, self.mediator.stations[idx].position + MouseEventType.MOUSE_MOTION, self.mediator.airports[idx].position ) ) self.mediator.react( MouseEvent( MouseEventType.MOUSE_UP, - self.mediator.stations[station_idx[-1]].position, + self.mediator.airports[airport_idx[-1]].position, ) ) - def test_build_station_nodes_dict(self): - self.mediator.stations = [ - Station( + def test_build_airport_nodes_dict(self): + self.mediator.airports = [ + airport( Rect( - color=station_color, - width=2 * station_size, - height=2 * station_size, + color=airport_color, + width=2 * airport_size, + height=2 * airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Circle( - color=station_color, - radius=station_size, + color=airport_color, + radius=airport_size, ), get_random_position(self.width, self.height), ), ] - for station in self.mediator.stations: - station.draw(self.screen) + for airport in self.mediator.airports: + airport.draw(self.screen) - self.connect_stations([0, 1]) + self.connect_airports([0, 1]) - station_nodes_dict = build_station_nodes_dict( - self.mediator.stations, self.mediator.paths + airport_nodes_dict = build_airport_nodes_dict( + self.mediator.airports, self.mediator.paths ) - self.assertCountEqual(list(station_nodes_dict.keys()), self.mediator.stations) - for station, node in station_nodes_dict.items(): - self.assertEqual(node.station, station) + self.assertCountEqual(list(airport_nodes_dict.keys()), self.mediator.airports) + for airport, node in airport_nodes_dict.items(): + self.assertEqual(node.airport, airport) - def test_bfs_two_stations(self): - self.mediator.stations = get_random_stations(2) - for station in self.mediator.stations: - station.draw(self.screen) + def test_bfs_two_airports(self): + self.mediator.airports = get_random_airports(2) + for airport in self.mediator.airports: + airport.draw(self.screen) - self.connect_stations([0, 1]) + self.connect_airports([0, 1]) - station_nodes_dict = build_station_nodes_dict( - self.mediator.stations, self.mediator.paths + airport_nodes_dict = build_airport_nodes_dict( + self.mediator.airports, self.mediator.paths ) - start_station = self.mediator.stations[0] - end_station = self.mediator.stations[1] - start_node = station_nodes_dict[start_station] - end_node = station_nodes_dict[end_station] + start_airport = self.mediator.airports[0] + end_airport = self.mediator.airports[1] + start_node = airport_nodes_dict[start_airport] + end_node = airport_nodes_dict[end_airport] node_path = bfs(start_node, end_node) self.assertSequenceEqual( node_path, [start_node, end_node], ) - def test_bfs_five_stations(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) + def test_bfs_five_airports(self): + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) - self.connect_stations([0, 1, 2]) - self.connect_stations([0, 3]) + self.connect_airports([0, 1, 2]) + self.connect_airports([0, 3]) - station_nodes_dict = build_station_nodes_dict( - self.mediator.stations, self.mediator.paths + airport_nodes_dict = build_airport_nodes_dict( + self.mediator.airports, self.mediator.paths ) - start_node = station_nodes_dict[self.mediator.stations[0]] - end_node = station_nodes_dict[self.mediator.stations[2]] + start_node = airport_nodes_dict[self.mediator.airports[0]] + end_node = airport_nodes_dict[self.mediator.airports[2]] node_path = bfs(start_node, end_node) self.assertSequenceEqual( node_path, [ - Node(self.mediator.stations[0]), - Node(self.mediator.stations[1]), - Node(self.mediator.stations[2]), + Node(self.mediator.airports[0]), + Node(self.mediator.airports[1]), + Node(self.mediator.airports[2]), ], ) - start_node = station_nodes_dict[self.mediator.stations[1]] - end_node = station_nodes_dict[self.mediator.stations[3]] + start_node = airport_nodes_dict[self.mediator.airports[1]] + end_node = airport_nodes_dict[self.mediator.airports[3]] node_path = bfs(start_node, end_node) self.assertSequenceEqual( node_path, [ - Node(self.mediator.stations[1]), - Node(self.mediator.stations[0]), - Node(self.mediator.stations[3]), + Node(self.mediator.airports[1]), + Node(self.mediator.airports[0]), + Node(self.mediator.airports[3]), ], ) - start_node = station_nodes_dict[self.mediator.stations[0]] - end_node = station_nodes_dict[self.mediator.stations[4]] + start_node = airport_nodes_dict[self.mediator.airports[0]] + end_node = airport_nodes_dict[self.mediator.airports[4]] node_path = bfs(start_node, end_node) self.assertSequenceEqual( node_path, diff --git a/test/test_mediator.py b/test/test_mediator.py index e7e132b6..01c0de75 100644 --- a/test/test_mediator.py +++ b/test/test_mediator.py @@ -3,7 +3,7 @@ import unittest from unittest.mock import MagicMock, create_autospec -from entity.get_entity import get_random_stations +from entity.get_entity import get_random_airports from event.mouse import MouseEvent from event.type import MouseEventType from geometry.triangle import Triangle @@ -21,10 +21,10 @@ passenger_spawning_start_step, screen_height, screen_width, - station_color, - station_size, + airport_color, + airport_size, ) -from entity.station import Station +from entity.airport import airport from geometry.circle import Circle from geometry.point import Point from geometry.rect import Rect @@ -41,29 +41,29 @@ def setUp(self): self.mediator = Mediator() self.mediator.render(self.screen) - def connect_stations(self, station_idx): + def connect_airports(self, airport_idx): self.mediator.react( MouseEvent( MouseEventType.MOUSE_DOWN, - self.mediator.stations[station_idx[0]].position, + self.mediator.airports[airport_idx[0]].position, ) ) - for idx in station_idx[1:]: + for idx in airport_idx[1:]: self.mediator.react( MouseEvent( - MouseEventType.MOUSE_MOTION, self.mediator.stations[idx].position + MouseEventType.MOUSE_MOTION, self.mediator.airports[idx].position ) ) self.mediator.react( MouseEvent( MouseEventType.MOUSE_UP, - self.mediator.stations[station_idx[-1]].position, + self.mediator.airports[airport_idx[-1]].position, ) ) def test_react_mouse_down(self): - for station in self.mediator.stations: - station.draw(self.screen) + for airport in self.mediator.airports: + airport.draw(self.screen) self.mediator.react(MouseEvent(MouseEventType.MOUSE_DOWN, Point(-1, -1))) self.assertTrue(self.mediator.is_mouse_down) @@ -71,7 +71,7 @@ def test_react_mouse_down(self): def test_get_containing_entity(self): self.assertTrue( self.mediator.get_containing_entity( - self.mediator.stations[2].position + Point(1, 1) + self.mediator.airports[2].position + Point(1, 1) ) ) @@ -80,10 +80,10 @@ def test_react_mouse_up(self): self.assertFalse(self.mediator.is_mouse_down) - def test_passengers_are_added_to_stations(self): + def test_passengers_are_added_to_airports(self): self.mediator.spawn_passengers() - self.assertEqual(len(self.mediator.passengers), len(self.mediator.stations)) + self.assertEqual(len(self.mediator.passengers), len(self.mediator.airports)) def test_is_passenger_spawn_time(self): self.mediator.spawn_passengers = MagicMock() @@ -98,53 +98,53 @@ def test_is_passenger_spawn_time(self): self.assertEqual(self.mediator.spawn_passengers.call_count, 2) - def test_passengers_spawned_at_a_station_have_a_different_destination(self): + def test_passengers_spawned_at_a_airport_have_a_different_destination(self): # Run the game until first wave of passengers spawn for _ in range(passenger_spawning_start_step): self.mediator.increment_time(ceil(1000 / framerate)) - for station in self.mediator.stations: - for passenger in station.passengers: + for airport in self.mediator.airports: + for passenger in airport.passengers: self.assertNotEqual( - passenger.destination_shape.type, station.shape.type + passenger.destination_shape.type, airport.shape.type ) - def test_passengers_at_connected_stations_have_a_way_to_destination(self): - self.mediator.stations = [ - Station( + def test_passengers_at_connected_airports_have_a_way_to_destination(self): + self.mediator.airports = [ + airport( Rect( - color=station_color, - width=2 * station_size, - height=2 * station_size, + color=airport_color, + width=2 * airport_size, + height=2 * airport_size, ), Point(100, 100), ), - Station( + airport( Circle( - color=station_color, - radius=station_size, + color=airport_color, + radius=airport_size, ), Point(100, 200), ), ] - # Need to draw stations if you want to override them - for station in self.mediator.stations: - station.draw(self.screen) + # Need to draw airports if you want to override them + for airport in self.mediator.airports: + airport.draw(self.screen) # Run the game until first wave of passengers spawn for _ in range(passenger_spawning_start_step): self.mediator.increment_time(ceil(1000 / framerate)) - self.connect_stations([0, 1]) + self.connect_airports([0, 1]) self.mediator.increment_time(ceil(1000 / framerate)) for passenger in self.mediator.passengers: self.assertIn(passenger, self.mediator.travel_plans) self.assertIsNotNone(self.mediator.travel_plans[passenger]) self.assertIsNotNone(self.mediator.travel_plans[passenger].next_path) - self.assertIsNotNone(self.mediator.travel_plans[passenger].next_station) + self.assertIsNotNone(self.mediator.travel_plans[passenger].next_airport) - def test_passengers_at_isolated_stations_have_no_way_to_destination(self): + def test_passengers_at_isolated_airports_have_no_way_to_destination(self): # Run the game until first wave of passengers spawn, then 1 more frame for _ in range(passenger_spawning_start_step + 1): self.mediator.increment_time(ceil(1000 / framerate)) @@ -153,73 +153,73 @@ def test_passengers_at_isolated_stations_have_no_way_to_destination(self): self.assertIn(passenger, self.mediator.travel_plans) self.assertIsNotNone(self.mediator.travel_plans[passenger]) self.assertIsNone(self.mediator.travel_plans[passenger].next_path) - self.assertIsNone(self.mediator.travel_plans[passenger].next_station) + self.assertIsNone(self.mediator.travel_plans[passenger].next_airport) - def test_get_station_for_shape_type(self): - self.mediator.stations = [ - Station( + def test_get_airport_for_shape_type(self): + self.mediator.airports = [ + airport( Rect( - color=station_color, - width=2 * station_size, - height=2 * station_size, + color=airport_color, + width=2 * airport_size, + height=2 * airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Circle( - color=station_color, - radius=station_size, + color=airport_color, + radius=airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Circle( - color=station_color, - radius=station_size, + color=airport_color, + radius=airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Triangle( - color=station_color, - size=station_size, + color=airport_color, + size=airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Triangle( - color=station_color, - size=station_size, + color=airport_color, + size=airport_size, ), get_random_position(self.width, self.height), ), - Station( + airport( Triangle( - color=station_color, - size=station_size, + color=airport_color, + size=airport_size, ), get_random_position(self.width, self.height), ), ] - rect_stations = self.mediator.get_stations_for_shape_type(ShapeType.RECT) - circle_stations = self.mediator.get_stations_for_shape_type(ShapeType.CIRCLE) - triangle_stations = self.mediator.get_stations_for_shape_type( + rect_airports = self.mediator.get_airports_for_shape_type(ShapeType.RECT) + circle_airports = self.mediator.get_airports_for_shape_type(ShapeType.CIRCLE) + triangle_airports = self.mediator.get_airports_for_shape_type( ShapeType.TRIANGLE ) - self.assertCountEqual(rect_stations, self.mediator.stations[0:1]) - self.assertCountEqual(circle_stations, self.mediator.stations[1:3]) - self.assertCountEqual(triangle_stations, self.mediator.stations[3:]) + self.assertCountEqual(rect_airports, self.mediator.airports[0:1]) + self.assertCountEqual(circle_airports, self.mediator.airports[1:3]) + self.assertCountEqual(triangle_airports, self.mediator.airports[3:]) - def test_skip_stations_on_same_path(self): - self.mediator.stations = get_random_stations(5) - for station in self.mediator.stations: - station.draw(self.screen) - self.connect_stations([i for i in range(5)]) + def test_skip_airports_on_same_path(self): + self.mediator.airports = get_random_airports(5) + for airport in self.mediator.airports: + airport.draw(self.screen) + self.connect_airports([i for i in range(5)]) self.mediator.spawn_passengers() self.mediator.find_travel_plan_for_passengers() - for station in self.mediator.stations: - for passenger in station.passengers: + for airport in self.mediator.airports: + for passenger in airport.passengers: self.assertEqual( len(self.mediator.travel_plans[passenger].node_path), 1 ) diff --git a/test/test_path.py b/test/test_path.py index 576dac73..104f6075 100644 --- a/test/test_path.py +++ b/test/test_path.py @@ -8,13 +8,13 @@ import pygame -from config import framerate, metro_speed_per_ms -from entity.get_entity import get_random_station, get_random_stations -from entity.metro import Metro +from config import framerate, plane_speed_per_ms +from entity.get_entity import get_random_airport, get_random_airports +from entity.plane import plane from entity.path import Path -from entity.station import Station +from entity.airport import airport from geometry.point import Point -from utils import get_random_color, get_random_position, get_random_station_shape +from utils import get_random_color, get_random_position, get_random_airport_shape class TestPath(unittest.TestCase): @@ -26,17 +26,17 @@ def setUp(self): def test_init(self): path = Path(get_random_color()) - station = get_random_station() - path.add_station(station) + airport = get_random_airport() + path.add_airport(airport) - self.assertIn(station, path.stations) + self.assertIn(airport, path.airports) def test_draw(self): path = Path(get_random_color()) - stations = get_random_stations(5) + airports = get_random_airports(5) pygame.draw.line = MagicMock() - for station in stations: - path.add_station(station) + for airport in airports: + path.add_airport(airport) path.draw(self.screen, 0) self.assertEqual(pygame.draw.line.call_count, 4 + 3) @@ -44,77 +44,77 @@ def test_draw(self): def test_draw_temporary_point(self): path = Path(get_random_color()) pygame.draw.line = MagicMock() - path.add_station(get_random_station()) + path.add_airport(get_random_airport()) path.set_temporary_point(Point(1, 1)) path.draw(self.screen, 0) self.assertEqual(pygame.draw.line.call_count, 1) - def test_metro_starts_at_beginning_of_first_line(self): + def test_plane_starts_at_beginning_of_first_line(self): path = Path(get_random_color()) - path.add_station(get_random_station()) - path.add_station(get_random_station()) + path.add_airport(get_random_airport()) + path.add_airport(get_random_airport()) path.draw(self.screen, 0) - metro = Metro() - path.add_metro(metro) + plane = plane() + path.add_plane(plane) - self.assertEqual(metro.current_segment, path.segments[0]) - self.assertEqual(metro.current_segment_idx, 0) - self.assertTrue(metro.is_forward) + self.assertEqual(plane.current_segment, path.segments[0]) + self.assertEqual(plane.current_segment_idx, 0) + self.assertTrue(plane.is_forward) - def test_metro_moves_from_beginning_to_end(self): + def test_plane_moves_from_beginning_to_end(self): path = Path(get_random_color()) - path.add_station(Station(get_random_station_shape(), Point(0, 0))) - dist_in_one_sec = 1000 * metro_speed_per_ms - path.add_station(Station(get_random_station_shape(), Point(dist_in_one_sec, 0))) + path.add_airport(airport(get_random_airport_shape(), Point(0, 0))) + dist_in_one_sec = 1000 * plane_speed_per_ms + path.add_airport(airport(get_random_airport_shape(), Point(dist_in_one_sec, 0))) path.draw(self.screen, 0) - for station in path.stations: - station.draw(self.screen) - metro = Metro() - path.add_metro(metro) + for airport in path.airports: + airport.draw(self.screen) + plane = plane() + path.add_plane(plane) dt_ms = ceil(1000 / framerate) for _ in range(framerate): - path.move_metro(metro, dt_ms) + path.move_plane(plane, dt_ms) - self.assertTrue(path.stations[1].contains(metro.position)) + self.assertTrue(path.airports[1].contains(plane.position)) - def test_metro_turns_around_when_it_reaches_the_end(self): + def test_plane_turns_around_when_it_reaches_the_end(self): path = Path(get_random_color()) - path.add_station(Station(get_random_station_shape(), Point(0, 0))) - dist_in_one_sec = 1000 * metro_speed_per_ms - path.add_station(Station(get_random_station_shape(), Point(dist_in_one_sec, 0))) + path.add_airport(airport(get_random_airport_shape(), Point(0, 0))) + dist_in_one_sec = 1000 * plane_speed_per_ms + path.add_airport(airport(get_random_airport_shape(), Point(dist_in_one_sec, 0))) path.draw(self.screen, 0) - for station in path.stations: - station.draw(self.screen) - metro = Metro() - path.add_metro(metro) + for airport in path.airports: + airport.draw(self.screen) + plane = plane() + path.add_plane(plane) dt_ms = ceil(1000 / framerate) for _ in range(framerate + 1): - path.move_metro(metro, dt_ms) + path.move_plane(plane, dt_ms) - self.assertFalse(metro.is_forward) + self.assertFalse(plane.is_forward) - def test_metro_loops_around_the_path(self): + def test_plane_loops_around_the_path(self): path = Path(get_random_color()) - path.add_station(Station(get_random_station_shape(), Point(0, 0))) - dist_in_one_sec = 1000 * metro_speed_per_ms - path.add_station(Station(get_random_station_shape(), Point(dist_in_one_sec, 0))) - path.add_station( - Station(get_random_station_shape(), Point(dist_in_one_sec, dist_in_one_sec)) + path.add_airport(airport(get_random_airport_shape(), Point(0, 0))) + dist_in_one_sec = 1000 * plane_speed_per_ms + path.add_airport(airport(get_random_airport_shape(), Point(dist_in_one_sec, 0))) + path.add_airport( + airport(get_random_airport_shape(), Point(dist_in_one_sec, dist_in_one_sec)) ) - path.add_station(Station(get_random_station_shape(), Point(0, dist_in_one_sec))) + path.add_airport(airport(get_random_airport_shape(), Point(0, dist_in_one_sec))) path.set_loop() path.draw(self.screen, 0) - for station in path.stations: - station.draw(self.screen) - metro = Metro() - path.add_metro(metro) + for airport in path.airports: + airport.draw(self.screen) + plane = plane() + path.add_plane(plane) dt_ms = ceil(1000 / framerate) - for station_idx in [1, 2, 3, 0, 1]: + for airport_idx in [1, 2, 3, 0, 1]: for _ in range(framerate): - path.move_metro(metro, dt_ms) + path.move_plane(plane, dt_ms) - self.assertTrue(path.stations[station_idx].contains(metro.position)) + self.assertTrue(path.airports[airport_idx].contains(plane.position)) if __name__ == "__main__": diff --git a/test/test_station.py b/test/test_station.py index 17042238..e770caff 100644 --- a/test/test_station.py +++ b/test/test_station.py @@ -4,20 +4,20 @@ sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../src") -from entity.station import Station -from utils import get_random_position, get_random_station_shape +from entity.airport import airport +from utils import get_random_position, get_random_airport_shape -class TestStation(unittest.TestCase): +class Testairport(unittest.TestCase): def setUp(self) -> None: self.position = get_random_position(width=100, height=100) - self.shape = get_random_station_shape() + self.shape = get_random_airport_shape() def test_init(self): - station = Station(self.shape, self.position) + airport = airport(self.shape, self.position) - self.assertEqual(station.shape, self.shape) - self.assertEqual(station.position, self.position) + self.assertEqual(airport.shape, self.shape) + self.assertEqual(airport.position, self.position) if __name__ == "__main__":