Warning: Undefined array key "HTTP_ACCEPT_LANGUAGE" in /www/wwwroot/blog/wp-content/plugins/UEditor-KityFormula-for-wordpress/main.php on line 13
【算法设计与分析】Strassen矩阵乘法(分治、剪枝)[附Python源码] – Machine World

【背景】

        矩阵乘法是线性代数中最常见的问题之一,它不仅在数值计算中具有广泛的应用,还是现代机器学习技术中必不可少的基石。

【定义】

ABn×nABn×n

即:

An×nBn×n=Cn×n

ABCCij

Cij=k=1nAikBkj

【分析】

        若按照上述提及的公式一次对矩阵A、B进行乘积运算。计算C中每一个元素Cij 需做n次乘法和n-1次加法运算,因此,欲计算出C中每一个元素的时间复杂度为O(n3)

其源码如下:

1
2
3
4
5
6
7
8
9
10
11
def traditional(matrix1, matrix2):
    matrix3 = []
    for in range(0len(matrix1)):
        temp = []
        for in range(0len(matrix2)):
            = 0
            for in range(0len(matrix1)):
                += matrix1[i][k] * matrix2[k][j]
            temp.append(t)
        matrix3.append(temp)
    return matrix3

【算法引出-分治法】

        根据n阶(此处为了方便叙述,我们假设n是2的幂)矩阵的相关特性,我们可以将每一块矩阵都分为4个大小相等的子矩阵,每一个子矩阵都是n/2×n/2的方阵。于是我们将方程C= AB重写为下述形式:

[C11C12C21C22]=[A11A12A21A22][B11B12B21B22]

由此可得:

C11=A11B11+A12B21C12=A11B12+A12B22C21=A21B11+A22B21C22=A21B12+A22B22

其分治递推式如下:

T(n)={O(1)n=28T(n/2)+O(n2)n>2

利用扩展递归求解得出:

T(n)=O(n3)

初次分治,得出的结果与传统公式求解的时间复杂度并没有改变,即这样的分治是徒劳的。

【Strassen矩阵乘法-分治、剪枝】

        Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿, 即只递归进行7次而不是8次n/2×n/2 矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次n/2×n/2矩阵的加法,但只是常数次 。

算法描述如下:

先按照先前的分治思想中矩阵分解的方法将A,B,C进行分解

创建如下7个n/2×n/2的矩阵M1,M2,M3,,M7:

M1=A11(B12B22)M2=(A11+A12)B22M3=(A21+A22)B11M4=A22(B21B11)M5=(A11+A22)(B11+B22)M6=(A12A22)(B21+B22)M7=(A11A21)(B11+B12)

做完这7次乘法后,再做若干次加减法就可以得到C11,C12,C21,C22,他们的计算公式如下:

C11=M5+M4M2+M6C12=M1+M2C21=M3+M4C22=M5+M1M3M7

其分治递推式如下:

T(n)={O(1)n=27T(n/2)+O(n2)n>2

求出:

T(n)=O(nlog7)O(n2.81)

其Python源代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding:utf-8 -*-
 
def mergeMatrix(A11, A12, A13, A14):
    = len(A11)
    for in range(0, n):
        A11[i].extend(A12[i])
        A13[i].extend(A14[i])
    for in range(0,n):
        A11.append(A13[i])
    return A11
def division(matrix):
    A11 = []
    A12 = []
    A21 = []
    A22 = []
    half_size = int(len(matrix) / 2)
    for in range(0, half_size):
        A11.append(matrix[i][:half_size])
        A12.append(matrix[i][half_size:])
    for in range(half_size, len(matrix)):
        A21.append(matrix[j][:half_size])
        A22.append(matrix[j][half_size:])
    return A11, A12, A21, A22
 
 
class Strassen:
    def add(self, m1, m2, size):
        matrix = []
        for in range(0, size):
            temp = []
            for in range(0, size):
                temp.append(m1[i][j] + m2[i][j])
            matrix.append(temp)
        return matrix
 
    def sub(self, m1, m2, size):
        matrix = []
        for in range(0, size):
            temp = []
            for in range(0, size):
                temp.append(m1[i][j] - m2[i][j])
            matrix.append(temp)
        return matrix
 
    def multiply(self, m1, m2, size):
        if(size == 1):
            return [[m1[0][0* m2[0][0]]]
        A11, A12, A21, A22 = division(m1)
        B11, B12, B21, B22 = division(m2)
        size = int(size / 2)
 
 
        #calculate M1 = A11(B12 - B22)
        m1 = self.multiply(A11, self.sub(B12, B22, size), size)
 
        #calculate M2 = (A11 + A12)B22
        m2 = self.multiply(self.add(A11, A12, size), B22, size)
 
        #calculate M3 = (A21 + A22)B11
        m3 = self.multiply(self.add(A21, A22, size), B11, size)
 
        #calculate M4 = A22(B21 - B11)
        m4 = self.multiply(A22, self.sub(B21, B11,size), size)
 
        #calculate M5 = (A11 + A22)(B11 + B22)
        m5 = self.multiply(self.add(A11, A22, size), self.add(B11, B22, size), size)
 
        #calculate M6 = (A12 - A22)(B21 + B22)
        m6 = self.multiply(self.sub(A12, A22, size), self.add(B21, B22, size), size)
 
        #calculate M7 = (A11 - A21)(B11 + B12)
        m7 = self.multiply(self.sub(A11, A21, size), self.add(B11, B12, size), size)
 
        #calculate C11 = M5 + M4 - M2 + M6
        C11 = self.add(self.sub(self.add(m5, m4, size), m2, size), m6, size)
 
        #calculate C12 = M1 + M2
        C12 = self.add(m1, m2, size)
 
        #calculate C21 = M3 + M4
        C21 = self.add(m3, m4, size)
 
        #calculate C22 = M5 + M1 - M3 -M7
        C22 = self.sub(self.sub(self.add(m5,m1,size), m3,size), m7,size)
 
        return mergeMatrix(C11, C12, C21, C22)
 
 
= Strassen()
#print(s.sub([[1,2],[3,4]],[[1,2],[3,4]]))
matrixA = [
    [1111],
    [1234],
    [1234],
    [1234]
]
matrixB = [
    [1234],
    [1234],
    [1234],
    [1234]
]
print(s.multiply(matrixA, matrixB, 4))

【总结】

  • 在此问题中,相对于传统算法的复杂度O(n3),使用分治+剪枝的Strassen算法的表现O(n2.81)更胜一筹。因此我们知道,在具有某些可以分治处理性质的问题中,也可以多利用分治法思想对问题进行求解。

  • 对矩阵乘法问题的研究中,Hopcroft和Kerr已证明,要计算2个2×2 矩阵的乘积,7次乘法是必要的。因此,要想进一步改进矩阵乘法的时间复杂性,就不能再基于计算2×2矩阵的7次乘法这样的方法了。或许应当研究3×3 或5×5矩阵的更好算法。

  • 在Strassen之后又有许多算法改进了矩阵乘法的计算时间 复杂性。目前最好的计算时间上界是 O(n2.376)

  • 到目前为止仍无法确切知道矩阵乘法的时间复杂性,关于这一研究课题还有许多工作可做。

【参考文献】

  • 王红梅,胡明.《算法设计与分析》[M].清华大学出版社

  • 王晓东.《算法设计与分析(第五版)》[M].电子工业出版社

作者 WellLee

《【算法设计与分析】Strassen矩阵乘法(分治、剪枝)[附Python源码]》有2条评论

回复 【2020年话】假期目标 - Machine World 取消回复

您的邮箱地址不会被公开。 必填项已用 * 标注