calculate weil-pairing from 0 to 1
为了更好理解 weil-pairing 以及 miller 算法,这里使用 python 从 0 开始实现了一遍。
用到的依赖只有 pow
函数,math.inf
和深拷贝函数 copy.deepcopy
,且使用了本人并不顺手的 OOP 编程范式。
虽然写了之后对 weil-pairing 和 OOP 的理解同时加深了(
从 开始
我们在高中竞赛时学过
然而今天没有必要再讲一遍 exgcd,因为先前有一篇 post 已经详细介绍了其原理,这里我选择直接使用 python 的语法糖:
pow(x, -1, p)
来求
于是对于分数
引入椭圆曲线
这里引出本 post 的主角,一个定义在
这里我们使用一个类 EpCurve
,来完成对它的抽象。目前,这个类已经可以放入求逆元和除法两个方法,为了体现这个曲线的特性,我们可以加一个判断点
class EpCurve:
def __init__(self, a, b, p) -> None:
self.A = a
self.B = b
self.mod = p
def inverse(self, x:int) -> int:
return pow(x, -1, self.mod)
def frac(self, x:int, y:int) -> int:
return x * self.inverse(y) % self.mod
def isOnCurve(self, x, y) -> bool:
return y*y % self.mod == (x*x*x + self.A*x + self.B) % self.mod
以及对应的实例 curve = EpCurve(30, 34, 631)
。
然后我们可以在 EpCurve
的基础上,抽象椭圆曲线上的点 EpPoint
:
class EpPoint(EpCurve):
def __init__(self, x, y, curve) -> None:
super().__init__(curve.A, curve.B, curve.mod)
if not self.isOnCurve(x, y):
raise ValueError(f"Point ({x}, {y}) is not on the curve")
self.x = x % self.mod
self.y = y % self.mod
self.O = False
self.curve = curve
def __str__(self) -> str:
if self.O:
return 'O'
return f"({self.x}, {self.y})"
这里 self.x self.y
就是点对应的坐标 self.O
对应了它是不是无穷远点 __str__
方法就是 python 内置的一个魔术方法,用于在 print 这个实例时调用这个函数输出对应的字符串。
点的运算
在大二下期的应用密码学课上,我们系统学习了椭圆曲线上点的运算,知道了如何进行点加运算与倍点运算:
对于两个点
于是我们得到
在几何意义上
对于无穷远点
def calcLam(self, other) -> int:
x1, x2, y1, y2 = self.x, other.x, self.y, other.y
if x2 == x1 and y2 == y1:
# P == Q
if y1 == 0:
return math.inf
return self.frac(3*x1*x1 + self.A, 2*y1) # s = lambda
else:
# P != Q
if x2 == x1:
return math.inf
return self.frac(y2 - y1, x2 - x1)
def add(self, other):
if self.O: # is the point at infinity
# O + P = P
return EpPoint(other.x, other.y, self.curve)
if other.O:
# P + O = P
return EpPoint(self.x, self.y, self.curve)
s = self.calcLam(other)
if s == math.inf:
res = copy.deepcopy(self)
res.O = True
return res
x1, x2, y1, y2 = self.x, other.x, self.y, other.y
x3 = (s*s - x1 - x2) % self.mod
y3 = (s * (x1 - x3) - y1) % self.mod
return EpPoint(x3, y3, self.curve)
对于点的乘法,当然可以使用类似于快速幂的算法解决。但只不过之后的内容跟点乘运算关系不大,这里就直接使用普通的累加解决问题,用于调试的正确性检查:
def mul(self, x): # check only, no performance, calc x*P
res = self
for i in range(x - 1):
res = res.add(self)
return res
阶的计算
原理也很简单,对一个点进行累加,如果变成无穷远点就代表这个点的阶(也就是累加次数)找到了。
def calcOrder(self) -> int:
back = self; cnt = 1
while True:
cnt += 1
back = back.add(self)
if back.O == 1:
return cnt
应用密码学教材的内容就到此为止了。
miller 算法
我们首先简要介绍一下有理函数和除子
对于任意单变量有理函数:
根据代数基本定理,我们总可以在复平面上将其因式分解为:
这里
为了简便,我们可以使用除子将式子简化:
这里的
当然,这里的
e.g. 我们令
, , . 那么
肯定在 上。因此: 这里
很好理解,因为此时 。虽然 并没有分母,但极点 的重数是可以通过射影变换与无穷小分析证明是 的,这里篇幅原因不做展开。
经过上述例子我们可以猜想
,当且仅当存在常数 , . - 展开式系数(度数)之和为
. - 如果对这几个点做点加运算,结果为
. - 特别地,如果不存在零点和极点,那么
为常数。
前者被称为除子的唯一性定理,后三者被称为主除子定理。结论与有理函数
基础知识铺垫完成,现在引入 miller 算法中位于核心地位的 line function
显然,
回到代码实现,由于这里
def lineFunc(self, other, S) -> int: # P, Q, S. S is the f_p(S)
xp, xq, yp, yq = self.x, other.x, self.y, other.y
if self.calcLam(other) == math.inf:
return S.x - xp
lam = self.calcLam(other)
return self.frac(S.y - yp - lam*(S.x - xp), S.x + xp + xq - lam*lam)
然后来到了 miller 算法本体:
设正整数
T = P, f = 1
for i from (n-2) to 0:
f = f^2 * g_{T,T}
T = 2T
if b[i] == 1:
f = f * g_{T,P}
T = T + P
return f
使用归纳法证明如下:
首先
时,算法返回 且 , 是常数,既没有零点也没有极点。将 代入 ,所有的项都消掉了,因此 成立。 设
。有 , 成立。我们需要证明 时的情况成立: 假设
,则不走 if 分支,我们实际的运算为 。则此时新的 : 结果正确,并且此时
,也符合归纳要求,因此归纳在 分支下成立。同理也可证得 分支下成立。 综上所述,该算法可以生成有理函数
, 使得 。Q.E.D
特别的,如果点
回到代码实现,由于 miller 算法是基于 line function
def miller(self, S) -> int: # P, S, calc f_P(S)
if self.O:
return 1
T = self; f = 1
n = bin(self.calcOrder())[3:] # 0b1 01001..101
for bit in n:
f = f * f * T.lineFunc(T, S) % self.mod
T = T.add(T)
if bit == '1':
f = f * T.lineFunc(self, S) % self.mod
T = T.add(self)
return f
weil pairing
我们设有理函数
weil pairing 的值除了与有理函数
- 单位根:
- 双线性:
, - 非退化性:
- 交错性:
, .
由于我们已经使用 miller 算法计算出了
def weil(self, Q, S) -> int: # P, Q, S
negS = EpPoint(S.x, -S.y, self.curve) # -S
res1 = self.frac(self.miller(Q.add(S)), self.miller(S))
res2 = self.frac(Q.miller(self.add(negS)), Q.miller(negS))
return self.frac(res1, res2)
例如取:
P = EpPoint(36, 60, curve) # Order 5
Q = EpPoint(121, 387, curve) # Order 5
S = EpPoint(0, 36, curve) # Order 130
这里的
print('P, Q\'s weil pairing:', P.weil(Q, S))
# P, Q's weil pairing: 242
print(242**5 % 631) # property 1 holds
# 1
print('Q, P\'s weil pairing:', Q.weil(P, S))
# Q, P's weil pairing: 279
print(242*279 % 631) # property 4 holds
# 1
然后我们再举另一个例子:
P3 = P.mul(3)
Q4 = Q.mul(4)
print('P3, Q4\'s weil pairing:', P3.weil(Q4, S))
# P3, Q4's weil pairing: 512
print(242**12 % 631) # property 2 holds
# 512
举例说明性质 3 也很简单:
O = P.mul(5)
print('P, O\'s weil pairing:', P.weil(O, S)) # property 3 holds
# P, O's weil pairing: 1
tate pairing
tate pairing 与 weil pairing 的区别是其椭圆曲线定义在有限域
其中
满足 weil pairing 中的单位根与双线性这两个性质。
可以发现,tate pairing 的计算量比 weil pairing 少一半,所以在密码学中更受青睐。例如 tate pairing 的变种 ate pairing 在以太坊中被广泛使用。
def tate(self, Q, S) -> int:
q = self.mod; l = self.calcOrder()
if q % l != 1:
raise ValueError('q and l don\'t suuport q == 1 mod l')
tau = self.frac(self.miller(Q.add(S)), self.miller(S))
return pow(tau, self.frac(q-1, l), q)
然后我们验证一下性质:
print('P, Q\'s tate pairing:', P.tate(Q, S))
# P, Q's tate pairing: 279
print('Q, P\'s tate pairing:', Q.tate(P, S))
# Q, P's tate pairing: 228
显然
然后看一下双线性:
P3 = P.mul(3)
Q2 = Q.mul(2)
print('P, Q\'s tate pairing:', P.tate(Q, S))
# P, Q's tate pairing: 279
print('3P, 2Q\'s tate pairing:', P3.tate(Q2, S))
# 3P, 2Q's tate pairing: 279
因为
完整代码
# y^2 = x^3 + 30x + 34
import copy, math
class EpCurve:
def __init__(self, a, b, p) -> None:
self.A = a
self.B = b
self.mod = p
def inverse(self, x:int) -> int:
return pow(x, -1, self.mod)
def frac(self, x:int, y:int) -> int:
return x * self.inverse(y) % self.mod
def isOnCurve(self, x, y) -> bool:
return y*y % self.mod == (x*x*x + self.A*x + self.B) % self.mod
class EpPoint(EpCurve):
def __init__(self, x, y, curve) -> None:
super().__init__(curve.A, curve.B, curve.mod)
if not self.isOnCurve(x, y):
raise ValueError(f"Point ({x}, {y}) is not on the curve")
self.x = x % self.mod
self.y = y % self.mod
self.O = False
self.curve = curve
def __str__(self) -> str:
if self.O:
return 'O'
return f"({self.x}, {self.y})"
def calcLam(self, other) -> int:
x1, x2, y1, y2 = self.x, other.x, self.y, other.y
if x2 == x1 and y2 == y1:
if y1 == 0:
return math.inf
return self.frac(3*x1*x1 + self.A, 2*y1) # s = lambda
else:
if x2 == x1:
return math.inf
return self.frac(y2 - y1, x2 - x1)
def add(self, other):
if self.O:
return EpPoint(other.x, other.y, self.curve)
if other.O:
return EpPoint(self.x, self.y, self.curve)
s = self.calcLam(other)
if s == math.inf:
res = copy.deepcopy(self)
res.O = True
return res
x1, x2, y1, y2 = self.x, other.x, self.y, other.y
x3 = (s*s - x1 - x2) % self.mod
y3 = (s * (x1 - x3) - y1) % self.mod
return EpPoint(x3, y3, self.curve)
def mul(self, x): # check only, no performance
res = self
for i in range(x - 1):
res = res.add(self)
return res
def calcOrder(self) -> int:
back = self; cnt = 1
while True:
cnt += 1
back = back.add(self)
if back.O == 1:
return cnt
def lineFunc(self, other, S) -> int: # S is the f_p(S)
xp, xq, yp, yq = self.x, other.x, self.y, other.y
if self.calcLam(other) == math.inf:
return S.x - xp
lam = self.calcLam(other)
return self.frac(S.y - yp - lam*(S.x - xp), S.x + xp + xq - lam*lam)
def miller(self, S) -> int: # P, S
if self.O:
return 1
T = self; f = 1
n = bin(self.calcOrder())[3:]
for bit in n:
f = f * f * T.lineFunc(T, S) % self.mod
T = T.add(T)
if bit == '1':
f = f * T.lineFunc(self, S) % self.mod
T = T.add(self)
return f
def weil(self, Q, S) -> int: # P, Q, S
negS = EpPoint(S.x, -S.y, self.curve) # -S
res1 = self.frac(self.miller(Q.add(S)), self.miller(S))
res2 = self.frac(Q.miller(self.add(negS)), Q.miller(negS))
return self.frac(res1, res2)
def tate(self, Q, S) -> int:
q = self.mod; l = self.calcOrder()
if q % l != 1:
raise ValueError('q and l don\'t suuport q == 1 mod l')
tau = self.frac(self.miller(Q.add(S)), self.miller(S))
return pow(tau, self.frac(q-1, l), q)
curve = EpCurve(30, 34, 631)
P = EpPoint(36, 60, curve) # Order 5
Q = EpPoint(121, 387, curve)
S = EpPoint(0, 36, curve)
# print('P + Q =', P.add(Q))
# print('P + Q =', P.add(P))
# print('Q + S =', Q.add(S))
# print('P\'s order:', P.calcOrder())
# print('P\'s miller function related to S:', P.miller(S))
# print('P + Q\'s miller function related to S:', P.miller(Q.add(S)))
# print('P, Q\'s weil pairing:', P.weil(Q, S))
P3 = P.mul(3)
Q2 = Q.mul(2)
O = P.mul(5)
print('P, Q\'s tate pairing:', P.tate(Q, S))
print('3P, 2Q\'s tate pairing:', P3.tate(Q2, S))
参考资料
- https://github.com/WTFAcademy/WTF-zk/tree/main
- https://crypto.stanford.edu/pbc/notes/ep/
- An Introduction to Mathematical Cryptography