欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > IT业 > torch.sparse_csr_tensor

torch.sparse_csr_tensor

2025/11/19 18:40:38 来源:https://blog.csdn.net/m0_63070489/article/details/144807944  浏览:    关键词:torch.sparse_csr_tensor

torch.sparse_csr_tensor

  • 以**CSR格式**构建一个稀疏张量。CSR 格式的稀疏张量乘法运算通常比 COO 格式的稀疏张量更快。
    • CSR(Compressed Sparse Row)格式是一种存储稀疏矩阵的常用格式,它通过三个数组来表示稀疏矩阵:
      行指针数组(crow_indices):每行第一个非零值在values中的索引。最后一个元素是非零值的总个数。
      列索引数组(col_indices):矩阵中非零值所在列的索引
      非零元素数组(values):矩阵中的非零值,按行优先顺序排列。
      假设我们有一个5x5的稀疏矩阵A,如下所示:
      [ 0 0 1 0 0 2 0 0 3 0 0 4 0 0 0 0 0 0 0 5 6 0 7 0 0 ] \begin{bmatrix} 0 & 0 & 1 & 0 & 0 \\ 2 & 0 & 0 & 3 & 0 \\ 0 & 4 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 5 \\ 6 & 0 & 7 & 0 & 0 \\ \end{bmatrix} 0200600400100070300000050
      ​CSR格式数据为:
      crow_indices= [0, 1, 3, 4, 5, 7]
      col_indices= [2, 0, 3, 1, 4, 0, 2]
      values = [1, 2, 3, 4, 5, 6, 7]
      • 解释:
        values = [1, 2, 3, 4, 5, 6, 7]
        行优先顺序存储所有非零元素的值,顺序与列索引数组对应。
        col_indices= [2, 0, 3, 1, 4, 0, 2]
        第1行的第1个非零元素是1,位于第2列,所以第1个元素是2
        第2行的第1个非零元素是2,位于第0列,所以第2个元素是0
        第2行的第2个非零元素是3,位于第3列,所以第3个元素是3
        第3行的第1个非零元素是4,位于第1列,所以第4个元素是1
        第4行的第1个非零元素是5,位于第4列,所以第5个元素是4
        第5行的第1个非零元素是6,位于第0列,所以第6个元素是0
        第5行的第2个非零元素是7,位于第2列,所以第7个元素是2
        crow_indices= [0, 1, 3, 4, 5,7]
        第1行的第一个非零元素1,在values中的下标是0,所以这里是0
        第2行的第一个非零元素2,在values中的下标是1,所以这里是1
        第3行的第一个非零元素4,在values中的下标是3,所以这里是3
        第4行的第一个非零元素5,在values中的下标是4,所以这里是4
        第5行的第一个非零元素6,在values中的下标是5,所以这里是5
        共有7个非零元素[1, 2, 3, 4, 5, 6, 7],故最后一个数是7

torch.sparse_csr_tensor(crow_indices, col_indices, values, size=None, *, dtype=None, device=None, pin_memory=False, requires_grad=False, check_invariants=None)Tensor

参数

  • crow_indices (array_like) - (B+1) 维数组 (*batchsize, nrows + 1)。每行第一个非零值在values中的索引。最后一个元素是非零值的总个数。相邻数字的差值,表示给定行中元素的数量。
  • col_indices (array_like) – values中每个元素的列坐标。长度与 values 相同。
  • values (array_list) – 张量的初始值。可以是列表、元组、NumPy ndarray、标量和其他类型
  • size (list, tuple, torch.Size, optional) – 稀疏张量的大小: (*batchsize, nrows, ncols, *densesize) .如果未提供,则大小将被推断为足够大的最小大小,以容纳所有非零元素。

关键字参数

  • dtype (torch.dtype,可选) – 返回张量的所需数据类型。默认值:如果为 None,则从值推断数据类型。
  • device (torch.device,可选) – 返回的张量的所需设备。默认值:如果为 None,则使用当前设备作为默认张量类型(请参阅 torch.set_default_device() )。device 将是 CPU 张量类型的 CPU,CUDA 张量类型的当前 CUDA 设备。
  • pin_memory (bool,可选) – 如果设置,则返回的张量将在固定内存中分配。仅适用于 CPU 张量。默认值:False。可以使用 .to(‘cuda’) 方法将数据从固定内存移动到 GPU,通常比从常规内存移动到 GPU 更快。
  • requires_grad (bool, 可选) – 如果 autograd 应记录对返回的张量的操作。默认值:False。
  • check_invariants (bool, 可选) – 是否选中稀疏张量不变性。默认值:False(由 torch.sparse.check_sparse_tensor_invariants.is_enabled() 得到)
    • 不变性检查是一种验证机制,用于确保稀疏张量的数据结构和逻辑是正确的。如:压缩稀疏行(CSR)格式中的一个不变性条件,即crow_indices数组的最后一个元素必须等于非零元素的数量(nnz)。

示例代码

import torch  # 导入PyTorch库# 定义CSR格式稀疏矩阵的行索引数组,表示每一行的起始位置
crow_indices = [0, 1, 3, 4, 5, 7]
# 定义CSR格式稀疏矩阵的列索引数组,表示每个非零元素所在的列
col_indices = [2, 0, 3, 1, 4, 0, 2]
# 定义CSR格式稀疏矩阵的非零元素值
values = [1, 2, 3, 4, 5, 6, 7]
# 使用torch.sparse_csr_tensor函数创建CSR格式的稀疏张量
# crow_indices和col_indices分别转换为torch.int64类型的张量
# values转换为默认类型的张量,并指定稀疏张量的数据类型为torch.double
x = torch.sparse_csr_tensor(torch.tensor(crow_indices, dtype=torch.int64),torch.tensor(col_indices, dtype=torch.int64),torch.tensor(values), dtype=torch.double)
print(f'x=\n{x}') # 打印稀疏张量x
print(f'x的稠密矩阵形式=\n{x.to_dense()}') # 将稀疏张量x转换为稠密矩阵形式并打印
'''output
x=
tensor(crow_indices=tensor([0, 1, 3, 4, 5, 7]),col_indices=tensor([2, 0, 3, 1, 4, 0, 2]),values=tensor([1., 2., 3., 4., 5., 6., 7.]), size=(5, 5), nnz=7,dtype=torch.float64, layout=torch.sparse_csr)
x的稠密矩阵形式=
tensor([[0., 0., 1., 0., 0.],[2., 0., 0., 3., 0.],[0., 4., 0., 0., 0.],[0., 0., 0., 0., 5.],[6., 0., 7., 0., 0.]], dtype=torch.float64)
'''

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词