﻿from pyscript import document
from pyscript.ffi import create_proxy
from js import setTimeout
from js import setInterval
import qm_nbit as qm
import geo2d
import numpy as np
import random
import svgobjects
from js import Object
import numpy as np


# --- Global variables -------------------------------------------------------------------

print("Startup begin!")

dragging = None
offset_x = 0
offset_y = 0

g_state = None
g_currentlyDisplayedParticles = 0

g_measurementDevice1 = None
g_measurementDevice2 = None

g_stateVis = None

g_time = 0
g_timeModulus = 30

# --- SVG classes ----------------------------------------------------------------------

class Group:
    def __init__(self, parent, x, y):
        self.translation = geo2d.Vec2d(x, y)
        self.group = document.createElementNS("http://www.w3.org/2000/svg", "g")
        self.group.setAttribute("transform", f"translate({x}, {y})")
        parent.appendChild(self.group)

    def appendChild(self, child):
        self.group.appendChild(child)

    def getTranslation(self):
        return self.translation


class Preview:
    def __init__(self, parent, x, y):
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("id", "preview")
        self.text.setAttribute("x", x+60)
        self.text.setAttribute("y", y+20)
        self.text.setAttribute("fill", "red")
        self.text.setAttribute("font-size", 16)
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle")
        self.text.setAttribute("class", "noselect")
        # Create two tspan elements for two lines
        self.tspan1 = document.createElementNS("http://www.w3.org/2000/svg", "tspan")
        self.tspan1.setAttribute("x", x+60)
        self.tspan1.setAttribute("dy", "0")
        self.tspan1.textContent = ""
        self.tspan2 = document.createElementNS("http://www.w3.org/2000/svg", "tspan")
        self.tspan2.setAttribute("x", x+60)
        self.tspan2.setAttribute("dy", "1.2em")
        self.tspan2.textContent = ""
        self.text.appendChild(self.tspan1)
        self.text.appendChild(self.tspan2)
        parent.appendChild(self.text)

    def setText(self, text):
        if isinstance(text, (list, tuple)) and len(text) == 2:
            self.tspan1.textContent = text[0]
            self.tspan2.textContent = text[1]
        elif isinstance(text, str):
            self.tspan1.textContent = text
            self.tspan2.textContent = ""
        else:
            self.tspan1.textContent = ""
            self.tspan2.textContent = ""


class Display:
    def __init__(self, parent, x, y):
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", x)
        self.rect.setAttribute("y", y)
        self.rect.setAttribute("width", 120)
        self.rect.setAttribute("height", 30)
        self.rect.setAttribute("class", "device")
        parent.appendChild(self.rect)
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("x", x + 60)
        self.text.setAttribute("y", y + 20)
        self.text.setAttribute("class", "noselect")
        self.text.setAttribute("font-size", "20")
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle")
        self.text.textContent = "---"
        parent.appendChild(self.text)

    def setText(self, text, autoClear=True):
        self.text.textContent = text
        if autoClear:
            setTimeout(create_proxy(self.clear), 1000)

    def clear(self):
        self.text.textContent = "---"


class ActionButton:
    def __init__(self, parent, x, y, label, on_pressed):
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", x)
        self.rect.setAttribute("y", y)
        self.rect.setAttribute("width", 60)
        self.rect.setAttribute("height", 60)
        self.rect.setAttribute("class", "device deviceButton")
        self.rect.addEventListener("click", create_proxy(on_pressed))
        parent.appendChild(self.rect)
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("x", x + 30)
        self.text.setAttribute("y", y + 30)
        self.text.setAttribute("class", "noselect")
        self.text.setAttribute("font-size", "12")
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle")
        self.text.setAttribute("alignment-baseline", "middle")
        self.text.textContent = label
        parent.appendChild(self.text)

