现充|junyu33

有限域中乘法在 intel 指令集上的高性能实现

/* multiplication in galois field with reduction */
#ifdef __x86_64__
__attribute__((target("sse2,pclmul")))
#endif
inline void gfmul (__m128i a, __m128i b, __m128i *res){
	__m128i tmp3, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12;
	__m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0);
	mul128(a, b, &tmp3, &tmp6);
	tmp7 = _mm_srli_epi32(tmp6, 31);
	tmp8 = _mm_srli_epi32(tmp6, 30);
	tmp9 = _mm_srli_epi32(tmp6, 25);
	tmp7 = _mm_xor_si128(tmp7, tmp8);
	tmp7 = _mm_xor_si128(tmp7, tmp9);
	tmp8 = _mm_shuffle_epi32(tmp7, 147);

	tmp7 = _mm_and_si128(XMMMASK, tmp8);
	tmp8 = _mm_andnot_si128(XMMMASK, tmp8);
	tmp3 = _mm_xor_si128(tmp3, tmp8);
	tmp6 = _mm_xor_si128(tmp6, tmp7);
	tmp10 = _mm_slli_epi32(tmp6, 1);
	tmp3 = _mm_xor_si128(tmp3, tmp10);
	tmp11 = _mm_slli_epi32(tmp6, 2);
	tmp3 = _mm_xor_si128(tmp3, tmp11);
	tmp12 = _mm_slli_epi32(tmp6, 7);
	tmp3 = _mm_xor_si128(tmp3, tmp12);

	*res = _mm_xor_si128(tmp3, tmp6);
}

prerequisite

这里的不可约多项式为x128+x7+x2+x+1.

mul128

首先讨论在不进行约简时,两个 128 bit 多项式的乘法运算。为了编写方便,我们将多项式转化为对应的二进制数,例如x4+x3+x+1对应的二进制为0b11011,即27

intel 中恰好有一个指令可以实现F2n的乘法(即非进位乘法,Carry-less Multiplication),即__m128i _mm_clmulepi64_si128(__m128i a, __m128i b, const int imm8);。它的语法如下:

我们可以借助64位的非进位乘法来实现128位的非进位乘法,具体步骤如下:

/* multiplication in galois field without reduction */
#ifdef __x86_64__
__attribute__((target("sse2,pclmul")))
inline void mul128(__m128i a, __m128i b, __m128i *res1, __m128i *res2) {
	__m128i tmp3, tmp4, tmp5, tmp6;
	tmp3 = _mm_clmulepi64_si128(a, b, 0x00);
	tmp4 = _mm_clmulepi64_si128(a, b, 0x10);
	tmp5 = _mm_clmulepi64_si128(a, b, 0x01);
	tmp6 = _mm_clmulepi64_si128(a, b, 0x11);

	tmp4 = _mm_xor_si128(tmp4, tmp5);
	tmp5 = _mm_slli_si128(tmp4, 8);
	tmp4 = _mm_srli_si128(tmp4, 8);
	tmp3 = _mm_xor_si128(tmp3, tmp5);
	tmp6 = _mm_xor_si128(tmp6, tmp4);
	// initial mul now in tmp3, tmp6
	*res1 = tmp3;
	*res2 = tmp6;
}

以下给出该算法的证明:

证明

设被乘数A的高低位分别为a1a2,同理乘数B的高低位分别为b1b2。在算法前半段,我们有:

tmp6=a1b1

tmp5=a2b1

tmp4=a1b2

tmp3=a2b2

在后半段,我们有:

tmp5=264((tmp4+tmp5)%264)

tmp4=264(tmp4+tmp5)

tmp3=264((tmp4+tmp5)%264)+tmp3

tmp6=264(tmp4+tmp5)+tmp6

故:

res1+res2

=2128tmp6+tmp3

=2128tmp6+2128(264(tmp4+tmp5)264(tmp4+tmp5)%1)+264((tmp4+tmp5)%264)+tmp3

=2128tmp6+264(tmp4+tmp5)264(tmp4+tmp5)%2128+264(tmp4+tmp5)%2128+tmp3

=2128tmp6+264(tmp4+tmp5)+tmp3

而:

AB

=(264a1+a2)(264b1+b2)

