import scipy.stats as ss
import copy
import sys
import mido
from time import sleep
import numpy as np
import random
import math
import pickle
import os

#Key: each row is a triple: note value, note length, type (0=continuation, 1=note onset, 2=rest)
################### create sequence
def createSequence(seqLength=50):
    sequence = np.zeros( (seqLength,3) , dtype=np.int8 )
    ii=0
    while ii<seqLength-4:    
        if random.random()<0.6:
            note = random.randint(60,76)
            length = random.randint(1,4)
            sequence[ii][0] = note
            sequence[ii][1] = length
            sequence[ii][2] = 1
            ii += length
        else:
            ii += 1
    return sequence


################### create even sequence
def createEvenSequence(seqLength=50):
    sequence = np.zeros( (seqLength,3) , dtype=np.int8 )
    ii=0
    while ii<seqLength-1:    
        note = random.randint(60,76)
        sequence[ii][0] = note
        sequence[ii][1] = 1
        sequence[ii][2] = 1
        ii += 1
    return sequence


################### create repeated sequence
def createRepeatedSequence(seqLength=50):
    sequence = np.zeros( (seqLength,3) , dtype=np.int8 )
    ii=0
    while ii<seqLength-1:    
        note = 66
        sequence[ii][0] = note
        sequence[ii][1] = 1
        sequence[ii][2] = 1
        ii += 1
    return sequence


################### create trill
def createTrill(seqlength=50):
    seqLength = 50
    sequence = np.zeros( (seqLength,3) , dtype=np.int8 )
    firstNote = random.randint(60,76)
    secondNote = firstNote + random.randrange(1,2)
    for ii in range(seqLength-1):
        sequence[ii][0] = firstNote if (ii%2==0) else secondNote
        sequence[ii][1] = 1
        sequence[ii][2] = 1
    return sequence


################### insert silent blocks
def insertSilentBlocks(seq,noOfBlocks):
    newSeq = copy.deepcopy(seq)
    gapPositions = random.sample(range(seq.shape[0]),noOfBlocks)
    cumulativeGapSize = 0
    for pos in gapPositions:
        gapSize = random.randint(25,30)
        for _ in range(gapSize):
            newSeq = np.insert(newSeq, pos+cumulativeGapSize, [0,0,2], axis=0)
        cumulativeGapSize += gapSize
    return newSeq


################### playback
def playback(sequence,speed,outport):
    #speed - 0.1 is fast, 0.3 is slow
    seqLength = len(sequence)
    messageList = []
    for ii in range(seqLength):
        messageList.append([])
    for ii in range(seqLength):
        if sequence[ii][0] != 0:
            #print(sequence[ii][0])
            messageList[ii].append(mido.Message('note_on',note=sequence[ii][0],velocity=80))
            messageList[ii+sequence[ii][1]].append(mido.Message('note_off',note=sequence[ii][0]))
    interval = speed
    for t in range(seqLength):
        for msg in messageList[t]:
            outport.send(msg)
        sleep(interval)
    outport.send(mido.Message('stop'))

################### analyse slapdash serialism                              
def analyseSlapdashSerialism(seq):
    intervalList = np.zeros( (12+1) )
    for note in seq:
        if note[2]==1:
            prevNote = note[0]
    for note in seq[1:]:
        if not note[2]==1:
            continue
        currNote = note[0]
        interval = abs(currNote-prevNote)
        if interval>=12:
            intervalList[12] += 1
        else:
            intervalList[interval] += 1
        prevNote = currNote
    count = len(intervalList)
    sumUp = 0.0
    sumSq = 0.0
    for ivIdx, iv in enumerate(intervalList):
        intervalList[ivIdx] = iv/count #normalise
        sumUp += iv
        sumSq += iv*iv
    stDev = math.sqrt(sumSq/count - math.pow(sumUp/count,2))
    #print([ '{:.2f}'.format(x) for x in intervalList ])
    return stDev

################### analyse slapdash serialism 2
def analyseSlapdashSerialism2(seq):
    intervalList = np.zeros( (12+1) )
    numberOfNotes = 0
    for note in seq:
        if note[2]==1:
            prevNote = note[0]
            numberOfNotes += 1
            break
    for note in seq[1:]:
        if not note[2]==1:
            continue
        currNote = note[0]
        interval = abs(currNote-prevNote)
        if interval>=12:
            intervalList[12] += 1
        else:
            intervalList[interval] += 1
        numberOfNotes += 1 
        prevNote = currNote
    normalisedIntervalList = [x / numberOfNotes for x in intervalList]
    idealList = [ 0.0, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.091, 0.0]
    diff = sum( [math.pow(a-b,2) for a, b in zip(normalisedIntervalList,idealList)] )
    #print([ '{:.2f}'.format(x) for x in normalisedIntervalList ])
    #print(diff)
    return diff