class TextToggleButton:
    def __init__(self, parent, x, y, choiceLabels, choiceIds, *, ysize=1, textsize=1):
        self.choiceLabels = choiceLabels
        self.choiceIds = choiceIds
        self.choice = 0
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", x)
        self.rect.setAttribute("y", y)
        self.rect.setAttribute("width", 60)
        self.rect.setAttribute("height", round(60*ysize))
        #self.rect.setAttribute("stroke", "black")
        #self.rect.setAttribute("fill", "white")        
        self.rect.setAttribute("class", "device deviceButton")
        self.rect.addEventListener("click", create_proxy(lambda event: self.on_pressed()))
        parent.appendChild(self.rect)    
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("x", x + 30)
        self.text.setAttribute("y", y + round(30*ysize)) # before: 40
        self.text.setAttribute("class", "noselect")
        self.text.setAttribute("font-size", str(round(40*ysize*textsize)))
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle") # Specifies alignment in x-dimension
        #self.text.setAttribute("alignment-baseline", "middle") -- does not work with Firefox
        self.text.setAttribute("dominant-baseline", "middle")  # Specifies alignment in y-dimension
        parent.appendChild(self.text)
        self.updateChoiceDisplay()

    def updateChoiceDisplay(self):
        label:str = self.choiceLabels[self.choice]
        if label.startswith("$") and label.endswith("$") and "_" in label:
            parts = label[1:-1].split("_")
            self.text.innerHTML = f'<tspan class="normal">{parts[0]}</tspan><tspan class="sub" dy="8">{parts[1]}</tspan>'
        else:
            self.text.textContent = self.choiceLabels[self.choice]
        

    def getChoice(self):
        return self.choiceIds[self.choice]

    def on_pressed(self):
        self.choice = (self.choice+1)%len(self.choiceLabels)
        self.updateChoiceDisplay()
        

class PolarisationSetting:
    def __init__(self, parent, x, y):
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", x)
        self.rect.setAttribute("y", y)
        self.rect.setAttribute("width", 60)
        self.rect.setAttribute("height", 60)
        self.rect.setAttribute("class", "device deviceButton")
        self.rect.addEventListener("click", create_proxy(self.on_click))

        parent.appendChild(self.rect)

        self.svg = document.createElementNS("http://www.w3.org/2000/svg", "svg")
        self.svg.setAttribute("x", x)
        self.svg.setAttribute("y", y)
        self.svg.setAttribute("width", 60)
        self.svg.setAttribute("height", 60)
        self.svg.setAttribute("viewBox", "-30 -30 60 60")
        parent.appendChild(self.svg)

        self.g = document.createElementNS("http://www.w3.org/2000/svg", "g")
        self.g.setAttribute("id", "directionDisplay")
        self.g.setAttribute("transform", "rotate(45)")
        #self.g.setAttribute("class", "noselect")
        self.g.setAttribute("stroke", "black")
        self.svg.appendChild(self.g)

        for y in range(-32, 33, 8):
            line = document.createElementNS("http://www.w3.org/2000/svg", "line")
            line.setAttribute("x1", -100)
            line.setAttribute("y1", y)
            line.setAttribute("x2", 100)
            line.setAttribute("y2", y)
            line.setAttribute("class", "noselect")
            self.g.appendChild(line)


    def getRotation(self):
        transform = self.g.getAttribute("transform")
        rotation_value = transform.split("rotate(")[1].split(")")[0]
        return float(rotation_value)


    def on_click(self, event):
        event.preventDefault()
        if event.shiftKey:
          self.increaseRotation(-22.5)
        else:
          self.increaseRotation(22.5)
        #lambda event: self.increaseRotation()
        

    def increaseRotation(self, delta):
        rotation = self.getRotation()
        self.g.setAttribute("transform", f"rotate({rotation + delta})")
        update_preview()

    
class PayloadArea:
    def __init__(self, parent, x, y, translationInfo, *, ysize=1):
        self.translationInfo = translationInfo
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", x)
        self.rect.setAttribute("y", y)
        self.rect.setAttribute("width", 120)
        self.rect.setAttribute("height", 120*ysize)
        #self.rect.setAttribute("stroke", "black")
        #self.rect.setAttribute("fill", "white")
        self.rect.setAttribute("class", "device")
        parent.appendChild(self.rect)


    def is_within_measurement_area(self, pos):
        bbox = geo2d.Region(self.rect.getBBox())
        bbox.move_(self.translationInfo)
        return bbox.contains(pos)


    def getContainedParticles(self):
        result = []
        for particleId in range(g_currentlyDisplayedParticles):
            pos = get_particle_position(particleId)
            if self.is_within_measurement_area(pos):
                result.append(particleId)
        return result   

