车道线检测

[复制链接]
查看872 | 回复0 | 2023-8-23 12:02:45 | 显示全部楼层 |阅读模式
前言

现在,车道线检测技术已经相称成熟,重要应用在自动驾驶、智能交通等范畴。下面列举一些当下最流行的车道线检测方法:


  • 基于图像处理的车道线检测方法。该方法是通过图像处理技术从摄像头传回的图像中提取车道线信息的一种方法,重要是使用图像处理算法进行车道线的检测和辨认,并输出车道线的位置信息。
  • 基于激光雷达的车道线检测方法。该方法通过激光雷达扫描地面,获取车道线位置信息。这种方法对于在光照较弱、天气恶劣的情况下车道线能更加准确地被检测出来。
  • 基于雷达与摄像头的融合车道线检测方法。此方法是将雷达和摄像头两个传感器的数据进行融合,从而得到更加准确的车道线位置信息,检测的鲁棒性也得到了进步。
  • 基于GPS和舆图的车道线检测方法。该方法重要是使用车辆上的GPS以及舆图数据来检测车道线的位置信息。这种方法可以降服图像处理技术在某些特定情况下(好比光照不敷大概情况光线无法反射)的不敷。
以上这些方法均存在优缺点,差异方法的选择重要取决于详细的技术需求和场景应用。在该章节以基于图像处理的车道线检测方法进行介绍。分别从底子的入门级方法到深度学习的方法进行介绍。
传统图像方法
传统图像方法通过边沿检测滤波等方式分割出车道线地区,然后连合霍夫变换、RANSAC等算法进行车道线检测。这类算法需要人工手动去调滤波算子,根据算法所针对的街道场景特点手动调节参数曲线,工作量大且鲁棒性较差,当行车情况出现明显变革时,车道线的检测效果不佳。基于蹊径特性检测根据提取特性差异,分为:基于颜色特性、纹理特性、多特性融合;
比方:在车道图像中,路面与车道线交汇处的灰度值变换剧烈,使用边沿增强算子突出图像的局部边沿,界说像素的边沿强度,设置阈值方法提取边沿点;
常用的算子:Sobel算子、Prewitt算子、Log算子、Canny算子;
基于灰度特性检测布局简朴,对于路面平整、车道线清楚的布局化蹊径尤为实用;但当光照强烈、有大量异物遮挡、蹊径布局复杂、车道线较为模糊时,检测效果受到很大的影响;
使用openCV的传统算法



  • Canny边沿检测
  • 高斯滤波
  • ROI和mask
  • 霍夫变换
openCV在图片和视频干系常用的代码

  1. # 读取图片 cv2.imread第一个参数是窗口的名字,第二个参数是读取格式(彩色或灰度)
  2. cv2.imread(const String & filename, int flags = IMREAD_COLOR)
  3. #显示图片 cv2.imshow第一个参数是窗口的名字  第二个参数是显示格式,
  4. cv2.imshow(name, img)       
  5. #保持图片
  6. cv2.imwrite(newfile_name, img)
  7. #关闭所有openCV打开的窗口。
  8. cv2.destroyAllWindows()
  9. #-----------------------------#
  10. #打开视频
  11. capture = cv2.VideoCapture('video.mp4')
  12. #按帧读取视频
  13. #capture.read()有两个返回值。其中ret是布尔值,如果读取帧是正确的则返回True,如果文件读取到结尾,它的返回值就为False。frame就是每一帧的图像,是个三维矩阵。
  14. ret, frame = capture.read()
  15. #视频编码格式设置
  16. fourcc = cv2.VideoWriter_fourcc('X', 'V', 'I', 'D')
  17. """
  18. 补充:cv2.VideoWriter_fourcc(‘I’, ‘4’, ‘2’, ‘0’),该参数是YUV编码类型,文件名后缀为.avi
  19. cv2.VideoWriter_fourcc(‘P’, ‘I’, ‘M’, ‘I’),该参数是MPEG-1编码类型,文件名后缀为.avi
  20. cv2.VideoWriter_fourcc(‘X’, ‘V’, ‘I’, ‘D’),该参数是MPEG-4编码类型,文件名后缀为.avi
  21. cv2.VideoWriter_fourcc(‘T’, ‘H’, ‘E’, ‘O’),该参数是Ogg Vorbis,文件名后缀为.ogv
  22. cv2.VideoWriter_fourcc(‘F’, ‘L’, ‘V’, ‘1’),该参数是Flash视频,文件名后缀为.flv
  23. """
