基于 Transformer 的多目标跟踪算法

1696 字
9 分钟

本文主要介绍TransTrack, TrackFormer, MORT三种算法,但考虑到三篇文章中都与 DETR 高度相关,这里对DETR也做一个简单的介绍。

DETR#

DETR 通过解码器得到固定 N 个预测集(N设置的远大于实际图像中的目标数量)

网络结构#

DETR的网络结构

import torch
from torch import nn
from torchvision.models import resnet50
class DETR(nn.Module):
def __init__(self, num_classes, hidden_dim, nheads,
num_encoder_layers, num_decoder_layers):
super().__init__()
# We take only convolutional layers from ResNet-50 model
self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])
self.conv = nn.Conv2d(2048, hidden_dim, 1)
self.transformer = nn.Transformer(hidden_dim, nheads,
num_encoder_layers, num_decoder_layers)
self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
self.linear_bbox = nn.Linear(hidden_dim, 4)
self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
def forward(self, inputs):
x = self.backbone(inputs)
h = self.conv(x)
H, W = h.shape[-2:]
pos = torch.cat([
self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
], dim=-1).flatten(0, 1).unsqueeze(1)
h = self.transformer(pos + h.flatten(2).permute(2, 0, 1),
self.query_pos.unsqueeze(1))
return self.linear_class(h), self.linear_bbox(h).sigmoid()
detr = DETR(num_classes=91, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6)
detr.eval()
inputs = torch.randn(1, 3, 800, 1200)
logits, bboxes = detr(inputs)
  1. 通过一个 CNN 网络提取图像特征,并将特征图展平后加入位置编码,输入到 Transformer 编码器中得到 全局特征 信息;
  2. 将一组可学习的对象查询作为 Query,与编码器输出的特征(作为 Key 和 Value)一起输入到 Transformer 解码器,得到隐藏状态;
  3. 将解码器输出的隐藏状态分别送入两个 FFN ,直接并行预测出最终的检测结果(类别和边界框坐标)。

实现细节#

目标查询:与传统方法中的锚框作用类似,但是这是一个可学习的参数,用于查询特定位置的目标框。 全连接层:编码器中的全连接层实际上由两层 1×11 \times 1 的卷积层组成。

TransTrack#

注意力机制在MOT的应用#

下面分别是基于检测的多目标跟踪、基于注意力机制的单目标跟踪、基于注意力机制的多目标跟踪流程。

TBD多目标跟踪的流水线

基于注意力机制的单目标跟踪流水现

基于注意力机制的多目标跟踪流水现

在单目标跟踪领域,注意力机制将上一帧的目标作为Query,下一帧作为Value,很容易就能实现对于单目标的跟踪。 但是对于多目标,这样的方法无法处理新生的目标,存在局限性。需要设计额外的流程来处理新生的轨迹。

主要流程#

TransTrack的结构

  1. 通过一个 CNN 网络 (ResNet50) 提取图像特征,特别的,将每个时刻特征保存到下一时刻;
  2. 再将 两个 连续帧的特征通过编码器 (自注意力机制) 得到 组合特征
  3. 将组合特征作为 Key 输入到解码器中,将 对象查询 和上一帧的对象特征作为 轨迹查询 作为 Query 并行 输入到两个解码器中,分别得到检测特征和跟踪特征;
  4. 再通过并行的全连接层得到检测框和跟踪框,使用匈牙利算法进行IoU匹配得到最后的预测框。
TIP

实际上,可以将该过程视为分别从当前帧的检测框和通过上一帧轨迹得到的检测框,再进行IoU匹配,得到当前帧的跟踪结果。
其根本的思想与传统的TBD范式较为相似,可简单的将其理解为,使用 STEP3-2 中的代替了TBD中的卡尔曼滤波器的预测功能(但是预测的依据基于外观信息而非运动模型)。

TrackFormer#

主要流程#

TrackFormer的结构

  1. 通过一个 CNN 网络 (ResNet50) 提取图像特征,并将特征通过编码器进一步 全局特征 信息;
  2. 将全局特征作为 Key 和 Value ,(上一帧得到的) 轨迹查询对象查询 拼接后作为 Query 将输入到解码器中;
  3. 解码器的输出作为下一帧的 轨迹查询 ,同时通过全连接层得到预测框和类别预测。
  4. 解码器自注意力无法解决预测框重叠问题,因此需要通过删除低置信度框和 NMS 来解决强烈重叠的预测框。
TIP

TrackFormer的思路很简单,对比TrackFormer和DETR的网络结构,我们就能发现,二者的主要区别就是解码器的输入不同: DETR的输入为 Object Query,而TrackFormer就是在此基础上将 Object Query 和 Track Query 拼接起来。

MORT#

在 DETR 的基础上设计,与同期的 TransTrack 和 TrackFormer 相比,无需非极大值抑制和IoU匹配的后处理。

轨迹块感知标签分配 (TALA)#

在 DETR 中,使用的是固定长度的对象查询,检测可以分配给任意对象。
而 MOTR 中引入了轨迹块感知标签分配 (TALA) 使检测查询仅用于检测新生成的对象,跟踪查询预测跟踪对象,如下图所示。
TALA策略图示

查询交互模块 (QIM)#

查询交互模块 QIM 的输入是 Transformer 解码器产生的隐藏状态和对应的预测分数。在训练过程中,对于跟踪查询得到的对象,如果匹配的对象在真实值中消失或预测边界框与目标之间的交并比(IoU)低于0.5的阈值,则移除已终止对象的隐藏状态;对于对象查询的到的对象,只保留得分高于入门阈值的结果。
特别的,过滤后的对于跟踪查询得到的对象,通过时间聚合网络(TAN)后,与新生对象连接。

主要流程#

MOTR主要流程

  1. 通过一个 CNN 网络 (ResNet50) 提取图像特征,并将特征通过编码器得到 全局特征 信息;
  2. 将全局特征作为 Key 和 Value ,检测查询和跟踪查询连接起来作为 Query 输入到解码器,生成 隐藏状态
  3. 隐状态通过 全连接层 得到当前时刻的跟踪结果,同时通过 查询交互模块 (QIM) 得到下一帧的轨迹查询。
TIP

跟踪查询集动态更新,初始化为空;检测查询用于查询新出现的对象。当前时刻跟踪查询和检测查询得到的所有跟踪框在下一时刻用于跟踪查询,而当前时刻已终止的跟踪对象将会从跟踪查询集中删除。
将检测查询和跟踪查询连接起来,输入到解码器中。在实践中,检测查询将只会检测新生成的对象,因为Transformer解码器中自注意力机制的查询交互将抑制检测跟踪对象的查询。

基于 Transformer 的多目标跟踪算法
https://blog.rinne05.top/blog/research/mot-e2e-intro/
作者
Rinne
发布于
10/22/2025
许可协议
CC BY-NC-SA 4.0

主题设置

主题模式
主题色
透明度
模糊
© 2025 Rinne,采用 CC BY-NC-SA 4.0 许可
ICP备案号: 豫ICP备2025156598号
输入以搜索...
通过 Fuse.js 搜索