1.torch.nonzero(input, *, as_tuple=False)
作用:在PyTorch中用于返回输入张量中非零元素的位置索引。
返回值:返回一个张量,每行代表一个非零元素的索引。
参数含义:
(1)input:输入的PyTorch 张量。
(2)as_tuple:一个布尔值,指定返回结果的格式。默认为 False,返回一个张量。如果设置为 True,则返回一个元组,其中每个元素代表一个维度上的索引。
应用场景:
(1)高级索引:
使用as_tuple=True返回的元组可以用于对原始张量进行高级索引,例如提取所有非零元素;
(2)掩码操作:
结合torch.nonzero()和其他函数来创建掩码,例如选择特定条件下的元素;
示例:
index_list = torch.nonzero(scores > det_thr, as_tuple=True)[0]
其中scores是torch.Tensor ,维度ndim=1; det_thr类型为<float>
代码含义:
其他函数scores>det_thr: 是逐元素比较,返回一个布尔型张量,得到result = tensor([True,True, False,False,…])
然后torch.nonzero(result) 得到所有值为True的元素的位置索引。例如tensor([[0],[2]])
参数as_tuple用于控制输出的格式:
- 默认 (as_tuple=False): 返回形如 二维张量 的结果,比如 tensor([[0], [2]]);
- 加上 as_tuple=True: 返回一个元组,其中每一维是一个 1D 张量,表示每个轴的索引
最后的[0]用于从元组中提取索引值张量;
最后整句代码含义:
取出所有 scores 中 大于阈值 det_thr 的元素的索引,并保存为 index_list
(3)稀疏张量处理:
在处理稀疏张量时,torch.nonzero()可以用于获取非零元素的索引。