复制代码
高斯滤波

Canny边沿检测

有关边沿检测也是计算机视觉。起首使用梯度变革来检测图像中的边,怎样辨认图像的梯度变革呢,答案是卷积核。卷积核是就是不连续的像素上找到梯度变革较大位置。我们知道 sobal 核可以很好检测边沿,那么 canny 就是 sobal 核检测上进行优化。
CV2提供了提取图像边沿的函数canny。其算法头脑如下:

  • 使用高斯模糊,去除噪音点(cv2.GaussianBlur)
  • 灰度转换(cv2.cvtColor)
  • 使用sobel算子,计算出每个点的梯度巨细和梯度方向
  • 使用非极大值抑制(只有最大的保留),消除边沿检测带来的杂散效应
  • 应用双阈值,来确定真实和潜伏的边沿
  • 通过抑制弱边沿来完成终极的边沿检测
  1. #color_img 输入图片
  2. #gaussian_ksize 高斯核大小,可以为方形矩阵,也可以为矩形
  3. #gaussian_sigmax X方向上的高斯核标准偏差
  4. gaussian = cv2.GaussianBlur(color_img, (gaussian_ksize,gaussian_ksize), gaussian_sigmax)
  5. #用于颜色空间转换。input_image为需要转换的图片,flag为转换的类型,返回值为颜色空间转换后的图片矩阵。
  6. #flag对应:
  7. #cv2.COLOR_BGR2GRAY BGR -> Gray
  8. #cv2.COLOR_BGR2RGB BGR -> RGB
  9. #cv2.COLOR_BGR2HSV BGR -> HSV
  10. gray_img = cv2.cvtColor(input_image, flag)
复制代码
输出结果:

  1. #imag为所操作的图片,threshold1为下阈值,threshold2为上阈值,返回值为边缘图。
  2. edge_img = cv2.Canny(gray_img,canny_threshold1,canny_threshold2)
  3. #整成canny_edge_detect的方法
  4. def canny_edge_detect(img):
  5.     gray = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
  6.     kernel_size = 5
  7.     blur_gray = cv2.GaussianBlur(gray,(kernel_size, kernel_size),0)
  8.     low_threshold = 180
  9.     high_threshold = 240
  10.     edges = cv2.Canny(blur_gray, low_threshold, high_threshold)
  11.     return edges
复制代码

ROI and mask

在机器视觉或图像处理中,通过图像获取到的信息通常是一个二维数组或矩阵,这些信息中大概包罗需要进一步处理的地区以及不需要处理的地区。为了进步图像处理的服从和准确性,通常会在需要处理的地区内界说一个感爱好的地区(ROI),并对该地区进行下一步的处理。ROI可以通过方框、圆、椭圆、不规则多边形等方式勾勒出需要处理的地区。在机器视觉软件中,常常通过图像处理算子和函数来计算ROI。好比,在OpenCV中可以使用cv::Rect、cv::RotatedRect、cv:oint等进行ROI的范例界说和计算;在Matlab中可以使用imrect、imellipse、impoly等函数实现ROI的界说和计算。
处理ROI的目的是为了便于在图像的地区中进行目的检测、物体跟踪、边沿检测、图像分割等操作。通过使用ROI,可以将不需要处理的地区从原始图像中清撤除,从而减少图像处理的复杂度和耗时,进步计算服从和准确性。

  1. #设置ROI和掩码
  2. poly_pts = numpy.array([[[0,368],[300,210],[340,210],[640,368]]])
  3. mask = np.zeros_like(gray_img)
  4. cv2.fillPoly(mask, pts, color)
  5. img_mask = cv2.bitwise_and(gray_img, mask)
复制代码
霍夫变换

霍夫变换是一种常用的图像处理算法,用于在图像中检测几何形状(如直线、圆、椭圆等)。该算法最初是由保罗·霍夫于 1962 年提出的。简朴来说,霍夫变换可以将在直角坐标系中表示的图形转换为极坐标系中的线或曲线,从而方便地进行形状的检测和辨认。所以霍夫变换实际上一种由繁到简(类似降维)的操作。在应用中,霍夫变换的过程可以分为以下几个步调:

  • 针对待检测的形状,选择相应的霍夫曼变换方法。好比,如果要检测直线,可以使用尺度霍夫变换;如果要检测圆形,可以使用圆霍夫变换。
  • 将图像转换为灰度图像,并进行边沿检测,以得到待检测的形状目的的轮廓。
  • 以肯定的步长和角度范围在霍夫空间中进行投票,将所有大概的直线或曲线与它们大概在的极坐标空间中的位置相对应。
  • 找到霍夫空间中的峰值,这些峰值表示形状的参数空间中存在原始图像中形状的大概性。
  • 通过峰值位置在原始图像中绘制直线、圆等形状。
