欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 焦点 > torch.matmul() VS torch.einsum()

torch.matmul() VS torch.einsum()

2025/5/25 23:47:59 来源:https://blog.csdn.net/m0_74413554/article/details/148114492  浏览:    关键词:torch.matmul() VS torch.einsum()

torch.matmul():标准的矩阵乘法

  • 向量-向量(点积)

    a = torch.randn(3)  # [3]
    b = torch.randn(3)  # [3]
    c = torch.matmul(a, b)  # 点积,标量输出
    
  • 矩阵-向量

    A = torch.randn(3, 4)  # [3, 4]
    x = torch.randn(4)     # [4]
    y = torch.matmul(A, x) # [3]
    
  • 矩阵-矩阵

    A = torch.randn(3, 4)  # [3, 4]
    B = torch.randn(4, 5)  # [4, 5]
    C = torch.matmul(A, B) # [3, 5]
    
  • 批量矩阵乘法(更高维张量)

    A = torch.randn(2, 3, 4)  # [B, M, K]
    B = torch.randn(2, 4, 5)  # [B, K, N]
    C = torch.matmul(A, B)     # [B, M, N]
    

    torch.einsum:爱因斯坦求和约定(更通用的张量运算工具)

  • 矩阵乘法

    A = torch.randn(3, 4)
    B = torch.randn(4, 5)
    C = torch.einsum("ik,kj->ij", A, B)  # 等价于 A @ BA = torch.randn(2, 3, 4)  # [B, M, K]
    B = torch.randn(2, 4, 5)  # [B, K, N]
    C = torch.einsum("bik,bkj->bij", A, B)  # [B, M, N]a = torch.randn(3)
    b = torch.randn(3)
    c = torch.einsum("i,i->", a, b)  # 点积,标量输出
    
  • 转置

    A = torch.randn(3, 4)
    B = torch.einsum("ij->ji", A)  # 等价于 A.T
    
  • 对角线提取

  • 张量收缩(Tensor Contraction)(高阶张量乘法)

    A = torch.randn(2, 3, 4, 5)
    B = torch.randn(2, 4, 5, 6)
    C = torch.einsum("abcd,abde->abce", A, B)  # 对 d 维度收缩
    
  • 广播运算

torch.matmultorch.einsum
灵活性仅支持矩阵乘法类操作支持任意张量运算(转置、收缩等)
可读性直观(A @ B需要熟悉爱因斯坦求和约定
性能高度优化(推荐用于标准矩阵乘法)灵活但可能稍慢
广播支持
批量处理自动支持需显式指定批量维度

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词