现充|junyu33

Generalized MT19937 PRG reverse analysis

Problem Description

Given 10 parameters N,M,A,U,S,B,T,C,L,F from mt19937, along with the first N just-generated pseudorandom numbers, determine the seed corresponding to these pseudorandom numbers.

Approach Analysis

20pts

Brute force enumeration of the seed is sufficient. Code omitted.

100pts

The general approach is: generated random numbers → state after twist → possible state[-1] before twist → obtain the seed.

We first consider the standard case, where the parameters are exactly the same as in the paper. There is an existing solution available.

The specific principles will not be elaborated here, as the provided link and Mivik's solution explain it clearly.

Then, when the highest bit of A is not 1, we can no longer directly determine whether the original tmp was odd based on the highest bit of tmp.

The code for reversing the twist in the original solution is as follows:

def backtrace(cur):
    high = 0x80000000
    low = 0x7fffffff
    mask = 0x9908b0df
    state = cur
    for i in range(623,-1,-1):
        tmp = state[i]^state[(i+397)%624]
        # recover Y,tmp = Y
        if tmp & high == high:
            tmp ^= mask
            tmp <<= 1
            tmp |= 1
        else:
            tmp <<=1
        # recover highest bit
        res = tmp&high
        # recover other 31 bits,when i =0,it just use the method again it so beautiful!!!!
        tmp = state[i-1]^state[(i+396)%624]
        # recover Y,tmp = Y
        if tmp & high == high:
            tmp ^= mask
            tmp <<= 1
            tmp |= 1
        else:
            tmp <<=1
        res |= (tmp)&low
        state[i] = res    
    return state

The issue here lies in the condition if tmp & high == high:, which is no longer valid and cannot be used to accurately determine the value of tmp.

A straightforward approach is to enumerate the four possible values of mt[N - 1]. Since mt[i - 1] ^ mt[i - 1] >> 30 is reversible and F is odd, it is coprime with 2 ** 32. Therefore, we can backtrack the seed by finding the inverse of F.

To verify these four seeds, we can use each seed to regenerate a few random numbers and compare them with the input.

Complete Code

from gmpy2 import invert
def _int32(x):
    return int(0xFFFFFFFF & x)
class mt19937:
    def __init__(self, seed=0):# magic method (run code below automatically when an object is created) 
        self.mt = [0] * N
        self.mt[0] = seed
        self.mti = 0
        for i in range(1, N):
            self.mt[i] = _int32(F * (self.mt[i - 1] ^ self.mt[i - 1] >> 30) + i)
    def getstate(self,op=False):
        if self.mti == 0 and op==False:
            self.twist()
        y = self.mt[self.mti]
        y = y ^ y >> U
        y = y ^ y << S & B
        y = y ^ y << T & C
        y = y ^ y >> L
        self.mti = (self.mti + 1) % N
        return _int32(y)
    def twist(self):
        for i in range(0, N):
            y = _int32((self.mt[i] & 0x80000000) + (self.mt[(i + 1) % N] & 0x7fffffff))
            self.mt[i] = (y >> 1) ^ self.mt[(i + M) % N]
            if y % 2 != 0:
                self.mt[i] = self.mt[i] ^ A
    def inverse_right(self,res, shift, mask=0xffffffff, bits=32):
        tmp = res
        for i in range(bits // shift):
            tmp = res ^ tmp >> shift & mask
        return tmp
    def inverse_left(self,res, shift, mask=0xffffffff, bits=32):
        tmp = res
        for i in range(bits // shift):
            tmp = res ^ tmp << shift & mask
        return tmp
    def extract_number(self,y): # namely "temper" in Mivik's code
        y = y ^ y >> U
        y = y ^ y << S & B
        y = y ^ y << T & C
        y = y ^ y >> L
        return y&0xffffffff
    def recover(self,y): # inverse of extract_number
        y = self.inverse_right(y,L)
        y = self.inverse_left(y,T,C)
        y = self.inverse_left(y,S,B)
        y = self.inverse_right(y,U)
        return y&0xffffffff
    def setstate(self,s): # N generated random numbers -> mt[] after twisting 
        if(len(s)!=N):
            raise ValueError("The length of prediction must be N!")
        for i in range(N):
            self.mt[i]=self.recover(s[i])
        #self.mt=s
        self.mti=0
    ''' 
    def predict(self,s): # a method to predict other pseudo random numbers after given N of them (useless in this problem)
        self.setstate(s)
        self.twist()
        return self.getstate(True)
    '''
    def invtwist(self): # mt[] after twisting -> 4 possible values of mt[-1] before twisting
        high = 0x80000000
        low = 0x7fffffff
        mask = A
        opt = [0] * 4
        for i in range(N-1,N-2,-1): # only process the last number
            for s in range(2):
                for t in range(2):
                    tmp = self.mt[i]^self.mt[(i+M)%N]
                    if s==0: # two possibilities
                        tmp ^= mask
                        tmp <<= 1
                        tmp |= 1
                    else:
                        tmp <<=1
                    res = tmp&high
                    tmp = self.mt[i-1]^self.mt[(i+M-1)%N]
                    if t==0: # another two
                        tmp ^= mask
                        tmp <<= 1
                        tmp |= 1
                    else:
                        tmp <<=1
                    res |= (tmp)&low
                    opt[s * 2 + t] = res
        return opt
    
    def recover_seed(self,last): # mt[-1] -> mt[0]
        n = 1 << 32
        inv = invert(F, n) # inverse of F mod 2 ^ 32
        for i in range(N-1, 0, -1):
            last = ((last - i) * inv) % n
            last = self.inverse_right(last, 30)
        return last

N, M, A, U, S, B, T, C, L, F = map(int, input().split())
inpt = [0] * N # align enough space
for i in range (N):
    inpt[i] = int(input())
D = mt19937() 
D.setstate(inpt) # using the input to recover state after twisting
op = D.invtwist() # generate four possibilities of D.mt[-1]
seed = [0] * 4
for i in range(4): # check the seeds one by one
    seed[i] = D.recover_seed(op[i]) 
    E = mt19937(seed[i])
    E.getstate() # another psuedo random number generator
    flag = 1
    for j in range(10): # compare first 10 numbers is totally enough 
        if E.extract_number(E.mt[j]) != inpt[j]:
            flag = 0
    if flag > 0:
        print(seed[i])
        break

Time complexity: O(N).