当使用 canny 进行边沿检测后图像可以交给霍夫变换进行简朴图形(线、圆)等的辨认。这里用霍夫变换在 canny 边沿检测结果中寻找直线。
  1. # 示例代码,作者丹成学长:Q746876041
  2. mask = np.zeros_like(edges)
  3.     ignore_mask_color = 255
  4.     # 获取图片尺寸
  5.     imshape = img.shape
  6.     # 定义 mask 顶点
  7.     vertices = np.array([[(0,imshape[0]),(450, 290), (490, 290), (imshape[1],imshape[0])]], dtype=np.int32)
  8.     # 使用 fillpoly 来绘制 mask
  9.     cv2.fillPoly(mask, vertices, ignore_mask_color)
  10.     masked_edges = cv2.bitwise_and(edges, mask)
  11.     # 定义Hough 变换的参数
  12.     rho = 1
  13.     theta = np.pi/180
  14.     threshold = 2
  15.     min_line_length = 4 # 组成一条线的最小像素数
  16.     max_line_gap = 5    # 可连接线段之间的最大像素间距
  17.     # 创建一个用于绘制车道线的图片
  18.     line_image = np.copy(img)*0
  19.     # 对于 canny 边缘检测结果应用 Hough 变换
  20.     # 输出“线”是一个数组,其中包含检测到的线段的端点
  21.     lines = cv2.HoughLinesP(masked_edges, rho, theta, threshold, np.array([]),
  22.                                 min_line_length, max_line_gap)
  23.     # 遍历“线”的数组来在 line_image 上绘制
  24.     for line in lines:
  25.         for x1,y1,x2,y2 in line:
  26.             cv2.line(line_image,(x1,y1),(x2,y2),(255,0,0),10)
  27.     color_edges = np.dstack((edges, edges, edges))
  28. import math
  29. import cv2
  30. import numpy as np
  31. """
  32. Gray Scale
  33. Gaussian Smoothing
  34. Canny Edge Detection
  35. Region Masking
  36. Hough Transform
  37. Draw Lines [Mark Lane Lines with different Color]
  38. """
  39. class SimpleLaneLineDetector(object):
  40.     def __init__(self):
  41.         pass
  42.     def detect(self,img):
  43.         # 图像灰度处理
  44.         gray_img = self.grayscale(img)
  45.         print(gray_img)
  46.         #图像高斯平滑处理
  47.         smoothed_img = self.gaussian_blur(img = gray_img, kernel_size = 5)
  48.         #canny 边缘检测
  49.         canny_img = self.canny(img = smoothed_img, low_threshold = 180, high_threshold = 240)
  50.         #区域 Mask
  51.         masked_img = self.region_of_interest(img = canny_img, vertices = self.get_vertices(img))
  52.         #霍夫变换
  53.         houghed_lines = self.hough_lines(img = masked_img, rho = 1, theta = np.pi/180, threshold = 20, min_line_len = 20, max_line_gap = 180)
  54.         # 绘制车道线
  55.         output = self.weighted_img(img = houghed_lines, initial_img = img, alpha=0.8, beta=1., gamma=0.)
  56.         
  57.         return output
  58.     def grayscale(self,img):
  59.         return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
  60.     def canny(self,img, low_threshold, high_threshold):
  61.         return cv2.Canny(img, low_threshold, high_threshold)
  62.     def gaussian_blur(self,img, kernel_size):
  63.         return cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
  64.     def region_of_interest(self,img, vertices):
  65.         mask = np.zeros_like(img)   
  66.    
  67.         if len(img.shape) > 2:
  68.             channel_count = img.shape[2]  
  69.             ignore_mask_color = (255,) * channel_count
  70.         else:
  71.             ignore_mask_color = 255
  72.             
  73.         cv2.fillPoly(mask, vertices, ignore_mask_color)
  74.         
  75.         masked_image = cv2.bitwise_and(img, mask)
  76.         return masked_image
  77.     def draw_lines(self,img, lines, color=[255, 0, 0], thickness=10):
  78.         for line in lines:
  79.             for x1,y1,x2,y2 in line:
  80.                 cv2.line(img, (x1, y1), (x2, y2), color, thickness)
  81.     def slope_lines(self,image,lines):
  82.         img = image.copy()
  83.         poly_vertices = []
  84.         order = [0,1,3,2]
  85.         left_lines = []
  86.         right_lines = []
  87.         for line in lines:
  88.             for x1,y1,x2,y2 in line:
  89.                 if x1 == x2:
  90.                     pass
  91.                 else:
  92.                     m = (y2 - y1) / (x2 - x1)
  93.                     c = y1 - m * x1
  94.                     if m < 0:
  95.                         left_lines.append((m,c))
  96.                     elif m >= 0:
  97.                         right_lines.append((m,c))
  98.         left_line = np.mean(left_lines, axis=0)
  99.         right_line = np.mean(right_lines, axis=0)
  100.         for slope, intercept in [left_line, right_line]:
  101.             rows, cols = image.shape[:2]
  102.             y1= int(rows)
  103.             y2= int(rows*0.6)
  104.             x1=int((y1-intercept)/slope)
  105.             x2=int((y2-intercept)/slope)
  106.             poly_vertices.append((x1, y1))
  107.             poly_vertices.append((x2, y2))
  108.             self.draw_lines(img, np.array([[[x1,y1,x2,y2]]]))
  109.         
  110.         poly_vertices = [poly_vertices[i] for i in order]
  111.         cv2.fillPoly(img, pts = np.array([poly_vertices],'int32'), color = (0,255,0))
  112.         return cv2.addWeighted(image,0.7,img,0.4,0.)
  113.    
  114.     def hough_lines(self,img, rho, theta, threshold, min_line_len, max_line_gap):
  115.         """
  116.         edge_img: 要检测的图片矩阵
  117.                 参数2: 距离r的精度,值越大,考虑越多的线
  118.                 参数3: 距离theta的精度,值越大,考虑越多的线
  119.                 参数4: 累加数阈值,值越小,考虑越多的线
  120.                 minLineLength: 最短长度阈值,短于这个长度的线会被排除
  121.                 maxLineGap:同一直线两点之间的最大距离
  122.         """
  123.         lines = cv2.HoughLinesP(img, rho, theta, threshold, np.array([]), minLineLength=min_line_len, maxLineGap=max_line_gap)
  124.         line_img = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8)
  125.         line_img = self.slope_lines(line_img,lines)
  126.         return line_img
  127.     def weighted_img(self,img, initial_img, alpha=0.1, beta=1., gamma=0.):
  128.         lines_edges = cv2.addWeighted(initial_img, alpha, img, beta, gamma)
  129.         return lines_edges
  130.         
  131.     def get_vertices(self,image):
  132.         rows, cols = image.shape[:2]
  133.         bottom_left  = [cols*0.15, rows]
  134.         top_left     = [cols*0.45, rows*0.6]
  135.         bottom_right = [cols*0.95, rows]
  136.         top_right    = [cols*0.55, rows*0.6]
  137.         
  138.         ver = np.array([[bottom_left, top_left, top_right, bottom_right]], dtype=np.int32)
  139.         return ver
