Source code for geographer.label_makers.seg_label_maker_soft_categorical
"""Label maker for soft-categorical segmentation labels.Soft-categorical are probabilistic multi-class labels."""importloggingimportnumpyasnpimportrasterioasriofromrasterio.featuresimportrasterizefromgeographer.connectorimportConnectorfromgeographer.label_makers.seg_label_maker_baseimportSegLabelMakerfromgeographer.utils.utilsimporttransform_shapely_geometrylog=logging.getLogger(__name__)
[docs]classSegLabelMakerSoftCategorical(SegLabelMaker):"""Label maker for soft-categorical segmentation labels. Soft-categorical are probabilistic multi-class labels. Assumes the connector's vectors contains for each segmentation class a "prob_seg_class<seg_class>" column containing the probabilities for that class. """add_background_band:bool@propertydeflabel_type(self):"""Return label_type."""return"soft-categorical"def_make_label_for_raster(self,connector:Connector,raster_name:str,)->None:"""Create (pixel) label for a raster. Args: connector : calling Connector raster_name: name of raster for which a label should be created Returns: None: """# pathsraster_path=connector.rasters_dir/raster_namelabel_path=connector.labels_dir/raster_name# If the raster does not exist ...ifnotraster_path.is_file():# ... log error to file.log.error("_make_geotif_label_soft_categorical: input raster %s does not exist!",raster_path,)# Else, if the label already exists ...eliflabel_path.is_file():# ... log error to file.log.error("_make_geotif_label_soft_categorical: label %s already exists!",label_path,)# Else, ...else:label_bands_count=self._get_label_bands_count(connector)# ...open the raster, ...withrio.open(raster_path)assrc:# Create profile for the label.profile=src.profileprofile.update({"count":label_bands_count,"dtype":rio.float32})# Open the label ...withrio.open(label_path,"w+",**profile)asdst:# ... and create one band in the label for each segmentation class.# (if an implicit background band is to be included,# it will go in band/channel 1.)start_band=1ifnotself.add_background_bandelse2forcount,seg_classinenumerate(connector.task_vector_classes,start=start_band):# To do that, first find (the df of)# the geoms intersecting the raster ...vectors_intersecting_raster_df=connector.vectors.loc[connector.vectors_intersecting_raster(raster_name)]# ... extract the geometries ...vector_geoms_in_std_crs=list(vectors_intersecting_raster_df["geometry"])# ... and convert them to the crs of the source raster.vector_geoms_in_src_crs=list(map(lambdageom:transform_shapely_geometry(geom,connector.vectors.crs.to_epsg(),src.crs.to_epsg(),),vector_geoms_in_std_crs,))# Extract the class probabilities ...class_probabilities=list(vectors_intersecting_raster_df[f"prob_of_class_{seg_class}"])# .. and combine with the geometries# to a list of (geometry, value) pairs.geom_value_pairs=list(zip(vector_geoms_in_src_crs,class_probabilities))# If there are no geoms of seg_type intersecting the raster ...iflen(vector_geoms_in_src_crs)==0:# ... the label raster is empty.mask=np.zeros((src.height,src.width),dtype=np.uint8)# Else, burn the values for those geoms into the band.else:mask=rasterize(shapes=geom_value_pairs,# or the other way around?out_shape=(src.height,src.width),fill=0.0,#transform=src.transform,dtype=rio.float32,)# Write the band to the label file.dst.write(mask,count)# If the background is not included in the segmentation classes ...ifself.add_background_band:# ... add background band.non_background_band_indices=list(range(start_band,2+len(connector.task_vector_classes),))# The probability of a pixel belonging to# the background is the complement of it# belonging to some segmentation class.background_band=1-np.add.reduce([dst.read(band_index)forband_indexinnon_background_band_indices])dst.write(background_band,1)def_get_label_bands_count(self,connector:Connector)->bool:# If the background is not included in the segmentation classes (default) ...ifself.add_background_band:# ... add a band for the implicit background segmentation class, ...label_bands_count=1+len(connector.task_vector_classes)# ... if the background *is* included, ...elifnotself.add_background_band:# ... don't.label_bands_count=len(connector.task_vector_classes)returnlabel_bands_countdef_run_safety_checks(self,connector:Connector):"""Run safety checks. Check existence of 'prob_of_class_<class name>' columns in connector.vectors. """# check required columns existrequired_cols={f"prob_of_class_{class_}"forclass_inconnector.all_vector_classes}ifnotset(required_cols)<=set(connector.vectors.columns):missing_cols=set(required_cols)-set(connector.vectors.columns)raiseValueError("connector.vectors.columns is missing required columns: "f"{', '.join(missing_cols)}")# check no other columns will be mistaken forvector_classes_in_vectors={col_name[(1+len("prob_of_class")):]# noqa: E203forcol_nameinconnector.vectors.columnsifcol_name.startswith("prob_of_class_")}ifnotvector_classes_in_vectors<=set(connector.all_vector_classes):log.warning("Ignoring columns: %s. The corresponding classes ""are not in connector.all_vector_classes",vector_classes_in_vectors-set(connector.all_vector_classes),)