AES(Advanced Encryption Standard)

一、AES概述

AES加密算法,即Rijndael算法,是一种对称分组密码,它可以使用长度为128、192和256位的密钥处理128位的数据块。这些不同的“风格”可以称为“AES-128”、“AES-192” 和 “AES-256”。

本文使用python实现AES-128,使用CBC模式。

二、AES算法框架

  • 字节代替(SubBytes):用一个S盒完成分组的字节到字节的代替。
  • 行移位(ShiftRows):一个简单的置换。
  • 列混淆(MixColumns):利用域GF(2^8)上的算术特性的一个代替。
  • 轮密钥加(AddRoundKey):当前分组和扩展密钥的一部分进行按位异或XOR。

三、密钥扩展

解释

AES首先将初始密钥输入到一个4*4的状态矩阵中,如下图所示:

这个4*4矩阵的每一列的4个字节组成一个字,矩阵4列的4个字依次命名为W[0]、W[1]、W[2]和W[3],它们构成一个以字为单位的数组W。例如,设密钥K为”abcdefghijklmnop”,则K0 = ‘a’,K1 = ‘b’, K2 = ‘c’,K3 = ‘d’,W[0] = “abcd”。

接着,对W数组扩充40个新列,构成总共44列的扩展密钥数组。新列以如下的递归方式产生:

1.如果i不是4的倍数,那么第i列由如下等式确定:

W[i]=W[i-4]⨁W[i-1]

2.如果i是4的倍数,那么第i列由如下等式确定:

W[i]=W[i-4]⨁T(W[i-1])

其中,T是一个有点复杂的函数。

函数T由3部分组成:字循环、字节代换和轮常量异或,这3部分的作用分别如下。

a.字循环:将1个字中的4个字节循环左移1个字节。即将输入字[b0, b1, b2, b3]变换成[b1,b2,b3,b0]。

b.字节代换:对字循环的结果使用S盒进行字节代换。

c.轮常量异或:将前两步的结果同轮常量Rcon[j]进行异或,其中j表示轮数。

轮常量Rcon[j]是一个字,其值见下表。

j 1 2 3 4 5
Rcon[j] 01 00 00 00 02 00 00 00 04 00 00 00 08 00 00 00 10 00 00 00
j 6 7 8 9 10
Rcon[j] 20 00 00 00 40 00 00 00 80 00 00 00 1B 00 00 00 36 00 00 00

下面举个例子:

设初始的128位密钥为:

3C A1 0B 21 57 F0 19 16 90 2E 13 80 AC C1 07 BD

那么4个初始值为:

W[0] = 3C A1 0B 21

W[1] = 57 F0 19 16

W[2] = 90 2E 13 80

W[3] = AC C1 07 BD

下面求扩展的第1轮的子密钥(W[4],W[5],W[6],W[7])。

由于4是4的倍数,所以:

W[4] = W[0] ⨁ T(W[3])

T(W[3])的计算步骤如下:

循环地将W[3]的元素移位:AC C1 07 BD变成C1 07 BD AC;

将 C1 07 BD AC 作为S盒的输入,输出为78 C5 7A 91;

将78 C5 7A 91与第一轮轮常量Rcon[1]进行异或运算,将得到79 C5 7A 91,因此,T(W[3])=79 C5 7A 91,故

W[4] = 3C A1 0B 21 ⨁ 79 C5 7A 91 = 45 64 71 B0

其余的3个子密钥段的计算如下:

W[5] = W[1] ⨁ W[4] = 57 F0 19 16 ⨁ 45 64 71 B0 = 12 94 68 A6

W[6] = W[2] ⨁ W[5] =90 2E 13 80 ⨁ 12 94 68 A6 = 82 BA 7B 26

W[7] = W[3] ⨁ W[6] = AC C1 07 BD ⨁ 82 BA 7B 26 = 2E 7B 7C 9B