=2128tmp6+264(tmp4+tmp5)+tmp3

res1+res2=AB

gfmul

实际上,我们还需要将进行非进位乘法后的结果对x128+x7+x2+x+1进行约简,而sagemath有专门针对有限域处理的轮子,因此我们可以编写一个sagemath的程序来验证结果:

# 定义有限域上的多项式环
P.<x> = PolynomialRing(GF(2))

# 定义模多项式
mod_poly = x^128 + x^7 + x^2 + x + 1

# 定义有限域 GF(2^128)
K.<a> = GF(2**128, modulus=mod_poly)

# 将整数转换为多项式
def int_to_poly(n):
    bin_str = bin(n)[2:]  # 将整数转换为二进制字符串,去掉'0b'前缀
    return sum([int(bit) * x^i for i, bit in enumerate(reversed(bin_str))])

# 将多项式转换回整数
def poly_to_int(p):
    return sum([int(coeff) * (2**i) for i, coeff in enumerate(p.list())])

# 示例:整数表示的元素
f_int, g_int = 98195696920426533817649554218743231661, 43027262476631949179376797970948942433

# 转换为多项式
f_poly = int_to_poly(f_int)
g_poly = int_to_poly(g_int)

# 在 GF(2^128) 上进行运算
f = K(f_poly)
g = K(g_poly)

# 运算示例
add_result = f + g
mul_result = f * g

# 将结果转换回整数
add_result_int = poly_to_int(add_result.polynomial())
mul_result_int = poly_to_int(mul_result.polynomial())

print("f (as int) =", f_int)
print("g (as int) =", g_int)
print("f + g (as int) =", add_result_int)
print("f * g (as int) =", mul_result_int)
'''
f (as int) = 98195696920426533817649554218743231661
g (as int) = 43027262476631949179376797970948942433
f + g (as int) = 140241067422779931769655503331313065676
f * g (as int) = 30853704161780158484268560045100192027
'''

当有了一个正常的对拍程序后,我们就可以分析并调试文章开头给出的代码了。

证明

我们先分析,如何用tmp3,tmp6的位运算来表示F2128(ab)的结果:

tmp3, tmp6 = mul128(a, b)

// 假设 T = ((tmp6 >> 127) ^ (tmp6 >> 126) ^ (tmp6 >> 121))

// 这里的 nr() 表示未约简的意思
res = nr(tmp3 ^ ((tmp6) << 128))
    = tmp3 ^ nr(tmp6 << 128)
    = tmp3 ^ nr(tmp6) ^ nr(tmp6 << 1) ^ nr(tmp6 << 2) ^ nr(tmp6 << 7)
    = tmp3 ^ nr(tmp6) ^ (nr(tmp6) << 1) ^ (nr(tmp6) << 2) ^ (nr(tmp6) << 7)
    // 这里 tmp6 >> 128 肯定是0,所以省略
    = tmp3 ^ (tmp6 ^ T)
           ^ ((tmp6 ^ T) << 1)
           ^ ((tmp6 ^ T) << 2)
           ^ ((tmp6 ^ T) << 7)

但事实上为了计算方便,我们取T = ((tmp6 >> 31) ^ (tmp6 >> 30) ^ (tmp6 >> 25)),因此我们有:

def gfmul(a, b):
    tmp3, tmp6 = mul128(a, b)
    T = ((tmp6 >> 31) ^ (tmp6 >> 30) ^ (tmp6 >> 25))

    res = tmp3 ^ \
      (T >> 96) ^ tmp6 ^ \
      (((T >> 96) ^ tmp6) << 1) ^ \
      (((T >> 96) ^ tmp6) << 2) ^ \
      (((T >> 96) ^ tmp6) << 7)
    return res % (2 ** 128 - 1)

经过与sagemath的比对,它们的输出完全一致。

指令集化

但在128bit数中,并没有将一个128bit数直接右移或左移的操作,如果使用C语言的<<或者>>运算符只会将组成__m128i的四个32位向量一起左移或者右移,从而达不到正确的结果。

实际上,这两个运算符对应的是_mm_slli_epi32_mm_srli_epi32这两个指令,前者的语法如下,后者类似:

__m128i _mm_slli_epi32(__m128i a, int imm8);