class MeasurementDevice:
    def __init__(self, parent, x, y):
        self.g = Group(parent, x, y)
        self.preview = Preview(self.g, 0, -50)
        self.display = Display(self.g, 0, 0)
        self.polarisationSetting = PolarisationSetting(self.g, 0, 30)
        self.button = ActionButton(self.g, 60, 30, "measure!", lambda event: self.on_button_pressed())
        self.payloadArea = PayloadArea(self.g, 0, 90, self.g.getTranslation())

    def on_button_pressed(self):
        particleIds = self.payloadArea.getContainedParticles()
        if len(particleIds) != 1:
            self.display.setText("error")
            return
        particleId = particleIds[0]
        g_state.createCheckpoint()
        measurement_result = g_state.measurePolarisation(particleId, self.polarisationSetting.getRotation()*qm.deg)
        self.display.setText(str(measurement_result))
        updateAll()


    def update_preview(self):
        if not previewIsActive():
            self.preview.setText("")
            return
        particleIds = self.payloadArea.getContainedParticles()
        if len(particleIds) != 1:
            self.preview.setText("")
            return
        particleId = particleIds[0]    
        p1, p2 = g_state.measurePolarisation(particleId, self.polarisationSetting.getRotation()*qm.deg, preview=True)
        text = [f"1: {round(p1*100)}%", f"0: {round(p2*100)}%"]
        self.preview.setText(text)


class PreparationDevice:
    def __init__(self, parent, x, y):
        self.g = Group(parent, x, y)
        self.display = Display(self.g, 0, 0)
        self.polarisationSetting = PolarisationSetting(self.g, 0, 30)
        self.button = ActionButton(self.g, 60, 30, "prepare!", lambda event: self.on_button_pressed())
        self.payloadArea = PayloadArea(self.g, 0, 90, self.g.getTranslation())

    def on_button_pressed(self):
        particleIds = self.payloadArea.getContainedParticles()
        if len(particleIds) < 1:
            self.display.setText("error")
            return
        g_state.createCheckpoint()
        for particleId in particleIds:
            g_state.setPolarisation(particleId, self.polarisationSetting.getRotation()*qm.deg)
        self.display.setText("OK")
        updateAll()


class EntanglementDevice:
    def __init__(self, parent, x, y):
        self.g = Group(parent, x, y)
        self.display = Display(self.g, 0, 0)
        choiceLabels = ["\u2260", "="]
        choiceIds = ["!=", "=="]
        self.entanglementSetting = TextToggleButton(self.g, 0, 30, choiceLabels, choiceIds)
        self.button = ActionButton(self.g, 60, 30, "entangle!", lambda event: self.on_button_pressed())
        self.payloadArea = PayloadArea(self.g, 0, 90, self.g.getTranslation())

    def on_button_pressed(self):
        particleIds = self.payloadArea.getContainedParticles()
        if len(particleIds) != 2:
            self.display.setText("error")
            return
        g_state.createCheckpoint()            
        g_state.entangle(self.entanglementSetting.getChoice(), particleIds[0], particleIds[1])
        self.display.setText("OK")
        updateAll()


class Transform1Device:
    def __init__(self, parent, x, y):
        self.g = Group(parent, x, y)
        self.display = Display(self.g, 0, 0)
        choiceLabels = ["$R_X$", "$R_Y$", "$R_Z$", "$R_H$"]
        choiceIDs = ["X", "Y", "Z", "H"]
        powerLabels = ["180°", "+90°", "-90°", "+45°", "-45°"]
        powerValue = [1, 0.5, -0.5, 0.25, -0.25]
        self.trafoSetting = TextToggleButton(self.g, 0, 30, choiceLabels, choiceIDs, ysize=0.5)
        self.powerSetting = TextToggleButton(self.g, 0, 60, powerLabels, powerValue, ysize=0.5)
        self.button = ActionButton(self.g, 60, 30, "transform!", lambda event: self.on_button_pressed())
        self.payloadArea = PayloadArea(self.g, 0, 90, self.g.getTranslation())

    def on_button_pressed(self):
        particleIds = self.payloadArea.getContainedParticles()
        if len(particleIds) == 0:
            self.display.setText("error")
            return
        g_state.createCheckpoint()
        for particleId in particleIds:
          transformation = qm.getTransformation(self.trafoSetting.getChoice(), self.powerSetting.getChoice())
          g_state.applyTransform(transformation, [particleId])
        self.display.setText("OK")
        updateAll()


