"""Convert a dataset of GeoTiffs to NPYs."""importloggingfromtypingimportLiteralimportnumpyasnpimportrasterioasriofrompydanticimportFieldfromtqdm.autoimporttqdmfromgeographer.creator_from_source_dataset_baseimportDSCreatorFromSourceWithBandsfromgeographer.raster_bands_getter_mixinimportRasterBandsGetterMixInlog=logging.Logger(__name__)
[docs]classDSConverterGeoTiffToNpy(DSCreatorFromSourceWithBands,RasterBandsGetterMixIn):"""Convert a dataset of GeoTiffs to NPYs."""squeeze_label_channel_dim_if_single_channel:bool=Field(default=True,description="whether to squeeze the label channel dim/axis if possible",)channels_first_or_last_in_npy:Literal["last","first"]=Field(default="last",description="Ignoring squeezing: 'last' -> (height, width, channels), ""'first' -> (channels, height, width).",)def_create(self):self._create_or_update()def_update(self):self._create_or_update()def_create_or_update(self)->None:# need this latergeoms_that_will_be_added_to_target_dataset=set(self.source_assoc.geoms_df.index)-set(self.target_assoc.geoms_df.index)# build npy associatornpy_rasters=self._get_npy_rasters()self.target_assoc.add_to_rasters(npy_rasters)self.target_assoc.add_to_geoms_df(self.source_assoc.geoms_df)# Determine which rasters to copy to target datasetrasters_that_already_existed_in_target_rasters_dir={raster_path.nameforraster_pathinself.target_assoc.rasters_dir.iterdir()}# For each raster that already existed in the target dataset ...forraster_nameinrasters_that_already_existed_in_target_rasters_dir:# ... if among the (vector) geometries intersecting# it in the target dataset ...geoms_intersecting_raster=set(self.target_assoc.geoms_intersecting_raster(raster_name))# ... there is a *new* (vector) geometry ...if(geoms_intersecting_raster&geoms_that_will_be_added_to_target_dataset!=set()):# ... then we need to update the label for it,# so we delete the current label.(self.target_assoc.labels_dir/raster_name).unlink(missing_ok=True)# For the rasters_dir and labels_dir of the source tif# and target npy dataset ...fortif_dir,npy_dirinzip(self.source_assoc.raster_data_dirs,self.target_assoc.raster_data_dirs):# ... go through all tif files. ...fortif_raster_nameintqdm(self.source_assoc.rasters.index,desc=f"Converting {tif_dir.name}"):# If the corresponding npy in the target raster data dir# does not exist ...ifnot(npy_dir/self._npy_filename_from_tif(tif_raster_name)).is_file():# ... convert the tif: Open the tif file ...withrio.open(tif_dir/tif_raster_name)assrc:raster_bands=self._get_bands_for_raster(self.bands,tif_dir/tif_raster_name,)# extract bands to list of arraysseq_extracted_np_bands=[src.read(band)forbandinraster_bands]# new raster pathnew_npy_raster_path=npy_dir/self._npy_filename_from_tif(tif_raster_name)# axis along which to stackifself.channels_first_or_last_in_npy=="last":axis=2else:# 'first'axis=0# stack band arrays into single tensornp_raster=np.stack(seq_extracted_np_bands,axis=axis)# squeeze np_raster if necessaryif(str(tif_dir.name)=="labels"andself.squeeze_label_channel_dim_if_single_channel):iflen(raster_bands)==1:np_raster=np.squeeze(np_raster,axis=axis)# save numpy arraynp.save(new_npy_raster_path,np_raster)# ... and save the associator.self.target_assoc.save()self.save()returnself.target_assocdef_get_npy_rasters(self):npy_rasters=self.source_assoc.rasters# the line after the next destroys the index name of tif_assoc.rasters,# so we remember it ...tif_assoc_rasters_index_name=self.source_assoc.rasters.index.nametif_raster_name_list=(npy_rasters.index.tolist().copy())# (it's either the .tolist() or .copy() operation, don't understand why)npy_rasters.index=list(map(self._npy_filename_from_tif,tif_raster_name_list))# ... and then set it by handnpy_rasters.index.name=tif_assoc_rasters_index_namereturnnpy_rasters@staticmethoddef_npy_filename_from_tif(tif_filename:str)->str:"""Return .npy filename from .tif filename."""returntif_filename[:-4]+".npy"@staticmethoddef_check_rasters_are_tifs(raster_names:list[str]):"""Make sure all rasters are GeoTiffs."""non_tif_rasters=list(filter(lambdas:nots.endswith(".tif"),raster_names,))ifnon_tif_rasters:raiseValueError("Only works with dataset of GeoTiff rasters!")