import base64
import string
import collections
import csv
import random
import operator
from itertools import izip_longest
from math import sqrt

def translator (frm='', to='', delete='', keep=None):
    # Function for stripping out characters we don't care about
    # Python Cookbook Recipe 1.9
    # Chris Perkins, Raymond Hettinger
    if len (to) == 1: to = to * len (frm)
    trans = string.maketrans (frm, to)
    if keep is not None:
        allchars = string.maketrans ('', '')
        # delete is expanded to delete everything except
        # what is mentioned in set(keep)-set(delete)
        delete = allchars.translate (allchars, keep.translate (allchars, delete))
    def translate (s):
        return s.translate (trans, delete)
    return translate

def xor_string (a_string,b_string):
    # Given two equal length strings XOR each character and return the result
    if (len(a_string) != len(b_string)):
        print "Unequal length strings!"
        return None
    else:
        return ''.join(chr(ord(a) ^ ord(b)) for a,b in zip(a_string,b_string))

def xor_char (s, c):
    # xor each character of the string s with the char c
    # mostly obsolete, use xor_str for everything and it works fine
    xor_string = ''
    for a in s:
        xor_string += chr (ord (a) ^ ord (c))
    return xor_string

def xor_str (s, k):
    # xor each character of the string s with each char in k (repeating of course)
    out = ''
    index = 0
    for char in s:
        out = out + (chr (ord(char) ^ ord(k[index])))
        index = (index + 1) % len(k)
    return out

def dict_euclidian_dist (a, b):
    e_sum = 0
    for x in a:
        e_sum += (float (a[x]) - float(b[x]))**2
    return sqrt (e_sum)

def dict_cosine_sim (a, b):
    a_sq_sum = 0
    b_sq_sum = 0
    ab_sum = 0
    for x in a:
        a_sq_sum += float(a[x])**2
        b_sq_sum += float(b[x])**2
        ab_sum += float(a[x]) * float(b[x])
    if a_sq_sum == 0.0 or b_sq_sum == 0.0:
        return 0
    return (ab_sum / (sqrt (a_sq_sum) * sqrt (b_sq_sum)))

def freq_array_load():
    # prepare the 'known' relative frequencies of letters in English
    # from http://www.cryptograms.org/letter-frequencies.php
    # http://www3.nd.edu/~busiforc/handouts/cryptography/Letter%20Frequencies.html  
    # http://www.cse.chalmers.se/edu/year/2010/course/TDA351/ass1/en_stat.html
    eng_freq = {}
    bigram_freq = {}
    trigram_freq = {}
    quadgram_freq = {}
    freq_csv = csv.DictReader (open ('letter_freq.csv', 'rb'), delimiter=',')    
    bigram_csv = csv.DictReader (open ('bigram_freq.csv', 'rb'), delimiter=',')    
    trigram_csv = csv.DictReader (open ('trigram_freq.csv', 'rb'), delimiter=',')    
    quadgram_csv = csv.DictReader (open ('quadgram_freq.csv', 'rb'), delimiter=',')    
    sum_cos_sim = {}
    
    for line in freq_csv:
        eng_freq[line['Letter']] = line['Frequency']
    for line in bigram_csv:
        bigram_freq[line['Bigram'].lower()] = line['Frequency']
    for line in trigram_csv:
        trigram_freq[line['Trigram'].lower()] = line['Frequency']
    for line in quadgram_csv:
        quadgram_freq[line['Quadgram'].lower()] = line['Frequency']
    return [(1,eng_freq), (2,bigram_freq), (3,trigram_freq), (4,quadgram_freq)] 


def ngram_freq_cos_sim (s, freq_array):
    # returns (key,value) where key is likley to decrypt string s
    # value is the sum'd cosine similarity values w.r.t. the given freq_array

    sum_cos_sim = {}
    # remove non-ascii letters
    text_filter = translator (keep=string.ascii_letters)
    for char in string.printable:
        sum_cos_sim[char] = 0
        t = text_filter (xor_str (s, char)).lower()
        for n,f in freq_array:
            sum_cos_sim[char] += dict_cosine_sim (ngram_freq(n, t, f), f)
           
    key = max (sum_cos_sim, key=sum_cos_sim.get) 
    value = sum_cos_sim[key]
    return (key, value)
           
            