复制代码
深度学习方法



  • 基于基于分割的方法
  • 基于检测的方法
  • 基于关键点的方法
  • 基于参数曲线的方法
LaneNet+H-Net车道线检测

论文链接:Towards End-to-End Lane Detection: an Instance Segmentation Approach
代码链接
LaneNet+H-Net的车道线检测方法,通过深度学习方法实现端到端的车道线检测,该方法包括两个重要组件,一个是实例分割网络,另一个是车道线检测网络。
该方法的重要贡献在于使用实例分割技术来区分差异车道线之间的重叠和交织,而且应用多任务学习方法同时实现车道线检测和实例分割。详细来说,该方法将车道线检测标题转化为实例分割标题,使用 Mask R-CNN 实现车道线的分割和检测。通过融合两个任务的 loss 函数,同时对车道线检测和实例分割网络进行练习,实现端到端的车道线检测。
论文中的模子布局重要包括两部分:实例分割网络和车道线检测网络。实例分割网络接纳 Mask R-CNN,由主干网络和 Mask R-CNN 网络两部分组成。车道线检测网络接纳了 U-Net 布局,用于对掩码图像进行后处理,得到车道线检测结果。


  • LaneNet将车道线检测标题转为实例分割标题,即:每个车道线形成独立的实例,但都属于车道线这一种别;H-Net由卷积层和全毗连层组成,使用转换矩阵H对同一车道线的像素点进行回归;
  • 对于一张输入图片,LaneNet负责输出实例分割结果,每条车道线一个标识ID,H-Net输出一个转换矩阵,对车道线像素点进行修正,并对修正后的结果拟合出一个三阶多项式作为猜测的车道线;


    测试效果
