Background

构架3DMOT的Dataset类和Dataloader

670
3 分钟阅读

Dataset

在实例化Dataloader之前,我们需要先创建一个Dataset类。

下面是一个简单的Dataset类示例:

class NusceneseDataset(Dataset):
def __init__(self, ann_file):
super().__init__()
self.stride = torch.tensor(0.8, dtype=torch.float32)
self.point_range = torch.tensor(51.2, dtype=torch.float32)
print("NuScene train open.")
with open(ann_file, "rb") as handle:
video_clips = pickle.load(handle)
print("Convert instances")
video_instances = []
for seq in tqdm.tqdm(video_clips):
seq_length = len(seq)
seq_instances = []
for frame_id in range(seq_length):
frame_detected = seq[frame_id]
translation = torch.from_numpy(frame_detected['translation'])
size = torch.from_numpy(frame_detected['size'])
yaw = torch.from_numpy(frame_detected['yaw']).unsqueeze(-1)
rotation = torch.from_numpy(frame_detected['rotation'])
velocity = torch.from_numpy(frame_detected['velocity'])
score = torch.from_numpy(frame_detected['score']).unsqueeze(-1)
classes = torch.from_numpy(frame_detected['classes']).unsqueeze(-1)
tracking_id = torch.from_numpy(frame_detected['tracking_id']).unsqueeze(-1)
scene_token = frame_detected['scene_token']
sample_token = frame_detected['sample_token']
timestamp = frame_detected['timestamp']
times = torch.zeros([len(size), 1], dtype=torch.float64) + timestamp
img_meta = {
"scene_token": scene_token,
"sample_token": sample_token,
"timestamp": timestamp,
}
instance_ = Instances([1, 1], img_meta)
instance_.set("translation", translation)
instance_.set("size", size)
instance_.set("yaw", yaw)
instance_.set("times", times)
instance_.set("rotation", rotation)
instance_.set("velocity", velocity)
instance_.set("score", score)
instance_.set("classes", classes)
instance_.set("tracking_id", tracking_id.clone())
instance_inds = torch.zeros_like(tracking_id) - 1
instance_.set("instance_inds", instance_inds)
instance_ = instance_[instance_.score[..., 0] >= 0.1]
seq_instances.append(instance_)
video_instances.append(seq_instances)
print("Convert mini batch")
clip_video = []
for seq in tqdm.tqdm(video_instances):
start_frame = 0
end_frame = len(seq)
clip_len_frames = 10
for idx in range(start_frame, end_frame - 1, 3):
clip_frame_ids = []
for frame_idx in range(idx, idx + clip_len_frames):
if frame_idx < end_frame:
input_dict = seq[frame_idx]
if len(input_dict) > 0:
clip_frame_ids.append(input_dict)
if len(clip_frame_ids) >= 10:
seq_times = [s._img_meta['timestamp'] for s in clip_frame_ids]
inverse_sequence_instances = copy.deepcopy(clip_frame_ids[::-1])
for iidx, inv_seq in enumerate(inverse_sequence_instances):
inv_seq.times = inv_seq.times.to(torch.float64)
inv_seq.times[inv_seq.times > 0] = torch.tensor(seq_times[iidx], dtype=torch.float64)
inv_seq._img_meta['timestamp'] = seq_times[iidx]
clip_video.append(copy.deepcopy(inverse_sequence_instances))
clip_video.append(copy.deepcopy(clip_frame_ids))
self.data_infos = clip_video
def __len__(self):
return len(self.data_infos)
def get_instance(self, sequence):
return {
"instances": sequence
}
def __getitem__(self, idx):
idx = min(idx, len(self) - 1)
sequence = self.data_infos[idx]
return self.get_instance(sequence)

标准的检测结果的格式,为:

sample_result {
"sample_token": <str> # 样本的唯一标识符。
"translation": <float> [3] # 估计的边界框位置,单位为米,格式为[x, y, z]。
"size": <float> [3] # 估计的边界框尺寸,单位为米,格式为[length, width, height]。
"rotation": <float> [4] # 估计的边界框朝向,使用四元数表示,格式为[x, y, z, w]。
"velocity": <float> [2] # 估计的边界框速度,单位为m/s,格式为[vx, vy]。
"acceleration": <float> [2] # 估计的边界框加速度,单位为m/s²,格式为[ax, ay]。
"detection_name": <str> # 预测的目标类别名称,例如car, pedestrian等。
"detection_score": <float> # 预测的目标类别置信度分数,范围为[0.0, 1.0]。
"attribute_name": <str> # 预测的目标属性名称,例如parked, moving等。
}
构架3DMOT的Dataset类和Dataloader
/blog/26025918
作者
发布于
2026/1/15
许可协议
CC BY-NC-SA 4.0