文章目录
- 1. description
- 2. pytorch
1. description
假设我们有一个矩阵A 表示如下:
A = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] \begin{equation} A= \begin{bmatrix} 1&2&3&4\\\\ 5&6&7&8\\\\ 9&10&11&12\end{bmatrix} \end{equation} A= 159261037114812
经过pytorch中的torch.nn.utils.parametrizations.weight_norm包装后,只是将原来的矩阵weight分解为两个部分
- parametrizations.weight.original0:表示的行向量的模长,
weight_g
5.477 = 1 2 + 2 2 + 3 2 + 4 2 ; 13.191 = 5 2 + 6 2 + 7 2 + 8 2 ; 21.119 = 9 2 + 1 0 2 + 1 1 2 + 1 2 2 ; \begin{equation} 5.477=\sqrt{1^2+2^2+3^2+4^2};13.191=\sqrt{5^2+6^2+7^2+8^2};21.119=\sqrt{9^2+10^2+11^2+12^2}; \end{equation} 5.477=12+22+32+42;13.191=52+62+72+82;21.119=92+102+112+122; - parametrizations.weight.original1:表示的矩阵weigh:
weight_v
V = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] \begin{equation} V= \begin{bmatrix} 1&2&3&4\\\\ 5&6&7&8\\\\ 9&10&11&12\end{bmatrix} \end{equation} V= 159261037114812
2. pytorch
- python:
import torch
import torch.nn as nntorch.set_printoptions(precision=3, sci_mode=False, threshold=torch.inf)if __name__ == "__main__":run_code = 0a_matrix = torch.arange(12).reshape((3, 4)).to(torch.float)b_matrix = torch.arange(12).reshape((3, 4)).to(torch.float) + 1print(f"start")my_linear = nn.Linear(in_features=4, out_features=3, bias=False)my_linear.weight = nn.Parameter(b_matrix)print(my_linear.weight)# wn_linear = nn.utils.weight_norm(my_linear)wn_linear = nn.utils.parametrizations.weight_norm(my_linear, name='weight')for p in wn_linear.named_parameters():print(p)c_matrix = torch.sqrt(torch.sum(b_matrix ** 2, dim=-1))print(f"c_matrix=\n{c_matrix}")a_result_linear = my_linear(a_matrix)a_result_wn = wn_linear(a_matrix)print(f"a_result_linear=\n{a_result_linear}")print(f"a_result_wn=\n{a_result_wn}")weight_direction = wn_linear.parametrizations.weight.original1 / wn_linear.parametrizations.weight.original0print(f"weight_direction=\n{weight_direction}")c_matrix_ones = c_matrix.reshape(-1, 1) @ torch.ones(1, 4)print(f"c_matrix_ones=\n{c_matrix_ones}")test_weight_direction = b_matrix / c_matrix_onesprint(f"test_weight_direction=\n{test_weight_direction}")check_weight = torch.allclose(weight_direction,test_weight_direction)print(f"weight_direction is the {check_weight} same with test_weight_direction")check_sum = torch.sum(weight_direction**2,dim=-1)print(f"check_sum=\n{check_sum}")
- result:
start
Parameter containing:
tensor([[ 1., 2., 3., 4.],[ 5., 6., 7., 8.],[ 9., 10., 11., 12.]], requires_grad=True)
('parametrizations.weight.original0', Parameter containing:
tensor([[ 5.477],[13.191],[21.119]], requires_grad=True))
('parametrizations.weight.original1', Parameter containing:
tensor([[ 1., 2., 3., 4.],[ 5., 6., 7., 8.],[ 9., 10., 11., 12.]], requires_grad=True))
c_matrix=
tensor([ 5.477, 13.191, 21.119])
a_result_linear=
tensor([[ 20., 44., 68.],[ 60., 148., 236.],[100., 252., 404.]], grad_fn=<MmBackward0>)
a_result_wn=
tensor([[ 20., 44., 68.],[ 60., 148., 236.],[100., 252., 404.]], grad_fn=<MmBackward0>)
weight_direction=
tensor([[0.183, 0.365, 0.548, 0.730],[0.379, 0.455, 0.531, 0.606],[0.426, 0.474, 0.521, 0.568]], grad_fn=<DivBackward0>)
c_matrix_ones=
tensor([[ 5.477, 5.477, 5.477, 5.477],[13.191, 13.191, 13.191, 13.191],[21.119, 21.119, 21.119, 21.119]])
test_weight_direction=
tensor([[0.183, 0.365, 0.548, 0.730],[0.379, 0.455, 0.531, 0.606],[0.426, 0.474, 0.521, 0.568]])
weight_direction is the True same with test_weight_direction
check_sum=
tensor([1.000, 1.000, 1.000], grad_fn=<SumBackward1>)