Ultra-Fast-Lane-Detection-V2

论文链接:Ultra Fast Deep Lane Detection with Hybrid Anchor Driven Ordinal Classification
代码:Ultra-Fast-Lane-Detection-V2

解说模子部分的代码



  • backbone
  • layer
  • model_culane
  • model_tusimple
  • seg_model
backbone

backbone有两类主干方法,VGG和ResNet


  • class vgg16bn
  • class resnet
  1. import torch,pdb
  2. import torchvision
  3. import torch.nn.modules
  4. class vgg16bn(torch.nn.Module):
  5.     def __init__(self,pretrained = False):
  6.         super(vgg16bn,self).__init__()
  7.         model = list(torchvision.models.vgg16_bn(pretrained=pretrained).features.children())
  8.         model = model[:33]+model[34:43]
  9.         self.model = torch.nn.Sequential(*model)
  10.         
  11.     def forward(self,x):
  12.         return self.model(x)
  13. class resnet(torch.nn.Module):
  14.     def __init__(self,layers,pretrained = False):
  15.         super(resnet,self).__init__()
  16.         #resnet有以下几种选择方式
  17.         if layers == '18':
  18.             model = torchvision.models.resnet18(pretrained=pretrained)
  19.         elif layers == '34':
  20.             model = torchvision.models.resnet34(pretrained=pretrained)
  21.         elif layers == '50':
  22.             model = torchvision.models.resnet50(pretrained=pretrained)
  23.         elif layers == '101':
  24.             model = torchvision.models.resnet101(pretrained=pretrained)
  25.         elif layers == '152':
  26.             model = torchvision.models.resnet152(pretrained=pretrained)
  27.         elif layers == '50next':
  28.             model = torchvision.models.resnext50_32x4d(pretrained=pretrained)
  29.         elif layers == '101next':
  30.             model = torchvision.models.resnext101_32x8d(pretrained=pretrained)
  31.         elif layers == '50wide':
  32.             model = torchvision.models.wide_resnet50_2(pretrained=pretrained)
  33.         elif layers == '101wide':
  34.             model = torchvision.models.wide_resnet101_2(pretrained=pretrained)
  35.         elif layers == '34fca':
  36.             model = torch.hub.load('cfzd/FcaNet', 'fca34' ,pretrained=True)
  37.         else:
  38.             raise NotImplementedError
  39.         
  40.         self.conv1 = model.conv1
  41.         self.bn1 = model.bn1
  42.         self.relu = model.relu
  43.         self.maxpool = model.maxpool
  44.         self.layer1 = model.layer1
  45.         self.layer2 = model.layer2
  46.         self.layer3 = model.layer3
  47.         self.layer4 = model.layer4
  48.     def forward(self,x):
  49.         x = self.conv1(x)
  50.         x = self.bn1(x)
  51.         x = self.relu(x)
  52.         x = self.maxpool(x)
  53.         x = self.layer1(x)
  54.         x2 = self.layer2(x)
  55.         x3 = self.layer3(x2)
  56.         x4 = self.layer4(x3)
  57.         return x2,x3,x4
复制代码
layer

在这部分代码,是设置网络层的功能模块。其中有两个模块 : AddCoordinates和CoordConv,它们都是用于卷积操作的。这些模块是用于办理卷积神经网络中的"hole problem"(洞穴标题)的。


  • AddCoordinates用于叠加在输入上的坐标信息。详细来说,它将坐标x、y和r(若设置了with_r参数为True)与输入张量相毗连。其中,x和y坐标在[-1, 1]范围内进行缩放,坐标原点在中央。r是间隔中央的欧几里得间隔,并缩放为[0,1]范围内。如许做的目的是为了给卷积层提供额外的位置信息,以便进步其对位置信息的感知。
  • CoordConv模块则使用AddCoordinates模块中到场的坐标信息进行卷积操作。在当前张量和坐标信息合并后,将结果输入到卷积层中进行卷积操作。别的参数与torch.nn.Conv2d类似。
