diff --git a/elgamal.py b/elgamal.py index 7d2aee4..b3e8dba 100644 --- a/elgamal.py +++ b/elgamal.py @@ -206,9 +206,8 @@ def find_prime(iNumBits, iConfidence): return p #encodes bytes to integers mod p. reads bytes from file -def encode(sPlaintext, iNumBits): +def encode(sPlaintext, iNumBits, p): byte_array = bytearray(sPlaintext, 'utf-16') - #z is the array of integers mod p z = [] @@ -232,6 +231,9 @@ def encode(sPlaintext, iNumBits): #add the byte multiplied by 2 raised to a multiple of 8 z[j//k] += byte_array[i]*(2**(8*(i%k))) + for i in range( len(z) ): + z[i] = modexp(z[i], 2, p) + #example #if n = 24, k = n / 8 = 3 #z[0] = (summation from i = 0 to i = k)m[i]*(2^(8*i)) @@ -241,9 +243,11 @@ def encode(sPlaintext, iNumBits): return z #decodes integers to the original message bytes -def decode(aiPlaintext, iNumBits): +def decode(aiPlaintext, iNumBits, p): #bytes array will hold the decoded original message bytes bytes_array = [] + q = (p - 1) // 2 + sqrtfactor = (q + 1) // 2 #same deal as in the encode function. #each encoded integer is a linear combination of k message bytes @@ -253,10 +257,14 @@ def decode(aiPlaintext, iNumBits): #num is an integer in list aiPlaintext for num in aiPlaintext: + #revnum = num + revnum = modexp(num, sqrtfactor, p) + if revnum > q: + revnum = p - revnum #get the k message bytes from the integer, i counts from 0 to k-1 for i in range(k): #temporary integer - temp = num + temp = revnum #j goes from i+1 to k-1 for j in range(i+1, k): #get remainder from dividing integer by 2^(8*j) @@ -267,7 +275,7 @@ def decode(aiPlaintext, iNumBits): bytes_array.append(letter) #subtract the letter multiplied by the power of two from num so #so the next message byte can be found - num = num - (letter*(2**(8*i))) + revnum = revnum - (letter*(2**(8*i))) #example #if "You" were encoded. @@ -306,7 +314,7 @@ def generate_keys(iNumBits=256, iConfidence=32): #encrypts a string sPlaintext using the public key k def encrypt(key, sPlaintext): - z = encode(sPlaintext, key.iNumBits) + z = encode(sPlaintext, key.iNumBits - 2, key.p) #cipher_pairs list will hold pairs (c, d) corresponding to each integer in z cipher_pairs = [] @@ -349,7 +357,7 @@ def decrypt(key, cipher): #add plain to list of plaintext integers plaintext.append( plain ) - decryptedText = decode(plaintext, key.iNumBits) + decryptedText = decode(plaintext, key.iNumBits - 2, key.p) #remove trailing null bytes decryptedText = "".join([ch for ch in decryptedText if ch != '\x00'])