欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 幼教 > weight_norm随手笔记

weight_norm随手笔记

2025/9/20 8:27:38 来源:https://blog.csdn.net/scar2016/article/details/146521271  浏览:    关键词:weight_norm随手笔记

文章目录

  • 1. description
  • 2. pytorch

1. description

假设我们有一个矩阵A 表示如下:
A = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] \begin{equation} A= \begin{bmatrix} 1&2&3&4\\\\ 5&6&7&8\\\\ 9&10&11&12\end{bmatrix} \end{equation} A= 159261037114812
经过pytorch中的torch.nn.utils.parametrizations.weight_norm包装后,只是将原来的矩阵weight分解为两个部分

  • parametrizations.weight.original0:表示的行向量的模长,weight_g
    5.477 = 1 2 + 2 2 + 3 2 + 4 2 ; 13.191 = 5 2 + 6 2 + 7 2 + 8 2 ; 21.119 = 9 2 + 1 0 2 + 1 1 2 + 1 2 2 ; \begin{equation} 5.477=\sqrt{1^2+2^2+3^2+4^2};13.191=\sqrt{5^2+6^2+7^2+8^2};21.119=\sqrt{9^2+10^2+11^2+12^2}; \end{equation} 5.477=12+22+32+42 ;13.191=52+62+72+82 ;21.119=92+102+112+122 ;
  • parametrizations.weight.original1:表示的矩阵weigh: weight_v
    V = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] \begin{equation} V= \begin{bmatrix} 1&2&3&4\\\\ 5&6&7&8\\\\ 9&10&11&12\end{bmatrix} \end{equation} V= 159261037114812

2. pytorch

  • python:
import torch
import torch.nn as nntorch.set_printoptions(precision=3, sci_mode=False, threshold=torch.inf)if __name__ == "__main__":run_code = 0a_matrix = torch.arange(12).reshape((3, 4)).to(torch.float)b_matrix = torch.arange(12).reshape((3, 4)).to(torch.float) + 1print(f"start")my_linear = nn.Linear(in_features=4, out_features=3, bias=False)my_linear.weight = nn.Parameter(b_matrix)print(my_linear.weight)# wn_linear = nn.utils.weight_norm(my_linear)wn_linear = nn.utils.parametrizations.weight_norm(my_linear, name='weight')for p in wn_linear.named_parameters():print(p)c_matrix = torch.sqrt(torch.sum(b_matrix ** 2, dim=-1))print(f"c_matrix=\n{c_matrix}")a_result_linear = my_linear(a_matrix)a_result_wn = wn_linear(a_matrix)print(f"a_result_linear=\n{a_result_linear}")print(f"a_result_wn=\n{a_result_wn}")weight_direction = wn_linear.parametrizations.weight.original1 / wn_linear.parametrizations.weight.original0print(f"weight_direction=\n{weight_direction}")c_matrix_ones = c_matrix.reshape(-1, 1) @ torch.ones(1, 4)print(f"c_matrix_ones=\n{c_matrix_ones}")test_weight_direction = b_matrix / c_matrix_onesprint(f"test_weight_direction=\n{test_weight_direction}")check_weight = torch.allclose(weight_direction,test_weight_direction)print(f"weight_direction is the {check_weight} same with test_weight_direction")check_sum = torch.sum(weight_direction**2,dim=-1)print(f"check_sum=\n{check_sum}")
  • result:
start
Parameter containing:
tensor([[ 1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.],[ 9., 10., 11., 12.]], requires_grad=True)
('parametrizations.weight.original0', Parameter containing:
tensor([[ 5.477],[13.191],[21.119]], requires_grad=True))
('parametrizations.weight.original1', Parameter containing:
tensor([[ 1.,  2.,  3.,  4.],[ 5.,  6.,  7.,  8.],[ 9., 10., 11., 12.]], requires_grad=True))
c_matrix=
tensor([ 5.477, 13.191, 21.119])
a_result_linear=
tensor([[ 20.,  44.,  68.],[ 60., 148., 236.],[100., 252., 404.]], grad_fn=<MmBackward0>)
a_result_wn=
tensor([[ 20.,  44.,  68.],[ 60., 148., 236.],[100., 252., 404.]], grad_fn=<MmBackward0>)
weight_direction=
tensor([[0.183, 0.365, 0.548, 0.730],[0.379, 0.455, 0.531, 0.606],[0.426, 0.474, 0.521, 0.568]], grad_fn=<DivBackward0>)
c_matrix_ones=
tensor([[ 5.477,  5.477,  5.477,  5.477],[13.191, 13.191, 13.191, 13.191],[21.119, 21.119, 21.119, 21.119]])
test_weight_direction=
tensor([[0.183, 0.365, 0.548, 0.730],[0.379, 0.455, 0.531, 0.606],[0.426, 0.474, 0.521, 0.568]])
weight_direction is the True same with test_weight_direction
check_sum=
tensor([1.000, 1.000, 1.000], grad_fn=<SumBackward1>)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词