python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化

Karli ·
更新时间:2024-11-10
· 671 次阅读

根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:

# -*- coding: utf-8 -*- import os import cv2 import gdal from osgeo import ogr,osr import numpy as np from skimage import morphology, color from skimage.segmentation import felzenszwalb, slic, quickshift from skimage.segmentation import mark_boundaries from skimage.util import img_as_float from skimage.future import graph def read_img(filename): dataset=gdal.Open(filename) im_width = dataset.RasterXSize im_height = dataset.RasterYSize im_geotrans = dataset.GetGeoTransform() im_proj = dataset.GetProjection() im_data = dataset.ReadAsArray(0,0,im_width,im_height) del dataset return im_width,im_height,im_proj,im_geotrans,im_data def write_img(filename,im_proj,im_geotrans,im_data): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape else: im_bands, (im_height, im_width) = 1,im_data.shape driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(filename, im_width, im_height, im_bands, datatype) dataset.SetGeoTransform(im_geotrans) dataset.SetProjection(im_proj) if im_bands == 1: dataset.GetRasterBand(1).WriteArray(im_data) else: for i in range(im_bands): dataset.GetRasterBand(i+1).WriteArray(im_data[i]) del dataset def DoesDriverHandleExtension(drv, ext): exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS) return exts is not None and exts.lower().find(ext.lower()) >= 0 def GetExtension(filename): ext = os.path.splitext(filename)[1] if ext.startswith('.'): ext = ext[1:] return ext def GetOutputDriversFor(filename): drv_list = [] ext = GetExtension(filename) for i in range(gdal.GetDriverCount()): drv = gdal.GetDriver(i) if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \ drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None: if ext and DoesDriverHandleExtension(drv, ext): drv_list.append(drv.ShortName) else: prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX) if prefix is not None and filename.lower().startswith(prefix.lower()): drv_list.append(drv.ShortName) return drv_list def GetOutputDriverFor(filename): drv_list = GetOutputDriversFor(filename) ext = GetExtension(filename) if not drv_list: if not ext: return 'ESRI Shapefile' else: raise Exception("Cannot guess driver for %s" % filename) elif len(drv_list) > 1: print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0])) return drv_list[0] def _weight_mean_color(graph, src, dst, n): """Callback to handle merging nodes by recomputing mean color. The method expects that the mean color of `dst` is already computed. Parameters ---------- graph : RAG The graph under consideration. src, dst : int The vertices in `graph` to be merged. n : int A neighbor of `src` or `dst` or both. Returns ------- data : dict A dictionary with the `"weight"` attribute set as the absolute difference of the mean color between node `dst` and `n`. """ diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color'] diff = np.linalg.norm(diff) return {'weight': diff} def merge_mean_color(graph, src, dst): """Callback called before merging two nodes of a mean color distance graph. This method computes the mean color of `dst`. Parameters ---------- graph : RAG The graph under consideration. src, dst : int The vertices in `graph` to be merged. """ graph.nodes[dst]['total color'] += graph.nodes[src]['total color'] graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count'] graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] / graph.nodes[dst]['pixel count']) if __name__ == '__main__': img_path = "E:/geo_test/test.tif" temp_path = "E:/geo_test/temp/" im_width,im_height,im_proj,im_geotrans,im_data = read_img(img_path) temp = im_data.transpose((2,1,0)) segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5) mark0 = mark_boundaries(temp, segments_quick) save_path = temp_path + "qs_seg0.tif" re0 = mark0.transpose((2,1,0)) write_img(save_path,im_proj,im_geotrans,re0) grid_path = temp_path + "qs_grid0.tif" grid0 = np.uint8(re0[0,...]) write_img(grid_path,im_proj,im_geotrans,grid0) skeleton = morphology.skeletonize(grid0) border0 = np.multiply(grid0, skeleton) ret,border0 = cv2.threshold(border0,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU) border_path = temp_path + "qs_border0.tif" write_img(border_path,im_proj,im_geotrans,border0) g = graph.rag_mean_color(temp, segments_quick) labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5, rag_copy=False, in_place_merge=True, merge_func=merge_mean_color, weight_func=_weight_mean_color) label_rgb2 = color.label2rgb(labels2, temp, kind='avg') rgb_path = temp_path + "qs_label.tif" lb = labels2.transpose((1,0)) write_img(rgb_path,im_proj,im_geotrans,lb) mark = mark_boundaries(label_rgb2, labels2) save_path = temp_path + "qs_seg.tif" re = mark.transpose((2,1,0)) write_img(save_path,im_proj,im_geotrans,re) grid_path = temp_path + "qs_grid.tif" grid = np.uint8(re[0,...]) write_img(grid_path,im_proj,im_geotrans,grid) skeleton = morphology.skeletonize(grid) border = np.multiply(grid, skeleton) ret,border = cv2.threshold(border,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU) border_path = temp_path + "qs_border.tif" write_img(border_path,im_proj,im_geotrans,border) # out_shp = temp_path + "temp.shp" # RasterToLineshp(border_path, out_shp, 1) border_driver = gdal.Open(rgb_path) border_band = border_driver.GetRasterBand(1) border_mask = border_band.GetMaskBand() dst_filename = temp_path + 'temp.shp' frmt = GetOutputDriverFor(dst_filename) drv = ogr.GetDriverByName(frmt) dst_ds = drv.CreateDataSource(dst_filename) dst_layername = 'out' srs = osr.SpatialReference(wkt=border_driver.GetProjection()) dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs) # dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs) dst_fieldname = 'DN' fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger) dst_layer.CreateField(fd) dst_field = 0 options = [""] options.append('DATASET_FOR_GEOREF=' + rgb_path) prog_func = gdal.TermProgress_nocb gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options, callback=prog_func) srcband = None src_ds = None dst_ds = None mask_ds = None # enum WKBGeometryType { # wkbPoint = 1, # wkbLineString = 2, # wkbPolygon = 3, # wkbTriangle = 17 # wkbMultiPoint = 4, # wkbMultiLineString = 5, # wkbMultiPolygon = 6, # wkbGeometryCollection = 7, # wkbPolyhedralSurface = 15, # wkbTIN = 16 # wkbPointZ = 1001, # wkbLineStringZ = 1002, # wkbPolygonZ = 1003, # wkbTrianglez = 1017 # wkbMultiPointZ = 1004, # wkbMultiLineStringZ = 1005, # wkbMultiPolygonZ = 1006, # wkbGeometryCollectionZ = 1007, # wkbPolyhedralSurfaceZ = 1015, # wkbTINZ = 1016 # wkbPointM = 2001, # wkbLineStringM = 2002, # wkbPolygonM = 2003, # wkbTriangleM = 2017 # wkbMultiPointM = 2004, # wkbMultiLineStringM = 2005, # wkbMultiPolygonM = 2006, # wkbGeometryCollectionM = 2007, # wkbPolyhedralSurfaceM = 2015, # wkbTINM = 2016 # wkbPointZM = 3001, # wkbLineStringZM = 3002, # wkbPolygonZM = 3003, # wkbTriangleZM = 3017 # wkbMultiPointZM = 3004, # wkbMultiLineStringZM = 3005, # wkbMultiPolygonZM = 3006, # wkbGeometryCollectionZM = 3007, # wkbPolyhedralSurfaceZM = 3015, # wkbTinZM = 3016, # }

对应的结果图如下:
原图:
原图
粗分割结果(代码中的qs_seg0.tif)
粗分割结果
粗分割格网(代码中的qs_grid0.tif)
粗分割格网
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
粗分割格网骨架
合并后的分割结果(代码中的qs_seg.tif):
合并后的粗分割结果
合并后的格网结果(代码中的qs_grid.tif)
合并后的格网结果
合并后的格网骨架结果(代码中的qs_border.tif):
合并后的格网骨架结果
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
矢量化以后的结果

TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。


作者:如雾如电



遥感影像 矢量 遥感 图像分割 矢量化 Python

需要 登录 后方可回复, 如果你还没有账号请 注册新账号