def old_ngram_freq_cos_sim (s):
    eng_freq = {}
    bigram_freq = {}
    trigram_freq = {}
    quadgram_freq = {}
    freq_csv = csv.DictReader (open ('letter_freq.csv', 'rb'), delimiter=',')    
    bigram_csv = csv.DictReader (open ('bigram_freq.csv', 'rb'), delimiter=',')    
    trigram_csv = csv.DictReader (open ('trigram_freq.csv', 'rb'), delimiter=',')    
    quadgram_csv = csv.DictReader (open ('quadgram_freq.csv', 'rb'), delimiter=',')    
    sum_cos_sim = {}
    
    for line in freq_csv:
        eng_freq[line['Letter']] = line['Frequency']
    for line in bigram_csv:
        bigram_freq[line['Bigram'].lower()] = line['Frequency']
    for line in trigram_csv:
        trigram_freq[line['Trigram'].lower()] = line['Frequency']
    for line in quadgram_csv:
        quadgram_freq[line['Quadgram'].lower()] = line['Frequency']
    
    # remove non-ascii letters
    text_filter = translator (keep=string.ascii_letters)
    
    # generate 'decrypted' text using each lowercase letters as the key
    for char in string.printable:
        sum_cos_sim[char] = 0
        char_dec = text_filter (xor_char (s, char)).lower()
        
        # single character test
        char_dec_freq = ngram_freq(1, char_dec,eng_freq)
        sum_cos_sim[char] += dict_cosine_sim (char_dec_freq, eng_freq)    
        
        # bigram test
        bigram_dec_freq = ngram_freq(2, char_dec, bigram_freq)
        sum_cos_sim[char] += dict_cosine_sim (bigram_dec_freq, bigram_freq)
    
        # trigram test 
        trigram_dec_freq = ngram_freq(3, char_dec, trigram_freq)
        sum_cos_sim[char] += dict_cosine_sim (trigram_dec_freq, trigram_freq)
 
        # quadgram test
        quadgram_dec_freq = ngram_freq(4, char_dec, quadgram_freq)
        sum_cos_sim[char] += dict_cosine_sim (quadgram_dec_freq, quadgram_freq)

    '''#
    for t in sorted(sum_cos_sim, key=sum_cos_sim.get):
        print t, sum_cos_sim[t]
    '''
    key = max (sum_cos_sim, key=sum_cos_sim.get) 
    value = sum_cos_sim[key]
    return (key, value)

def ngram_freq (n, s, freq_dict):
    count = collections.Counter()
    freq = collections.defaultdict (list)
    for i in range(0,len(s)-(n-1)):
        count[s[i:len(s)-(len(s)-(n+i))]] += 1
    for key in freq_dict:
        if (len(s)-(n-1)==0):
            freq[key] = 0.0
        else:
            freq[key] = count.get(key,0.0)/float(len(s)-(n-1))
    return freq

def test_xor (s, k):
    encrypt = xor_str (s , k)
    decrypt = xor_str (encrypt, k)
    encrypt_again = xor_str (decrypt, k)
    assert encrypt == encrypt_again
    assert decrypt == s
    return True

def str_tips (s, n):
    if n > len(s)/2:
        return s
    else:
        return s[:n] + " ... " + s[-n:]

def hamming_dist (s1, s2):
    # http://en.wikipedia.org/wiki/Hamming_distance   
    # this version specifically examines the binary of each character, not 
    # just the characters like the wikipedia one does.
    assert len(s1) == len(s2), "hamming_distance requires equal length strings"
    dist = 0
    for ch1, ch2 in zip(s1, s2):
        dist += sum ( a!=b 
            for a, b in zip(
                bin(ord(ch1))[2:].zfill(8), 
                bin(ord(ch2))[2:].zfill(8))
        )
    return dist
 
def key_len_finder (s):
    # nhd = normalized hamming distance dict
    nhdd = {}
    min_nhd = (0,41.0)
    blocks = 4
    for keysize in xrange (1,41):
        sum_nhd = 0
        for i in range (blocks):
            # normalized by keysize*8 since hamming dist is binary (char = 8 bits)
            sum_nhd += hamming_dist (
                    s[i*keysize:i*keysize+keysize], 
                    s[i*keysize+keysize:i*keysize+keysize*2] 
                ) / float (keysize*8)
        nhd = sum_nhd / float(blocks)
        nhdd[keysize] = nhd
        if nhd < min_nhd[1]:
            min_nhd = (keysize, nhd)
    return sorted(nhdd.iteritems(), key=operator.itemgetter(1))

def grouper (iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # http://docs.python.org/2/library/itertools.html#recipes
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
    args = [iter(iterable)] * n
    return izip_longest(fillvalue=fillvalue, *args)

def group_and_trans (s,key_len):
    transpose = ['' for i in range (key_len)]
    s_segments = grouper(s,key_len,'')
    for j in s_segments:
        for i in range(key_len):
            transpose[i] += j[i]
    return transpose
