【背景】
矩阵乘法是线性代数中最常见的问题之一,它不仅在数值计算中具有广泛的应用,还是现代机器学习技术中必不可少的基石。
【定义】
即:
【分析】
若按照上述提及的公式一次对矩阵A、B进行乘积运算。计算C中每一个元素
其源码如下:
1 2 3 4 5 6 7 8 9 10 11 | def traditional(matrix1, matrix2): matrix3 = [] for i in range ( 0 , len (matrix1)): temp = [] for j in range ( 0 , len (matrix2)): t = 0 for k in range ( 0 , len (matrix1)): t + = matrix1[i][k] * matrix2[k][j] temp.append(t) matrix3.append(temp) return matrix3 |
【算法引出-分治法】
根据n阶(此处为了方便叙述,我们假设n是2的幂)矩阵的相关特性,我们可以将每一块矩阵都分为4个大小相等的子矩阵,每一个子矩阵都是n/2×n/2的方阵。于是我们将方程C= AB重写为下述形式:
由此可得:
其分治递推式如下:
利用扩展递归求解得出:
初次分治,得出的结果与传统公式求解的时间复杂度并没有改变,即这样的分治是徒劳的。
【Strassen矩阵乘法-分治、剪枝】
Strassen算法的核心思想是令递归树稍微不那么茂盛一点儿, 即只递归进行7次而不是8次n/2×n/2 矩阵的乘法。减少一次矩阵乘法带来的代价可能是额外几次n/2×n/2矩阵的加法,但只是常数次 。
算法描述如下:
先按照先前的分治思想中矩阵分解的方法将A,B,C进行分解
创建如下7个
做完这7次乘法后,再做若干次加减法就可以得到
其分治递推式如下:
求出:
其Python源代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # -*- coding:utf-8 -*- def mergeMatrix(A11, A12, A13, A14): n = len (A11) for i in range ( 0 , n): A11[i].extend(A12[i]) A13[i].extend(A14[i]) for i in range ( 0 ,n): A11.append(A13[i]) return A11 def division(matrix): A11 = [] A12 = [] A21 = [] A22 = [] half_size = int ( len (matrix) / 2 ) for i in range ( 0 , half_size): A11.append(matrix[i][:half_size]) A12.append(matrix[i][half_size:]) for j in range (half_size, len (matrix)): A21.append(matrix[j][:half_size]) A22.append(matrix[j][half_size:]) return A11, A12, A21, A22 class Strassen: def add( self , m1, m2, size): matrix = [] for i in range ( 0 , size): temp = [] for j in range ( 0 , size): temp.append(m1[i][j] + m2[i][j]) matrix.append(temp) return matrix def sub( self , m1, m2, size): matrix = [] for i in range ( 0 , size): temp = [] for j in range ( 0 , size): temp.append(m1[i][j] - m2[i][j]) matrix.append(temp) return matrix def multiply( self , m1, m2, size): if (size = = 1 ): return [[m1[ 0 ][ 0 ] * m2[ 0 ][ 0 ]]] A11, A12, A21, A22 = division(m1) B11, B12, B21, B22 = division(m2) size = int (size / 2 ) #calculate M1 = A11(B12 - B22) m1 = self .multiply(A11, self .sub(B12, B22, size), size) #calculate M2 = (A11 + A12)B22 m2 = self .multiply( self .add(A11, A12, size), B22, size) #calculate M3 = (A21 + A22)B11 m3 = self .multiply( self .add(A21, A22, size), B11, size) #calculate M4 = A22(B21 - B11) m4 = self .multiply(A22, self .sub(B21, B11,size), size) #calculate M5 = (A11 + A22)(B11 + B22) m5 = self .multiply( self .add(A11, A22, size), self .add(B11, B22, size), size) #calculate M6 = (A12 - A22)(B21 + B22) m6 = self .multiply( self .sub(A12, A22, size), self .add(B21, B22, size), size) #calculate M7 = (A11 - A21)(B11 + B12) m7 = self .multiply( self .sub(A11, A21, size), self .add(B11, B12, size), size) #calculate C11 = M5 + M4 - M2 + M6 C11 = self .add( self .sub( self .add(m5, m4, size), m2, size), m6, size) #calculate C12 = M1 + M2 C12 = self .add(m1, m2, size) #calculate C21 = M3 + M4 C21 = self .add(m3, m4, size) #calculate C22 = M5 + M1 - M3 -M7 C22 = self .sub( self .sub( self .add(m5,m1,size), m3,size), m7,size) return mergeMatrix(C11, C12, C21, C22) s = Strassen() #print(s.sub([[1,2],[3,4]],[[1,2],[3,4]])) matrixA = [ [ 1 , 1 , 1 , 1 ], [ 1 , 2 , 3 , 4 ], [ 1 , 2 , 3 , 4 ], [ 1 , 2 , 3 , 4 ] ] matrixB = [ [ 1 , 2 , 3 , 4 ], [ 1 , 2 , 3 , 4 ], [ 1 , 2 , 3 , 4 ], [ 1 , 2 , 3 , 4 ] ] print (s.multiply(matrixA, matrixB, 4 )) |
【总结】
-
在此问题中,相对于传统算法的复杂度
,使用分治+剪枝的Strassen算法的表现 更胜一筹。因此我们知道,在具有某些可以分治处理性质的问题中,也可以多利用分治法思想对问题进行求解。 -
对矩阵乘法问题的研究中,Hopcroft和Kerr已证明,要计算2个2×2 矩阵的乘积,7次乘法是必要的。因此,要想进一步改进矩阵乘法的时间复杂性,就不能再基于计算2×2矩阵的7次乘法这样的方法了。或许应当研究3×3 或5×5矩阵的更好算法。
-
在Strassen之后又有许多算法改进了矩阵乘法的计算时间 复杂性。目前最好的计算时间上界是
。 -
到目前为止仍无法确切知道矩阵乘法的时间复杂性,关于这一研究课题还有许多工作可做。
【参考文献】
-
王红梅,胡明.《算法设计与分析》[M].清华大学出版社
-
王晓东.《算法设计与分析(第五版)》[M].电子工业出版社
[…] Strassen矩阵乘法(分治、剪枝)[附Python源码] […]
[…] 2、分治法 20190304 20190304 20190304 […]