################### analyse melodicity                             
def analyseMelodicity(seq):
    intervalList = np.zeros( (12+1) )
    numberOfNotes = 0
    for note in seq:
        if note[2]==1:
            prevNote = note[0]
            numberOfNotes += 1
            break
    for note in seq[1:]:
        if note[2]!=1:
            continue
        currNote = note[0]
        interval = abs(currNote-prevNote)
        if interval>=12:
            intervalList[12] += 1
        else:
            intervalList[interval] += 1
        numberOfNotes += 1 
        prevNote = currNote
    normalisedIntervalList = [x / numberOfNotes for x in intervalList]
    idealList = [ 0.0, 0.1, 0.7, 0.05, 0.05, 0.02, 0.01, 0.01, 0.03, 0.01, 0.01, 0.01, 0.0]
    diff = sum( [math.pow(a-b,2) for a, b in zip(normalisedIntervalList,idealList)] )
    #print([ '{:.2f}'.format(x) for x in normalisedIntervalList ])
    #print(diff)
    return diff


################### analyse trillicity
def analyseTrillicity(seq):
    pitchSeq = justThePitches(seq)
    windowLength = 8
    trillicity = 0
    length = 0
    for ws in range(len(pitchSeq)-windowLength):
        window = pitchSeq[ws:ws+windowLength]
        pitch1 = window[0]
        pitch2 = window[1]
        length += windowLength-2
        if abs(pitch1-pitch2)>2 or pitch1==pitch2:
            continue
        #wibble check for difference between pitch
        even = True
        #print(window)
        for pp in window[2:]:
            if even and pp==pitch1:
                trillicity +=1
            if (not even) and pp==pitch2:
                trillicity +=1
            even = not even
    normalisedTrillicity = 1.0 - trillicity/length
    return normalisedTrillicity

################### analyse large gaps
def analyseLargeGaps(seq):
    countGaps = 0
    countGapSize = 0
    for ss in seq:
        if ss[2]==2:
            countGapSize +=1
        else:
            if countGapSize>=15:
                countGaps +=1
            countGapSize = 0
    if countGaps==0:
        return 1.0
    elif countGaps==1:
        return 0.3
    elif countGaps>=2:
        return 0.1
    
################### analyse repeated
def analyseRepeated(seq):
    countRepeated = 0
    totalLength = 0 
    currNote = seq[0][0]
    for ss in seq[1:]:
        if ss[2]==1:
            totalLength += 1
            if currNote==ss[0]:
                countRepeated +=1
            currNote=ss[0]
    return 1.0-(countRepeated/totalLength)


################### analyse continuous
def analyseContinuous(seq):
    countRests = 0
    for ss in seq:
        if ss[2]==2:
            countRests +=1
    return float(countRests)/float(len(seq))


################### just the pitches
def justThePitches(seq):
    pitchesInOrder = []
    for note in seq:
        if note[2]!=1:
            continue
        else:
            pitchesInOrder.append(note[0])
    return pitchesInOrder