class Transform2Device:
    def __init__(self, parent, x, y):
        self.g = Group(parent, x, y)
        self.display = Display(self.g, 0, 0)
        choiceLabels = ["CNOT", "SWAP"]
        choiceValues = [qm.CNOT_matrix, qm.SWAP_matrix]
        self.trafoSetting = TextToggleButton(self.g, 0, 30, choiceLabels, choiceValues, textsize=0.5)
        self.button = ActionButton(self.g, 60, 30, "transform!", lambda event: self.on_button_pressed())
        self.payloadArea1 = PayloadArea(self.g, 0, 90, self.g.getTranslation(), ysize=0.5)
        self.payloadArea2 = PayloadArea(self.g, 0, 90+60, self.g.getTranslation(), ysize=0.5)

    def on_button_pressed(self):
        particleIds1 = self.payloadArea1.getContainedParticles()
        particleIds2 = self.payloadArea2.getContainedParticles()
        if len(particleIds1) != 1 or len(particleIds2) != 1:
            self.display.setText("error")
            return
        g_state.createCheckpoint()        
        g_state.applyTransform(self.trafoSetting.getChoice(), particleIds1+particleIds2)
        self.display.setText("OK")
        updateAll()





class Particle:
    def __init__(self, parent, particleId, x, y):
        self.particleId = particleId
        self.element = document.createElementNS("http://www.w3.org/2000/svg", "g")
        self.element.setAttribute("id", f"particle{particleId}")
        self.element.setAttribute("transform", f"translate({x}, {y})")
        self.circle = document.createElementNS("http://www.w3.org/2000/svg", "circle")
        self.circle.setAttribute("cx", 0)
        self.circle.setAttribute("cy", 0)
        self.circle.setAttribute("r", 25)
        self.circle.setAttribute("class", "particle particleSensitive")
        self.circle.setAttribute("stroke-width", 2)
        self.circle.setAttribute("fill", "none")
        self.element.appendChild(self.circle)
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("class", "noselect")
        self.text.setAttribute("x", 0)
        self.text.setAttribute("y", 5)
        self.text.setAttribute("font-size", 30)
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle")
        self.text.setAttribute("dominant-baseline", "middle")
        self.text.textContent = str(particleId+1)
        self.element.appendChild(self.text)
        parent.appendChild(self.element)



def compute_transform_from_bboxes(sourceBox, targetBox):
    scale_x = targetBox.width / sourceBox.width
    scale_y = targetBox.height / sourceBox.height

    translate_x = targetBox.x - sourceBox.x * scale_x
    translate_y = targetBox.y - sourceBox.y * scale_y

    return f"translate({translate_x}, {translate_y}) scale({scale_x}, {scale_y})"



class BaseVectorPartVis:
    def __init__(self, parent, targetBox, particleId):
        # Natural aspect ratio is 80:50
        self.group = document.createElementNS("http://www.w3.org/2000/svg", "g")
        self.group.setAttribute("transform", compute_transform_from_bboxes(geo2d.Region(0, 0, 80, 50), targetBox))
        parent.appendChild(self.group)
        self.rect = document.createElementNS("http://www.w3.org/2000/svg", "rect")
        self.rect.setAttribute("x", 0)
        self.rect.setAttribute("y", 0)
        self.rect.setAttribute("width", 80)
        self.rect.setAttribute("height", 50)
        self.rect.setAttribute("stroke", "black")
        self.rect.setAttribute("fill", "transparent")
        self.group.appendChild(self.rect)
        self.text = document.createElementNS("http://www.w3.org/2000/svg", "text")
        self.text.setAttribute("class", "noselect")
        self.text.setAttribute("x", 15)
        self.text.setAttribute("y", 25)
        self.text.setAttribute("font-size", 30)
        self.text.setAttribute("font-family", "Arial")
        self.text.setAttribute("text-anchor", "middle")
        self.text.setAttribute("dominant-baseline", "middle")
        self.text.textContent = f"{particleId+1}:"
        self.group.appendChild(self.text)
        self.circle = document.createElementNS("http://www.w3.org/2000/svg", "circle")
        self.circle.setAttribute("cx", 50)
        self.circle.setAttribute("cy", 25)
        self.circle.setAttribute("r", 20)
        self.circle.setAttribute("stroke", "black")
        #self.circle.setAttribute("fill", "lightgray")
        self.circle.setAttribute("class", "particle")
        self.group.appendChild(self.circle)
        self.line = document.createElementNS("http://www.w3.org/2000/svg", "line")
        self.line.setAttribute("x1", 30)
        self.line.setAttribute("y1", 25)
        self.line.setAttribute("x2", 70)
        self.line.setAttribute("y2", 25)
        self.line.setAttribute("stroke", "black")
        self.line.setAttribute("stroke-width", 2)
        self.group.appendChild(self.line)

    def setAngle(self, angle):
        if self.group is not None:
          angle = -angle # Needed to make visualization compatible with spin visualizer
          self.line.setAttribute("transform", f"rotate({angle}, 50, 25)")
        
    def removeFromCanvas(self):
        if self.group is not None:
          self.group.remove()
          self.group = None
        
    