所以,第一轮的密钥为 45 64 71 B0 12 94 68 A6 82 BA 7B 26 2E 7B 7C 9B。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def gen_key(key):
key_bytes = key.encode('utf-8')
if len(key_bytes) != 16:
raise ValueError("Key must be 16 bytes after UTF-8 encoding.")
key_hex = [hex(b) for b in key_bytes]
key_rotate = []
w = [[] for i in range(0, 44)]
for i in range(0, 16):
w[i // 4].append(key_hex[i])
for i in range(4, 44):
gw = copy.deepcopy(w[i - 1])
if i % 4 == 0:
gw[0], gw[1], gw[2], gw[3] = gw[1], gw[2], gw[3], gw[0]
gw = substitute(gw) #g(w(i-1))
gw[0] = hex(int(gw[0], 16) ^ rcon[i // 4 - 1])
for j in range(0, 4):
w[i].append(hex(int(gw[j], 16) ^ int(w[i-4][j], 16)))
key_rotate = [w[i * 4] + w[i * 4 + 1] + w[i * 4 + 2] + w[ i* 4 + 3] for i in range(0, 11)] # 轮密钥列表,每个元素都是有16个字节的列表
return key_rotate

四、加解密过程

字节代换(SubBytes)

解释

AES的字节代换其实就是一个简单的查表操作。AES定义了一个S盒和一个逆S盒。
AES的S盒和逆S盒:

状态矩阵中的元素按照下面的方式映射为一个新的字节:把该字节的高4位作为行值,低4位作为列值,取出S盒或者逆S盒中对应的行的元素作为输出。例如,加密时,输出的字节S1为0x12,则查S盒的第0x01行和0x02列,得到值0xc9,然后替换S1原有的0x12为0xc9。状态矩阵经字节代换后的图如下:

S盒是按照这个公式计算出来的:GF(2^8) = GF(2) [x]/(x^8 + x^4 + x^3 + x + 1)

AES128加密-S盒和逆S盒构造推导及代码实现Rijndael S-box初识白盒密码有限域算术

代码
1
2
3
4
5
6
7
8
def substitute(m_hex, inverse=False):
m_s = []
box = s_box if not inverse else i_s_box
for i in m_hex:
x, y = int(i, 16) // 16, int(i, 16) % 16
temp = hex(box[x*16+y])
m_s.append(temp)
return m_s

行移位(ShiftRows)

解释

行移位是一个4x4的矩阵内部字节之间的置换,用于提供算法的扩散性。

行移位变换完成基于行的循环移位操作,变换方法为:第0行不变,第1行循环左移1个字节,第2行循环左移两个字节,第3行循环左移3个字节。如下图所示:(从上往下读)

逆向行移位即是相反的操作,即:第一行保持不变,第二行循环右移1个字节,第三行循环右移两个字节,第四行循环左移3个字节。

代码
1
2
3
4
5
6
7
8
9
10
11
12
def shiftrows(a, inverse=False): #inverse为True时表示为逆操作,默认为False
if not inverse:
return [ a[0], a[5], a[10], a[15],
a[4], a[9], a[14], a[3],
a[8], a[13], a[2], a[7],
a[12], a[1], a[6], a[11] ]
else :
return[ a[0], a[13], a[10], a[7],
a[4], a[1], a[14], a[11],
a[8], a[5], a[2], a[15],
a[12], a[9], a[6], a[3] ]

列混淆(MixColumns)

解释

列混淆是将状态数组的每一列乘以一个矩阵,其中乘法是在有限域GF(2^8)上进行的,分为正向列混淆和列混淆逆变换两种操作。正向列混淆用于加密操作,列混淆逆变换用于解密操作。
正向列混淆变换过程:

逆向列混淆变换可以再乘以矩阵的逆得到。

代码
1
2
3
4
5
6
7
8
9
10
11
12
def mixcolumn(m_row, inverse=False):
matrix = mix_column_matrix if not inverse else i_mix_column_matrix
m_col = []
for i in range(0, 16):
x, y = i % 4, i // 4
result = 0
for j in range(0, 4):
result ^= (mul(matrix[x * 4 + j], int(m_row[y * 4 + j], 16)))
result = mod(result)
m_col.append(hex(result))
return m_col

轮密钥加(AddRoundKey)

解释

轮密钥加是将轮密钥简单地与状态进行逐比特异或。这个操作相对简单,其依据的原理是“任何数和自身的异或结果为0”。加密过程中,每轮的输入与轮子密钥异或一次;因此,解密时再异或上该轮的轮子密钥即可恢复。

轮密钥加过程可以看成是字逐位异或的结果,也可以看成字节级别或者位级别的操作。也就是说,可以看成S0 S1 S2 S3 组成的32位字与W[4i]的异或运算。

五、完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#分组长度为16个字节,加密10轮
#AES-128
import copy
import os

# 密钥扩展
def gen_key(key):
key_bytes = key.encode('utf-8')
if len(key_bytes) != 16:
raise ValueError("Key must be 16 bytes after UTF-8 encoding.")
key_hex = [hex(b) for b in key_bytes]
key_rotate = []
w = [[] for i in range(0, 44)]
for i in range(0, 16):
w[i // 4].append(key_hex[i])
for i in range(4, 44):
gw = copy.deepcopy(w[i - 1])
if i % 4 == 0:
gw[0], gw[1], gw[2], gw[3] = gw[1], gw[2], gw[3], gw[0]
gw = substitute(gw) #g(w(i-1))
gw[0] = hex(int(gw[0], 16) ^ rcon[i // 4 - 1])
for j in range(0, 4):
w[i].append(hex(int(gw[j], 16) ^ int(w[i-4][j], 16)))
key_rotate = [w[i * 4] + w[i * 4 + 1] + w[i * 4 + 2] + w[ i* 4 + 3] for i in range(0, 11)] # 轮密钥列表,每个元素都是有16个字节的列表
return key_rotate

# 两个多项式相乘
def mul(poly1, poly2):
result = 0
for index in range(poly2.bit_length()):
if poly2 & (1 << index):
result ^= (poly1 << index)
return result

# 多项式poly模多项式100011011
def mod(poly, mod = 0b100011011):
while poly.bit_length() > 8:
poly ^= (mod << (poly.bit_length() - 9))
return poly

#对输入的十六进制列表 m_hex 进行字节代换操作。
def substitute(m_hex, inverse=False):
m_s = []
box = s_box if not inverse else i_s_box
for i in m_hex:
x, y = int(i, 16) // 16, int(i, 16) % 16
temp = hex(box[x*16+y])
m_s.append(temp)
return m_s

s_box = [0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16]
i_s_box = [0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB,
0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB,
0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E,
0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25,
0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92,
0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84,
0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06,
0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B,
0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73,
0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E,
0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B,
0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4,
0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F,
0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF,
0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61,
0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D]

rcon = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36]

def xor(a, key):
return [hex(int(ai, 16) ^ int(ki, 16)) for ai, ki in zip(a, key)]

# 行移位
def shiftrows(a, inverse=False): #inverse为True时表示为逆操作,默认为False
if not inverse:
return [ a[0], a[5], a[10], a[15],
a[4], a[9], a[14], a[3],
a[8], a[13], a[2], a[7],
a[12], a[1], a[6], a[11] ]
else :
return[ a[0], a[13], a[10], a[7],
a[4], a[1], a[14], a[11],
a[8], a[5], a[2], a[15],
a[12], a[9], a[6], a[3] ]

# 列混淆
def mixcolumn(m_row, inverse=False):
matrix = mix_column_matrix if not inverse else i_mix_column_matrix
m_col = []
for i in range(0, 16):
x, y = i % 4, i // 4
result = 0
for j in range(0, 4):
result ^= (mul(matrix[x * 4 + j], int(m_row[y * 4 + j], 16)))
result = mod(result)
m_col.append(hex(result))
return m_col

# 列混合乘的矩阵
mix_column_matrix = [0x2, 0x3, 0x1, 0x1,
0x1, 0x2, 0x3, 0x1,
0x1, 0x1, 0x2, 0x3,
0x3, 0x1, 0x1, 0x2]
# 列混合乘的逆矩阵
i_mix_column_matrix = [0xe, 0xb, 0xd, 0x9,
0x9, 0xe, 0xb, 0xd,
0xd, 0x9, 0xe, 0xb,
0xb, 0xd, 0x9, 0xe]


def aes_encrypt_block(block, key_rotate):
state = block
state = xor(state, key_rotate[0])
for rnd in range(1, 10):
state = substitute(state)
state = shiftrows(state)
state = mixcolumn(state)
state = xor(state, key_rotate[rnd])
state = substitute(state)
state = shiftrows(state)
state = xor(state, key_rotate[10])
return [int(b, 16) for b in state]

def aes_decrypt_block(block, key_rotate):
state = block
state = xor(state, key_rotate[10])
state = shiftrows(state, inverse=True)
state = substitute(state, inverse=True)
for rnd in range(9, 0, -1):
state = xor(state, key_rotate[rnd])
state = mixcolumn(state, inverse=True)
state = shiftrows(state, inverse=True)
state = substitute(state, inverse=True)
state = xor(state, key_rotate[0])
return [int(b, 16) for b in state]

def pad(data):
pad_len = 16 - (len(data) % 16)
return data + bytes([pad_len] * pad_len)

def unpad(data):
pad_len = data[-1]
return data[:-pad_len]


def aes_cbc_encrypt(plaintext, key):
key_rotate = gen_key(key)
iv = os.urandom(16)
print(f"\n[生成随机IV] (16字节): {iv.hex()}")
plaintext = pad(plaintext.encode())
blocks = [plaintext[i:i+16] for i in range(0, len(plaintext), 16)]
ciphertext = iv
prev = iv
for block in blocks:
block = bytes([b ^ p for b, p in zip(block, prev)])
encrypted = aes_encrypt_block([hex(b)[2:].zfill(2) for b in block], key_rotate)
encrypted_bytes = bytes(encrypted)
ciphertext += encrypted_bytes
prev = encrypted_bytes
return ciphertext.hex()

def aes_cbc_decrypt(ciphertext, key):
key_rotate = gen_key(key)
ciphertext = bytes.fromhex(ciphertext)
iv = ciphertext[:16]
blocks = [ciphertext[i:i+16] for i in range(16, len(ciphertext), 16)]
plaintext = b''
prev = iv
for block in blocks:
decrypted = aes_decrypt_block([hex(b)[2:].zfill(2) for b in block], key_rotate)
decrypted = bytes([d ^ p for d, p in zip(decrypted, prev)])
plaintext += decrypted
prev = block
return unpad(plaintext).decode()

if __name__ == '__main__':
mode = input("请输入模式 (encrypt/decrypt): ").strip().lower()
key = input("请输入16字节密钥: ").strip()
if len(key.encode('utf-8')) != 16:
print("错误:密钥长度必须为16字节!")
else:
if mode == "encrypt":
plaintext = input("请输入明文: ").strip()
print("加密后密文:", aes_cbc_encrypt(plaintext, key))
elif mode == "decrypt":
ciphertext = input("请输入密文 (16进制字符串): ").strip()
try:
print("解密后明文:", aes_cbc_decrypt(ciphertext, key))
except Exception as e:
print(f"解密失败: {e}")
else:
print("无效模式,请输入 encrypt 或 decrypt")

aes.py

aes1.py –有每轮结果的版本