diff --git a/04_ml.ipynb b/04_ml.ipynb index efdd023..a0f608b 100644 --- a/04_ml.ipynb +++ b/04_ml.ipynb @@ -363,7 +363,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Highway2Vec Clustering and similarity search\n", + "## Highway2Vec - Clustering and similarity search\n", "\n", "In this part we will see:\n", "" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from srai.loaders import OSMWayLoader, OSMNetworkType\n", + "from srai.regionalizers import H3Regionalizer, geocode_to_region_gdf\n", + "from srai.joiners import IntersectionJoiner\n", + "from srai.embedders import Highway2VecEmbedder\n", + "\n", + "area = geocode_to_region_gdf(\"Wrocław, Poland\")\n", + "nodes, edges = OSMWayLoader(OSMNetworkType.DRIVE).load(area)\n", + "regions = H3Regionalizer(resolution=9).transform(area) \n", + "joint = IntersectionJoiner().transform(regions, edges)\n", + "\n", + " \n", + "embedder = Highway2VecEmbedder()\n", + "embedder.fit(regions, edges, joint)\n", + "embeddings = embedder.transform(regions, edges, joint)\n", + "embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from clustering import (\n", + " scale_embeddings,\n", + " generate_clustering_model,\n", + " generate_linkage_matrix,\n", + " plot_dendrogram,\n", + " cluster_regions,\n", + " plot_clustered_regions_with_roads\n", + ")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "embeddings_scaled = scale_embeddings(embeddings)\n", + "ac_model = generate_clustering_model(\n", + " embeddings_scaled,\n", + " {\n", + " \"n_clusters\": None,\n", + " \"distance_threshold\": 0,\n", + " \"metric\": \"euclidean\",\n", + " \"linkage\": \"ward\",\n", + " },\n", + ")\n", + "\n", + "linkage_matrix = generate_linkage_matrix(ac_model)\n", + "plot_dendrogram(linkage_matrix, {\"truncate_mode\": \"level\", \"p\": 3})\n", + "plt.show()\n", + "clusters = [6]\n", + "regions_clustered = cluster_regions(\n", + " linkage_matrix, embeddings, regions, clusters #[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]\n", + ")\n", + "plot_clustered_regions_with_roads(regions_clustered, edges, area, clusters)\n" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..deda0c6 --- /dev/null +++ b/clustering.py @@ -0,0 +1,140 @@ +""" +This is a boilerplate pipeline 'visualizations' +generated using Kedro 0.18.7 +""" +from typing import Any, Dict, List, Tuple + +import contextily as ctx +import geopandas as gpd +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib.figure import Figure +from scipy.cluster.hierarchy import cut_tree, dendrogram +from sklearn.cluster import AgglomerativeClustering +from sklearn.preprocessing import StandardScaler +from tqdm.auto import tqdm + + +MAP_SOURCE = ctx.providers.CartoDB.Positron +MATPLOTLIB_COLORMAP = "tab20" +PLOTLY_COLORMAP = list( + map( + lambda color: f"rgb{tuple(map(lambda color_compound: color_compound * 255, color))}", + matplotlib.colormaps[MATPLOTLIB_COLORMAP].colors, + ) +) + + +def scale_embeddings(embeddings: pd.DataFrame) -> pd.DataFrame: + return pd.DataFrame( + StandardScaler().fit_transform(embeddings), + index=embeddings.index, + columns=embeddings.columns, + ) + + +def generate_clustering_model( + embeddings: pd.DataFrame, clustering_params: Dict[str, Any] +): + model = AgglomerativeClustering( + n_clusters=clustering_params["n_clusters"], + distance_threshold=clustering_params["distance_threshold"], + metric=clustering_params["metric"], + linkage=clustering_params["linkage"], + ) + model.fit(embeddings) + + return model + + +def generate_linkage_matrix(model: AgglomerativeClustering) -> np.ndarray: + counts = np.zeros(model.children_.shape[0]) + n_samples = len(model.labels_) + for i, merge in enumerate(model.children_): + current_count = 0 + for child_idx in merge: + if child_idx < n_samples: + current_count += 1 # leaf node + else: + current_count += counts[child_idx - n_samples] + counts[i] = current_count + + linkage_matrix = np.column_stack( + [model.children_, model.distances_, counts] + ).astype(float) + + return linkage_matrix + + +def plot_dendrogram( + linkage_matrix: np.ndarray, dendrogram_params: Dict[str, Any] +) -> Figure: + fig, _ = plt.subplots(figsize=(12, 7)) + plt.xlabel("Number of microregions") + dendrogram(linkage_matrix, **dendrogram_params) + plt.tight_layout() + return fig + + +def cluster_regions( + linkage_matrix: np.ndarray, + embeddings: gpd.GeoDataFrame, + regions: gpd.GeoDataFrame, + clusters: List[int], +) -> gpd.GeoDataFrame: + regions_clustered = regions.loc[embeddings.index, :] + + cut_tree_results = cut_tree(linkage_matrix, n_clusters=clusters) + for index, c in tqdm(list(enumerate(clusters))): + assigned_clusters = cut_tree_results[:, index] + regions_clustered[f"cluster_{c}"] = pd.Series( + assigned_clusters, index=regions_clustered.index + ).astype("category") + + return regions_clustered + + +def plot_clustered_regions_with_roads( + regions_clustered: gpd.GeoDataFrame, + roads: gpd.GeoDataFrame, + area: gpd.GeoDataFrame, + clusters: List[int], +) -> Dict[str, Figure]: + plots = {} + for c in clusters: + cluster_column = f"cluster_{c}" + fig, ax = _pyplot_clustered_regions_with_roads( + regions_clustered.sjoin(area), + roads.sjoin(area), + cluster_column, + title=cluster_column, + ) + ax.set_axis_off() + plt.tight_layout() + plots[cluster_column] = fig + # plt.close() + + return plots + + +def _pyplot_clustered_regions_with_roads( + regions: gpd.GeoDataFrame, roads: gpd.GeoDataFrame, column: str, title: str = "" +) -> Tuple[Figure, plt.Axes]: + fig, ax = plt.subplots(figsize=(10, 9)) + ax.set_aspect("equal") + ax.set_title(title) + regions.to_crs(epsg=3857).plot( + column=column, + ax=ax, + alpha=0.9, + legend=True, + cmap=MATPLOTLIB_COLORMAP, + vmin=0, + vmax=len(PLOTLY_COLORMAP), + linewidth=0, + ) + roads.to_crs(epsg=3857).plot(ax=ax, color="black", alpha=0.5, linewidth=0.2) + ctx.add_basemap(ax, source=MAP_SOURCE) + return fig, ax diff --git a/requirements.txt b/requirements.txt index 0b10781..b90742c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,7 @@ geopandas==0.13.2 notebook>=6,<7 # RISE not compatible with version 7 RISE==5.7.1 osmnx==1.6.0 +contextily==1.3.0 +scikit-learn==1.3.0 +tqdm==4.65.0 +matplotlib==3.7.2 \ No newline at end of file