class BaseVectorVis:
    def __init__(self, parent, baseVectorNumber, targetBox, N):
        self.N = N
        self.basisState = qm.BasisState(baseVectorNumber, self.N)
        self.parts = []
        targetBox = targetBox.shrinkToAR(self.N * 80, 50)
        grid = geo2d.Grid(targetBox, numCellsX=self.N)
        for i in range(self.N):
          part = BaseVectorPartVis(parent, grid.getCell(i), particleId=i)
          self.parts.append(part)
        self.setBasisAngle(0)

    def setBasisAngle(self, basisAngle):
        for i in range(self.N):
          angleDelta = {'0':0, '1':90}[self.basisState.str[i]]
          self.parts[i].setAngle(basisAngle+angleDelta)

    def removeFromCanvas(self):
        for part in self.parts:
          part.removeFromCanvas()
    

class StatePartVis:
    def __init__(self, parent, baseVectorNumber, targetBox:geo2d.Region, N):
        self.N = N
        self.absValueDisplayMode = "normal"

        #spaceFractionTop = 0.7
        #yMiddle = targetBox.y + spaceFractionTop * targetBox.height
        #numberRegion = geo2d.Region(targetBox.x, targetBox.y, targetBox.width, targetBox.height*spaceFractionTop)
        #baseVectorRegion = geo2d.Region(targetBox.x, yMiddle, targetBox.width, targetBox.height*(1-spaceFractionTop))

        layout = geo2d.Layout(targetBox, axis="y", spaceDistribution=[0.3, 0.3, 0.7, 0.3])
        absValueRegion = layout.getCell(0)
        phaseAngleRegion = layout.getCell(1)
        numberRegion = layout.getCell(2)
        baseVectorRegion = layout.getCell(3)

        self.absValueVis = svgobjects.Label(parent, "absValue", absValueRegion.shrinkToAR(3,1))
        self.phaseAngleVis = svgobjects.Label(parent, "phase", phaseAngleRegion.shrinkToAR(3,1))
        self.numberVis = svgobjects.ComplexNumberVis(parent, numberRegion)        
        self.baseVectorVis = BaseVectorVis(parent, baseVectorNumber, baseVectorRegion.shrink(0, 5, 0, 0) , self.N)

        # Note: The shrinkToAR transformation is required, because the font size of svgobjects.Labels is
        # only determined by the region height. So shrinkToAR is needed to make the font smaller 
        # when we use more qbits and therefore the width gets lower.


    def setValue(self, value: complex):
        self.numberVis.setValue(value)
        if self.absValueDisplayMode == "sqrt(x)":
            percent = 100 * abs(value)**2
            self.absValueVis.setValue(f"sqrt( {percent:.1f}% )")
        else:
            self.absValueVis.setValue(f"{abs(value):.3f}")
        if abs(value) > 0.01:
            phase_deg = round(np.angle(value, deg=True)) % 360
            self.phaseAngleVis.setValue(f"{phase_deg}°")
        else:
            self.phaseAngleVis.setValue("--")


    def setAbsValueDisplayMode(self, value: str):
        assert value in ["normal", "sqrt(x)"]
        self.absValueDisplayMode = value


    def setBasisAngle(self, basisAngle: float):
        self.baseVectorVis.setBasisAngle(basisAngle)

        
    def removeFromCanvas(self):
        self.numberVis.removeFromCanvas()
        self.baseVectorVis.removeFromCanvas()
        self.absValueVis.removeFromCanvas()
        self.phaseAngleVis.removeFromCanvas()