_mm_slli_epi32 指令对寄存器 a 中的每个 32 位整数执行逻辑左移操作,移位的位数由立即数 imm8 指定。所有的整数都按照相同的位数进行左移,移出的高位被丢弃,低位补零。

通过相关推导,对于任何128bit数x,我们有以下恒等式(0<i<32):

x << i === mm_slli_epi32(x, i) ^ (mm_srli_epi32(x, 32-i) << 32)

证明比较容易,这里略去。

因此,我们可以修改原来的函数:

def gfmul(a, b):
    tmp3, tmp6 = mul128(a, b)
    T = ((tmp6 >> 31) ^ (tmp6 >> 30) ^ (tmp6 >> 25))
    # T = (srli(tmp6, 31) ^ srli(tmp6, 30) ^ srli(tmp6, 25))

    res = tmp3 ^ \
      (T >> 96) ^ tmp6 ^ \
      (((T >> 96) ^ tmp6) << 1) ^ \
      (((T >> 96) ^ tmp6) << 2) ^ \
      (((T >> 96) ^ tmp6) << 7)

    # res    = tmp3 ^ tmp6 ^ \
    #        = slli(tmp6, 1) ^ (srli(tmp6, 31) << 32) ^ \
    #          slli(tmp6, 2) ^ (srli(tmp6, 30) << 32) ^ \
    #          slli(tmp6, 7) ^ (srli(tmp6, 25) << 32) ^ \
    #          (T >> 96) ^ (T >> 96 << 1) ^ (T >> 96 << 2) ^ (T >> 96 << 7)
    # 这里因为(T >> 96 << 7)本身最多只有14位,比32位小,因此不用进行上文公式的变形
    #        = (T >> 96) ^ tmp6 \
    #          slli(tmp6 ^ (T >> 96), 1) \
    #          slli(tmp6 ^ (T >> 96), 2) \
    #          slli(tmp6 ^ (T >> 96), 7) \
    #          (T << 32) ^ tmp3

    return res & (2 ** 128 - 1)

现在我们可以回到本文开头的源码了:

void gfmul (__m128i a, __m128i b, __m128i *res){
	__m128i tmp3, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12;
	__m128i XMMMASK = _mm_setr_epi32(0xffffffff, 0x0, 0x0, 0x0);
	mul128(a, b, &tmp3, &tmp6);

	tmp7 = _mm_srli_epi32(tmp6, 31);
	tmp8 = _mm_srli_epi32(tmp6, 30);
	tmp9 = _mm_srli_epi32(tmp6, 25);
	tmp7 = _mm_xor_si128(tmp7, tmp8);
	tmp7 = _mm_xor_si128(tmp7, tmp9);
	// tmp7 = T = (srli(tmp6, 31) ^ srli(tmp6, 30) ^ srli(tmp6, 25))
	tmp8 = _mm_shuffle_epi32(tmp7, 147);

	tmp7 = _mm_and_si128(XMMMASK, tmp8);
	// tmp7 = T >> 96
	tmp8 = _mm_andnot_si128(XMMMASK, tmp8);
	// tmp8 = T << 32

	tmp3 = _mm_xor_si128(tmp3, tmp8);
	// tmp3 = ((T << 32) ^ tmp3)
	tmp6 = _mm_xor_si128(tmp6, tmp7);
	// tmp6 = ((T >> 96) ^ tmp6)

	tmp10 = _mm_slli_epi32(tmp6, 1);
	tmp3 = _mm_xor_si128(tmp3, tmp10);
	tmp11 = _mm_slli_epi32(tmp6, 2);
	tmp3 = _mm_xor_si128(tmp3, tmp11);
	tmp12 = _mm_slli_epi32(tmp6, 7);
	tmp3 = _mm_xor_si128(tmp3, tmp12);

	*res = _mm_xor_si128(tmp3, tmp6);
	// res = ((T << 32) ^ tmp3) ^ \
		slli((T >> 96) ^ tmp6), 1) ^ \
		slli((T >> 96) ^ tmp6), 2) ^ \
		slli((T >> 96) ^ tmp6), 7) ^ \
		((T >> 96) ^ tmp6)             
}

因此,我们就实现了使用 intel 指令集在F2128上进行乘法运算的目标。