Source code for geographer.utils.cluster_rasters

"""Cluster rasters.

Given a dataset and an optional list of rasters partition the rasters
into equivalence classes ('clusters') that need to be respected when
generating the train-validation split.
"""

from __future__ import annotations

import itertools
from pathlib import Path
from typing import Any, Literal, Tuple

import networkx as nx
import pandas as pd
from geopandas import GeoDataFrame
from networkx import Graph

from geographer import Connector
from geographer.utils.utils import deepcopy_gdf


[docs] def get_raster_clusters( connector: Connector | Path | str, clusters_defined_by: Literal[ "rasters_that_share_vectors", "rasters_that_share_vectors_or_overlap", ], raster_names: list[str] | None = None, preclustering_method: Literal["x then y-axis", "y then x-axis", "x-axis", "y-axis"] | None = "y then x-axis", # TODO!!!!!!!!!! ) -> list[set[str]]: """Return clusters of raster. Args: connector: connector or path or str to data dir containing connector clusters_defined_by: relation between rasters defining clusters raster_names: optional list of raster names preclustering_method: optional preclustering method to speed up clustering Returns: (names of rasters defining) clusters """ allowed_clusters_defined_by_args = { "rasters_that_share_vectors", "rasters_that_share_vectors_or_overlap", } if clusters_defined_by not in allowed_clusters_defined_by_args: raise ValueError(f"Unknown clusters_defined_by arg: {clusters_defined_by}") if not isinstance(connector, Connector): connector = Connector.from_data_dir(connector) if raster_names is None: raster_names = connector.rasters.index.tolist() if preclustering_method is None: preclusters = [set(raster_names)] singletons, non_singletons = [], preclusters elif preclustering_method in {"x-axis", "y-axis"}: axis = preclustering_method[0] # 'x' or 'y' geoms = _get_preclustering_geoms(connector=connector, raster_names=raster_names) preclusters = _pre_cluster_along_axis(geoms, axis) singletons, non_singletons = _separate_non_singletons(preclusters) elif preclustering_method in {"x then y-axis", "y then x-axis"}: first_axis = preclustering_method[0] second_axis = "y" if first_axis == "x" else "x" # cluster along first axis geoms = _get_preclustering_geoms(connector=connector, raster_names=raster_names) preclusters = _pre_cluster_along_axis(geoms, first_axis) # cluster along 2nd axis singletons, non_singletons = _refine_preclustering_along_second_axis( preclusters, second_axis, connector ) else: raise ValueError(f"Unknown preclustering_method: {preclustering_method}") # build graph raster_clusters = singletons for non_singleton in non_singletons: graph_of_non_singleton = _extract_graph_of_rasters( connector=connector, clusters_defined_by=clusters_defined_by, raster_names=list(non_singleton), ) raster_clusters += list(nx.connected_components(graph_of_non_singleton)) return raster_clusters
# TODO: rename to refine pre clustering? def _refine_preclustering_along_second_axis( preclusters: list[set[str]], second_axis: Literal["x", "y"], connector: Connector ) -> Tuple[list[set[str]], list[set[str]]]: """Refine preclustering along the second axis. Args: preclusters: preclusters second_axis: name of second axis along which to refine pre-clustering Returns: singleton and non-singleton pre-clusters """ singletons, preclusters_along_2nd_axis = [], [] for precluster in preclusters: if len(precluster) == 1: singletons.append(precluster) else: precluster_geoms = _get_preclustering_geoms( connector=connector, raster_names=list(precluster) ) refined_precluster = _pre_cluster_along_axis(precluster_geoms, second_axis) preclusters_along_2nd_axis += refined_precluster additional_singletons, non_singletons = _separate_non_singletons( preclusters_along_2nd_axis ) singletons += additional_singletons return singletons, non_singletons def _get_preclustering_geoms( connector: Connector, raster_names: list[str] ) -> GeoDataFrame: # raster geoms rasters = deepcopy_gdf(connector.rasters[["geometry"]].loc[raster_names]) rasters["name"] = rasters.index rasters["raster_or_polygon"] = "raster" raster_names_set = set(raster_names) del raster_names # determine polygons that overlap w several rasters polygons_overlapping_rasters = [] for polygon_name in connector.vectors.index: rasters_intersect_but_dont_contain_polygon = set( connector.rasters_intersecting_vector(polygon_name) ) - set(connector.rasters_containing_vector(polygon_name)) if ( len(rasters_intersect_but_dont_contain_polygon) >= 2 and raster_names_set & rasters_intersect_but_dont_contain_polygon != set() ): polygons_overlapping_rasters.append(polygon_name) # geoms for those polygons vectors = deepcopy_gdf( connector.vectors.loc[polygons_overlapping_rasters][["geometry"]] ) vectors["name"] = vectors.index vectors["raster_or_polygon"] = "polygon" # make sure there are no duplicate names assert set(vectors["name"]) & set(rasters["name"]) == set() # combine geoms geoms = GeoDataFrame( pd.concat([rasters, vectors]), crs=rasters.crs, geometry="geometry" ) # don't need ? geoms = deepcopy_gdf(geoms) # TODO: don't recompute the bounds when we cluster along 2 axes if not {"minx", "miny", "maxx", "maxy"} <= set(geoms.columns): geoms.drop( columns=["minx", "miny", "maxx", "maxy"], errors="ignore", inplace=True ) geoms = GeoDataFrame( pd.concat([geoms, geoms.geometry.bounds], axis=1), # column axis crs=geoms.crs, geometry="geometry", ) return geoms def _separate_non_singletons( preclusters: list[set[Any]], ) -> tuple[list[set[Any]], list[set[Any]]]: singletons, non_singletions = [], [] for precluster in preclusters: if len(precluster) == 1: singletons.append(precluster) else: non_singletions.append(precluster) return singletons, non_singletions # simpler version def _pre_cluster_along_axis( geoms: GeoDataFrame, axis: Literal["x", "y"] ) -> list[set[str]]: if axis not in {"x", "y"}: raise ValueError("axis arg should be one of 'x', 'y'.") mins = geoms[["name", f"min{axis}", "raster_or_polygon"]].copy() mins["type"] = "min" mins = mins.rename(columns={f"min{axis}": "value"}) mins = mins.to_dict(orient="records") # mins = list(mins.itertuples()) maxs = geoms[["name", f"max{axis}", "raster_or_polygon"]].copy() maxs["type"] = "max" maxs = maxs.rename(columns={f"max{axis}": "value"}) maxs = maxs.to_dict(orient="records") interval_endpoints = sorted( mins + maxs, key=lambda d: (d["value"], 0 if d["type"] == "min" else 1) ) # tuple comparison ensures mins are smaller than maxes for the same values # (-> smaller clusters) raster_clusters_along_axis = [] while interval_endpoints != []: rightmost_endpoint = interval_endpoints.pop() assert rightmost_endpoint["type"] == "max" current_cluster = ( set() ) # Should this be {interval_endpoints.pop()}? No, think I'm fine... entered_intervals_count = 1 exited_intervals_count = 0 while entered_intervals_count - exited_intervals_count > 0: next_smaller_endpoint = interval_endpoints.pop() if next_smaller_endpoint["raster_or_polygon"] == "raster": current_cluster.add(next_smaller_endpoint["name"]) if next_smaller_endpoint["type"] == "max": entered_intervals_count += 1 elif next_smaller_endpoint["type"] == "min": exited_intervals_count += 1 else: raise Exception("something is wrong") raster_clusters_along_axis.append(current_cluster) return raster_clusters_along_axis def _extract_graph_of_rasters( connector: Connector, clusters_defined_by: str, raster_names: list[str] = None, ) -> Graph: """Extract graph of rasters determined by clusters_defined_by.""" raster_graph = Graph() raster_graph.add_nodes_from(raster_names) # add edges to graph pairs_of_rasters = itertools.combinations(raster_names, 2) are_connected = lambda s: _are_connected_by_an_edge( # noqa: E731 *s, clusters_defined_by=clusters_defined_by, connector=connector, ) pairs_of_connected_rasters = filter(are_connected, pairs_of_rasters) raster_graph.add_edges_from(pairs_of_connected_rasters) return raster_graph def _are_connected_by_an_edge( raster: str, another_raster: str, clusters_defined_by: str, connector: Connector, ) -> bool: """Return True if rasters are connected, else False. Return True if there is an edge in the graph of rasters determined by the clusters_defined_by relation, else return False. """ raster_bbox = connector.rasters.loc[raster].geometry other_raster_bbox = connector.rasters.loc[another_raster].geometry if clusters_defined_by == "rasters_that_overlap": connected = raster_bbox.intersects(other_raster_bbox) elif clusters_defined_by == "rasters_that_share_vectors": vectors_in_raster = set(connector.vectors_intersecting_raster(raster)) vectors_in_other_raster = set( connector.vectors_intersecting_raster(another_raster) ) connected = vectors_in_raster & vectors_in_other_raster != set() elif clusters_defined_by == "rasters_that_share_vectors_or_overlap": connected_bc_rasters_overlap = _are_connected_by_an_edge( raster, another_raster, "rasters_that_overlap", connector ) connected_bc_of_shared_polygons = _are_connected_by_an_edge( raster, another_raster, "rasters_that_share_vectors", connector ) connected = connected_bc_rasters_overlap or connected_bc_of_shared_polygons else: raise ValueError(f"Unknown clusters_defined_by arg: {clusters_defined_by}") return connected