class StateVis:
    def __init__(self, parent, targetBox, N):
        self.N = N # number of particles
        layout = geo2d.Grid(targetBox, numCellsX=2**self.N, padding=20)
        self.parts = [StatePartVis(parent, baseVectorNumber, layout.getCell(baseVectorNumber), self.N) for baseVectorNumber in range(2**self.N)]

    def setState(self, stateVector):
        if len(stateVector) != 2**self.N:
          raise Exception("State size doesn't fit current N value")
        
        for i in range(2**self.N):
          self.parts[i].setValue(stateVector[i][0])

    def setBasisAngle(self, basisAngle):
        for part in self.parts:
            part.setBasisAngle(basisAngle)

    def setAbsValueDisplayMode(self, value:str):
        assert value in ["normal", "sqrt(x)"]
        for part in self.parts:
            part.setAbsValueDisplayMode(value)

    def removeFromCanvas(self):
        for part in self.parts:
            part.removeFromCanvas()



# --- Functions --------------------------------------------------------------------------


def previewIsActive():
    return document.getElementById("previewSwitch").checked


def getSelectedVisBasis():
    return float(document.getElementById("basisSelection").value)


def get_mouse_position(evt):
    svgArea = document.getElementById("laboratory")
    pt = svgArea.createSVGPoint()
    
    # Handle both mouse and touch events
    if hasattr(evt, 'touches'):  # Touch event
        pt.x = evt.touches.item(0).clientX
        pt.y = evt.touches.item(0).clientY
    else:  # Mouse event
        pt.x = evt.clientX
        pt.y = evt.clientY
        
    cursorpt = pt.matrixTransform(svgArea.getScreenCTM().inverse())
    return cursorpt.x, cursorpt.y


def get_particle_bbox(particleId):
    particle = document.getElementById(f"particle{particleId}")
    bbox = particle.getBBox()

    point = document.getElementById("laboratory").createSVGPoint()
    point.x = bbox.x
    point.y = bbox.y
    point1 = point.matrixTransform(particle.getCTM())
    point.x += bbox.width
    point.y += bbox.height
    point2 = point.matrixTransform(particle.getCTM())
    result = geo2d.Region(point1.x, point1.y, point2.x - point1.x, point2.y - point1.y)    

    return result


def get_particle_position(particleId):
    return get_particle_bbox(particleId).getCenter()



def updateStateDisplay():
    rotationFactor = np.exp(-1j * g_time * 2 * np.pi / g_timeModulus)
    transfomedState = g_state.getStateVectorInBasis(getSelectedVisBasis()*qm.deg)
    g_stateVis.setState(rotationFactor * transfomedState)


def update_preview():
    g_measurementDevice1.update_preview()
    g_measurementDevice2.update_preview()


def updateAll():
    # xxx: Why this minus sign?
    g_stateVis.setBasisAngle(-getSelectedVisBasis())
    updateStateDisplay()
    update_preview()

    if document.getElementById("stateVisualizationSwitch").checked:
        document.getElementById("stateVisualizationDiv").style.display = "inline" # block
    else:
        document.getElementById("stateVisualizationDiv").style.display = "none"


def updateNumberOfParticles(newNumber):
    global g_currentlyDisplayedParticles, g_state, g_stateVis
    
    stateVisualization = document.getElementById("stateVisualizationSVG")
    if g_stateVis is not None:
      g_stateVis.removeFromCanvas()
    g_stateVis = StateVis(stateVisualization, geo2d.Region(0, 0, 700, 200).shrink(10, 10, 10, 10), newNumber)
    
    for particleId in range(g_currentlyDisplayedParticles):
      particle = document.getElementById(f"particle{particleId}")
      particle.remove()
    g_currentlyDisplayedParticles = 0

    parent = document.getElementById("laboratory")    
    for particleId in range(newNumber):
      Particle(parent, particleId, 100 +80*particleId, 300)
    g_currentlyDisplayedParticles = newNumber

    g_state = qm.State(newNumber)


# --- Events --------------------------------------------------------------------------


def help(event):
    from js import window
    window.open("help.html", "_blank")


def undo(event):
    print("Undoing last state change....")
    g_state.undo()
    updateAll()


def toggle_controls(event):
    controls = document.getElementById('laboratory-controls-group-2')
    toggle_btn = document.getElementById('toggleControlsBtn')
    
    if controls.style.display == 'none':
        controls.style.display = 'flex'
        toggle_btn.textContent = '-'
    else:
        controls.style.display = 'none'
        toggle_btn.textContent = '+'


def on_particle_number_change(event):
    newNumber = int(document.getElementById("particleNumberSelect").value)
    updateNumberOfParticles(newNumber)
    updateAll()