需要注意的是,这些模块需要连合使用,AddCoordinates模块在CoordConv模块之前使用,以确保卷积层可以大概获取到足够的位置信息。
  1. import torch
  2. from torch import nn
  3. class AddCoordinates(object):
  4.     r"""Coordinate Adder Module as defined in 'An Intriguing Failing of
  5.     Convolutional Neural Networks and the CoordConv Solution'
  6.     (https://arxiv.org/pdf/1807.03247.pdf).
  7.     This module concatenates coordinate information (`x`, `y`, and `r`) with
  8.     given input tensor.
  9.     `x` and `y` coordinates are scaled to `[-1, 1]` range where origin is the
  10.     center. `r` is the Euclidean distance from the center and is scaled to
  11.     `[0, 1]`.
  12.     Args:
  13.         with_r (bool, optional): If `True`, adds radius (`r`) coordinate
  14.             information to input image. Default: `False`
  15.     Shape:
  16.         - Input: `(N, C_{in}, H_{in}, W_{in})`
  17.         - Output: `(N, (C_{in} + 2) or (C_{in} + 3), H_{in}, W_{in})`
  18.     Examples:
  19.         >>> coord_adder = AddCoordinates(True)
  20.         >>> input = torch.randn(8, 3, 64, 64)
  21.         >>> output = coord_adder(input)
  22.         >>> coord_adder = AddCoordinates(True)
  23.         >>> input = torch.randn(8, 3, 64, 64).cuda()
  24.         >>> output = coord_adder(input)
  25.         >>> device = torch.device("cuda:0")
  26.         >>> coord_adder = AddCoordinates(True)
  27.         >>> input = torch.randn(8, 3, 64, 64).to(device)
  28.         >>> output = coord_adder(input)
  29.     """
  30.     def __init__(self, with_r=False):
  31.         self.with_r = with_r
  32.     def __call__(self, image):
  33.         batch_size, _, image_height, image_width = image.size()
  34.         y_coords = 2.0 * torch.arange(image_height).unsqueeze(
  35.             1).expand(image_height, image_width) / (image_height - 1.0) - 1.0
  36.         x_coords = 2.0 * torch.arange(image_width).unsqueeze(
  37.             0).expand(image_height, image_width) / (image_width - 1.0) - 1.0
  38.         coords = torch.stack((y_coords, x_coords), dim=0)
  39.         if self.with_r:
  40.             rs = ((y_coords ** 2) + (x_coords ** 2)) ** 0.5
  41.             rs = rs / torch.max(rs)
  42.             rs = torch.unsqueeze(rs, dim=0)
  43.             coords = torch.cat((coords, rs), dim=0)
  44.         coords = torch.unsqueeze(coords, dim=0).repeat(batch_size, 1, 1, 1)
  45.         image = torch.cat((coords.to(image.device), image), dim=1)
  46.         return image
  47. class CoordConv(nn.Module):
  48.     r"""2D Convolution Module Using Extra Coordinate Information as defined
  49.     in 'An Intriguing Failing of Convolutional Neural Networks and the
  50.     CoordConv Solution' (https://arxiv.org/pdf/1807.03247.pdf).
  51.     Args:
  52.         Same as `torch.nn.Conv2d` with two additional arguments
  53.         with_r (bool, optional): If `True`, adds radius (`r`) coordinate
  54.             information to input image. Default: `False`
  55.     Shape:
  56.         - Input: `(N, C_{in}, H_{in}, W_{in})`
  57.         - Output: `(N, C_{out}, H_{out}, W_{out})`
  58.     Examples:
  59.         >>> coord_conv = CoordConv(3, 16, 3, with_r=True)
  60.         >>> input = torch.randn(8, 3, 64, 64)
  61.         >>> output = coord_conv(input)
  62.         >>> coord_conv = CoordConv(3, 16, 3, with_r=True).cuda()
  63.         >>> input = torch.randn(8, 3, 64, 64).cuda()
  64.         >>> output = coord_conv(input)
  65.         >>> device = torch.device("cuda:0")
  66.         >>> coord_conv = CoordConv(3, 16, 3, with_r=True).to(device)
  67.         >>> input = torch.randn(8, 3, 64, 64).to(device)
  68.         >>> output = coord_conv(input)
  69.     """
  70.     def __init__(self, in_channels, out_channels, kernel_size,
  71.                  stride=1, padding=0, dilation=1, groups=1, bias=True,
  72.                  with_r=False):
  73.         super(CoordConv, self).__init__()
  74.         in_channels += 2
  75.         if with_r:
  76.             in_channels += 1
  77.         self.conv_layer = nn.Conv2d(in_channels, out_channels,
  78.                                     kernel_size, stride=stride,
  79.                                     padding=padding, dilation=dilation,
  80.                                     groups=groups, bias=bias)
  81.         self.coord_adder = AddCoordinates(with_r)
  82.     def forward(self, x):
  83.         x = self.coord_adder(x)
  84.         x = self.conv_layer(x)
  85.         return x