################### simpleGA
def simpleGA(outport,initialisationFunction,fitnessFunction,timescale=101,popSize=1000,playSamples=True,samplePoint=10,playbackSpeed=0.1,\
             storeSamples=False,samplePrefix="sample",silentOn=True):

    if storeSamples:
        if not os.path.exists("00samples"+os.path.sep+samplePrefix):
            os.mkdir("00samples"+os.path.sep+samplePrefix)
    
    ## mutation distribution
    mutRange = np.arange(-4,4)
    xU, xL = mutRange + 0.5, mutRange - 0.5 
    prob = ss.norm.cdf(xU, scale = 3) - ss.norm.cdf(xL, scale = 3)
    prob = prob / prob.sum() #normalize the probabilities so their sum is 1
    ## from https://stackoverflow.com/questions/37411633/how-to-generate-a-random-normal-distribution-of-integers
    
    tournamentSize = 7
    population = []
    for _ in range(popSize):
        #population.append([createSequence(),0.0])
        population.append([initialisationFunction(),0.0])
    for t in range(timescale):
        for i in range(popSize):
            #population[i][1] = analyseMelodicity(population[i][0])
            population[i][1] = fitnessFunction(population[i][0])
        population.sort(key=lambda x: x[1])
        print("mean fitness: ",sum([x[1] for x in population])/popSize, " at time ",t)
        print("  melodicity: ",analyseMelodicity(population[0][0]))
        print("  trillicity: ",analyseTrillicity(population[0][0]))
        print("  slapdash:   ",analyseSlapdashSerialism2(population[0][0]))
        print("  continuity: ",analyseContinuous(population[0][0]))
        print("  repeated: ",analyseRepeated(population[0][0]))
        print("  large gaps: ",analyseLargeGaps(population[0][0]))
        print(population[0][1])
        #print(justThePitches(population[0][0]))
        #print(['{:.2f}'.format(x[1]) for x in population])
        #print([x[0] for x in population[0][0]])
        if (t % samplePoint) == 0 and playSamples:
            #print("gaps ",analyseLargeGaps(population[0][0]))
            playback(population[0][0],playbackSpeed,outport)
            #print(population[0][0])
        if (t % samplePoint) == 0 and storeSamples:
            with open("00samples"+os.path.sep+samplePrefix+os.path.sep+samplePrefix+"_"+str(t)+".p","wb") as f:
                pickle.dump(population[0][0],f)

        newPop = []
        for pp in population:
            #print([x[1] for x in sorted(random.sample(population,tournamentSize),key=lambda x: x[1])])
            parent1, *_ = sorted(random.sample(population,tournamentSize),key=lambda x: x[1])
            child = copy.deepcopy(parent1)
            length = np.shape(child[0])[0]
            for pos in range(length):
                if child[0][pos][2]==1:
                    if random.random()<0.1:
                        currVal = child[0][pos][0]
                        candVal = currVal + np.random.choice(mutRange, p = prob)
                        if candVal>76 or candVal<60:
                            candVal = currVal
                        child[0][pos][0] = candVal
                    elif random.random()<0.1 and silentOn:
                        #make silent
                        child[0][pos][0] = 0
                        child[0][pos][2] = 2
                if child[0][pos][2]==2:
                    if random.random()<0.05:
                        #make silence into a note
                        child[0][pos][0] = random.randint(60,76)
                        child[0][pos][1] = 1
                        child[0][pos][2] = 1
                child[1] = 0.0
            newPop.append(child)

        newPop[0] = sorted(population,key=lambda x: x[1])[0] #elitism
        population = copy.deepcopy(newPop)
    return population[0][0]

################### main zone
outport = mido.open_output('IAC Driver pioneer')

## options: createSequence, createEvenSequence, createTrill, insertLargeGaps
## options: analyseSlapdashSerialism2, analyseMelodicity, analyseTrillicity, analyseContinuous

#example 1: starting from a random sequence, making it more trilly - will this fail?
#seq2 = simpleGA(outport,createEvenSequence,lambda x: analyseTrillicity(x),silentOn=False, \
#                 storeSamples="True", samplePrefix="example1")

##example 2: starting from a pure trill, making it more melodic
#seq2 = simpleGA(outport,createTrill,lambda x: analyseMelodicity(x)+analyseTrillicity(x),silentOn=False,
#                storeSamples="True", samplePrefix="example2",playbackSpeed=0.08)

##example 2bis: starting from a pure trill, making it more melodic
#seq2 = simpleGA(outport,createTrill,lambda x: analyseMelodicity(x)+analyseTrillicity(x)+analyseContinuous(x),silentOn=True,
# storeSamples="True", samplePrefix="example2",playbackSpeed=0.08)

##example 3: including gaps
#seq2  = simpleGA(outport, lambda: insertSilentBlocks(createEvenSequence(),2),\
#                 lambda x: analyseMelodicity(x)+analyseLargeGaps(x), playSamples=True,\
#                 samplePoint = 100, popSize = 50, timescale=301, silentOn=True, \
#                 storeSamples = True, samplePrefix = "example3")

#example 4:
seq2  = simpleGA(outport, lambda: insertSilentBlocks(createEvenSequence(),2), \
                 lambda x: analyseMelodicity(x) + analyseLargeGaps(x) + 5* analyseTrillicity(x) + 3*analyseContinuous(x), \
                 playSamples=True, silentOn=True, \
                 storeSamples=True, samplePrefix="example4", samplePoint = 100, popSize = 50, timescale=1001)


################### faff zone


#seq2  = simpleGA(outport, lambda: insertSilentBlocks(createEvenSequence(),2), \
#                 lambda x: analyseLargeGaps(x),timescale=100,playSamples=True)

#seq2 = simpleGA(outport,createTrill,lambda x: analyseTrillicity(x),timescale=100,playSamples=True)

#seq1 = createEvenSequence()
#seq1 = insertSilentBlocks(seq1,2)
#print(seq1)
#playback(seq1,0.1,outport)
#sleep(1.0)
#playback(seq2,0.2,outport)

#tr = createTrill()
#tr = createSequence()
#print(analyseTrillicity(tr))

#playback(seq,0.2,outport,0.2)
#print('{:.2f}'.format(analyseSlapdashSerialism(seq)))
#seq1 = simpleGA(outport,analyseSlapdashSerialism2)
#seq2 = simpleGA(outport,analyseMelodicity)