def onTimeTick():
    global g_time    
    if not document.getElementById("freezeTimeSwitch").checked or not g_time == 0:
        g_time = (g_time + 1) % g_timeModulus
        updateStateDisplay()


def on_state_visualization_switch():
    updateAll()


def on_preview_switch():
    update_preview()


def on_basis_selection_change():
    updateAll()


def on_absValueDisplayModechange():
    value = document.getElementById("absValueDisplayModeSelect").value
    g_stateVis.setAbsValueDisplayMode(value)


def on_mouse_down(evt):
    evt.preventDefault()
    global dragging, offset_x, offset_y
    mouse_x, mouse_y = get_mouse_position(evt)

    # Find touched particleId with highest number (topmost)
    dragging = None
    for particleId in range(g_currentlyDisplayedParticles):
      if get_particle_bbox(particleId).contains(mouse_x, mouse_y):
        dragging = particleId
    if dragging is None:
      return

    bbox = get_particle_bbox(dragging)
    particle = document.getElementById(f"particle{dragging}")
    transform = particle.getAttribute("transform")
    translate_x = float(transform.split(",")[0].split("(")[1])
    translate_y = float(transform.split(",")[1].split(")")[0])

    offset_x = mouse_x - translate_x
    offset_y = mouse_y - translate_y
    
    # Prevent default behavior for touch events to avoid scrolling
    if hasattr(evt, 'touches'):
        evt.preventDefault()

def on_touch_start(evt):
    on_mouse_down(evt)

def on_mouse_up(evt):
    evt.preventDefault()
    global dragging
    dragging = None

def on_touch_end(evt):
    on_touch_end(evt)

def on_mouse_move(evt):
    if dragging is None:
        return
    mouse_x, mouse_y = get_mouse_position(evt)
    particle1 = document.getElementById(f"particle{dragging}")
    particle1.setAttribute("transform", f"translate({mouse_x - offset_x}, {mouse_y - offset_y})")      
    update_preview()

def on_touch_move(evt):
    if dragging is not None:
        evt.preventDefault()  # Prevent scrolling while dragging
        on_mouse_move(evt)


def on_svg_dblclick(evt):
    evt.preventDefault()
    print("Doubleclick!")


# --- Initialization -----------------------------------------------------------


def init():
    global g_measurementDevice1, g_measurementDevice2, g_stateVis
    document.getElementById("loading").style.visibility="hidden"
    parent = document.getElementById("laboratory")

    def pos(i):
      return 10 + i*150

    PreparationDevice(parent, pos(0), 45)
    EntanglementDevice(parent, pos(1), 45)
    g_measurementDevice1 = MeasurementDevice(parent, pos(2), 45)
    g_measurementDevice2 = MeasurementDevice(parent, pos(3), 45)
    Transform1Device(parent, pos(4), 45)
    Transform2Device(parent, pos(5), 45)

    updateNumberOfParticles(2)

    # Mouse events
    parent.addEventListener("mousedown", create_proxy(on_mouse_down))
    parent.addEventListener("mouseup", create_proxy(on_mouse_up))
    parent.addEventListener("mousemove", create_proxy(on_mouse_move))
    
    # Touch events
    parent.addEventListener("touchstart", create_proxy(on_touch_start))
    parent.addEventListener("touchend", create_proxy(on_touch_end))
    parent.addEventListener("touchmove", create_proxy(on_touch_move))
    
    # Double-click event to prevent text selection
    options = Object()
    options.capture = True
    options.passive = False
    parent.addEventListener("dblclick", create_proxy(on_svg_dblclick), options)
    
    document.getElementById("stateVisualizationSwitch").addEventListener("change", create_proxy(lambda evt: on_state_visualization_switch()))
    document.getElementById("basisSelection").addEventListener("change", create_proxy(lambda evt: on_basis_selection_change()))
    document.getElementById("absValueDisplayModeSelect").addEventListener("change", create_proxy(lambda evt: on_absValueDisplayModechange()))
    document.getElementById("previewSwitch").addEventListener("change", create_proxy(lambda evt: on_preview_switch()))
    document.getElementById("particleNumberSelect").addEventListener("change", create_proxy(lambda evt: on_particle_number_change(evt)))
    setInterval(create_proxy(onTimeTick), 100)
    updateAll()
    random.seed(42)
    print("Startup done!")


init()