复制代码
seg_model



  • 重要包罗conv_bn_relu和SegHead
  • conv_bn_relu模块包罗一个卷积层、一个BatchNorm层和ReLU激活函数。这些层的目的是将输入张量x进行卷积、BN和ReLU激活操作,并将结果返回。
  • SegHead模块实现了一个带有分支的分割头。它包括三个分支,它们分别对应于主干网络的差异层级。每个分支都由卷积、BN和ReLU激活函数组成,并使用双线性插值对它们的输出进行上采样。然后它们的输出会被拼接在一起,输入到一个包罗一系列卷积层的组合中。最后,它输出一个张量,其中包罗num_lanes + 1个通道,表示每个车道的掩码以及背景。
  1. import torch
  2. from utils.common import initialize_weights
  3. class conv_bn_relu(torch.nn.Module):
  4.     def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False):
  5.         super(conv_bn_relu,self).__init__()
  6.         self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size,
  7.             stride = stride, padding = padding, dilation = dilation,bias = bias)
  8.         self.bn = torch.nn.BatchNorm2d(out_channels)
  9.         self.relu = torch.nn.ReLU()
  10.     def forward(self,x):
  11.         x = self.conv(x)
  12.         x = self.bn(x)
  13.         x = self.relu(x)
  14.         return x
  15.         
  16. class SegHead(torch.nn.Module):
  17.     def __init__(self,backbone, num_lanes):
  18.         super(SegHead, self).__init__()
  19.         self.aux_header2 = torch.nn.Sequential(
  20.             conv_bn_relu(128, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1),
  21.             conv_bn_relu(128,128,3,padding=1),
  22.             conv_bn_relu(128,128,3,padding=1),
  23.             conv_bn_relu(128,128,3,padding=1),
  24.         )
  25.         self.aux_header3 = torch.nn.Sequential(
  26.             conv_bn_relu(256, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1),
  27.             conv_bn_relu(128,128,3,padding=1),
  28.             conv_bn_relu(128,128,3,padding=1),
  29.         )
  30.         self.aux_header4 = torch.nn.Sequential(
  31.             conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1),
  32.             conv_bn_relu(128,128,3,padding=1),
  33.         )
  34.         self.aux_combine = torch.nn.Sequential(
  35.             conv_bn_relu(384, 256, 3,padding=2,dilation=2),
  36.             conv_bn_relu(256, 128, 3,padding=2,dilation=2),
  37.             conv_bn_relu(128, 128, 3,padding=2,dilation=2),
  38.             conv_bn_relu(128, 128, 3,padding=4,dilation=4),
  39.             torch.nn.Conv2d(128, num_lanes+1, 1)
  40.             # output : n, num_of_lanes+1, h, w
  41.         )
  42.         initialize_weights(self.aux_header2,self.aux_header3,self.aux_header4,self.aux_combine)
  43.         # self.droput = torch.nn.Dropout(0.1)
  44.     def forward(self,x2,x3,fea):
  45.         x2 = self.aux_header2(x2)
  46.         x3 = self.aux_header3(x3)
  47.         x3 = torch.nn.functional.interpolate(x3,scale_factor = 2,mode='bilinear')
  48.         x4 = self.aux_header4(fea)
  49.         x4 = torch.nn.functional.interpolate(x4,scale_factor = 4,mode='bilinear')
  50.         aux_seg = torch.cat([x2,x3,x4],dim=1)
  51.         aux_seg = self.aux_combine(aux_seg)
  52.         return aux_seg
复制代码
未完待续!

来源:https://blog.csdn.net/qq_38915354/article/details/130340073
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则