Skip to content

Commit

Permalink
Add DODO-GPT
Browse files Browse the repository at this point in the history
Add various new features relating to Chat-GPT and Whisper APIs; namely "Dodo Hint", "Dodo Query", "Import image as ZX-diagram", and the text-to-speech and speech-to-text.
  • Loading branch information
mjsutcliffe99 committed Jul 26, 2024
1 parent f1fb7b2 commit 5d9f41b
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 5 deletions.
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,8 @@ sphinx_autodoc_typehints>=1.10
sphinx_rtd_theme>=0.4
sphinxcontrib-svg2pdfconverter>=1.2.2
myst-parser>=3.0.0
imageio
imageio
openai
pydub
sounddevice
ffmpeg-python
20 changes: 18 additions & 2 deletions zxlive/dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
class FileFormat(Enum):
"""Supported formats for importing/exporting diagrams."""

All = "zxg *.json *.qasm *.tikz *.zxp *.zxr *.gif", "All Supported Formats"
All = "zxg *.json *.qasm *.tikz *.zxp *.zxr *.gif *.png *.jpg", "All Supported Formats"
QGraph = "zxg", "QGraph" # "file extension", "format name"
QASM = "qasm", "QASM"
TikZ = "tikz", "TikZ"
Json = "json", "JSON"
ZXProof = "zxp", "ZXProof"
ZXRule = "zxr", "ZXRule"
Gif = "gif", "Gif"
PNG = "png", "PNG"
JPEG = "jpg", "JPEG"
_value_: str

def __new__(cls, *args, **kwds): # type: ignore
Expand Down Expand Up @@ -77,7 +79,6 @@ class ImportRuleOutput:
file_path: str
r: CustomRule


def show_error_msg(title: str, description: Optional[str] = None, parent: Optional[QWidget] = None) -> None:
"""Displays an error message box."""
msg = QMessageBox(parent) #Set the parent of the QMessageBox
Expand Down Expand Up @@ -258,6 +259,21 @@ def export_gif_dialog(parent: QWidget) -> Optional[str]:
if file_path_and_format is None or not file_path_and_format[0]:
return None
return file_path_and_format[0]

def import_image_dialog(parent: QWidget) -> Optional[ImportGraphOutput | ImportProofOutput | ImportRuleOutput]:
"""Shows a dialog to import a diagram from an image on disk.
Generates and returns the imported graph or `None` if the import failed."""
file_path, selected_filter = QFileDialog.getOpenFileName(
parent=parent,
caption="Select Image",
filter=FileFormat.PNG.filter
)
if selected_filter == "":
# This happens if the user clicks on cancel
return None

return file_path

def get_lemma_name_and_description(parent: MainWindow) -> tuple[Optional[str], Optional[str]]:
dialog = QDialog(parent)
Expand Down
261 changes: 261 additions & 0 deletions zxlive/dodo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# DODO-GPT is brought to you by Dave the Dodo of Picturing Quantum Processes fame

import pyzx as zx
import openai
import sounddevice as sd
import os
import numpy as np
import base64
import requests
from pydub import AudioSegment
from .utils import GET_MODULE_PATH
from enum import IntEnum

API_KEY = 'no-key'

# Need to add error handling still! (e.g. if invalid key, or if no connection to chat-gpt server, etc.)
# Maybe add a config (i.e. to set the api_key, the gpt model, the OpenAI Whisper voice, etc.)
# Maybe also record your full transripts with DODO-GPT (and have the option of whether to save it to file when you close zxlive)

#TEMP... (# TEMP - this should just be calling xz.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name)
class VertexType(IntEnum):
"""Type of a vertex in the graph."""
BOUNDARY = 0
Z = 1
X = 2
H_BOX = 3
W_INPUT = 4
W_OUTPUT = 5
Z_BOX = 6

#TEMP... (# TEMP - this should just be calling xz.utils.VertexType instead (not sure why that doesn't work though?): VertexType(1).name)
class EdgeType(IntEnum):
"""Type of an edge in the graph."""
SIMPLE = 1
HADAMARD = 2
W_IO = 3

def get_local_api_key():
"""Get the API key from key.txt file"""
f = open(GET_MODULE_PATH()+"/user/key.txt", "r")
global API_KEY
API_KEY = f.read()
f.close()

def query_chatgpt(prompt,model="gpt-3.5-turbo"):
"""Send a prompt to Chat-GPT and return its response."""
client = openai.OpenAI(
api_key = API_KEY
)

chat_completion = client.chat.completions.create(
messages=[
{
"role":"user",
"content":prompt
}
],
model=model#"gpt-3.5-turbo"#"gpt-4o"
)
return chat_completion.choices[0].message.content

def prep_hint(g):
"""Generate the default/generic 'ask for hint' type prompt for DODO-GPT."""
strPrime = describe_graph(g)
strQuery = strPrime + "Please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips. Please also keep your answer simple and do not write any diagrams or data in return."
return strQuery

def describe_graph(g):
"""Prime DODO-GPT before making a query. Returns a string for describing to DODO-GPT the current ZX-diagram."""

VertexType = {0:'BOUNDARY',1:'Z',2:'X'}
EdgeType = {1:'SIMPLE',2:'HADAMARD'}

strPrime = "\nConsider a ZX-diagram defined by the following spiders:\n\nlabel, type, phase\n"
for v in g.vertices(): strPrime += str(v) + ', ' + str(VertexType[g.type(v)]) + ', ' + str(g.phase(v)) + '\n'

strPrime += "\nwith the following edges:\n\nsource,target,type\n"
for e in g.edges(): strPrime += str(e[0]) + ', ' + str(e[1]) + ', ' + str(EdgeType[g.edge_type(e)]) + '\n'
strPrime += '\n'

#Follows the format...
#
#"""
#Consider a ZX-diagram defined by the following spiders:
#
#label, type, phase
#0, Z, 0.25
#1, Z, 0.5
#2, X, 0.5
#
#with the following edges:
#
#source,target,type
#0, 1, SIMPLE
#1, 2, HADAMARD
#
#"""

return strPrime

def text_to_speech(text):
"""Generates an mp3 file reading the given text (via OpenAI Whisper)."""
client = openai.OpenAI(
api_key = API_KEY
)

response = client.audio.speech.create(
model="tts-1",
voice="nova",
input=text,
)

response.stream_to_file(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3")
os.system(GET_MODULE_PATH() + "/temp/Dodo_Dave_latest.mp3") #TEMP/TODO - THIS SHOULD BE USING A PROPER IN-APP AUDIO PLAYER RATHER THAN OS

def record_audio(duration=5, sample_rate=44100):
recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=2, dtype='int16')
sd.wait()
return recording

def save_as_mp3(audio_data, sample_rate=44100):
file_path = GET_MODULE_PATH() + "/temp/user_query_latest.mp3"
audio_segment = AudioSegment(
data=np.array(audio_data).tobytes(),
sample_width=2,
frame_rate=sample_rate,
channels=2
)
audio_segment.export(file_path, format='mp3')
return file_path

def transcribe_audio(file_path):
client = openai.OpenAI(api_key=API_KEY)
with open(file_path, "rb") as audio_file:
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=audio_file,
language='en'
)
#print(f'Transcription: {transcription.text}') #TEMP
return transcription.text

def speech_to_text():
sample_rate = 44100 # Sample rate in Hz
duration = 5 # Duration of recording in seconds
audio_data = record_audio(duration, sample_rate)
file_path = save_as_mp3(audio_data, sample_rate)
txt = transcribe_audio(file_path)
return txt

def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')

def get_image_prompt():
"""Returns the prompt that encourages DODO-GPT to describe the given image like a ZX-diagram."""

return """
Please convert this image into a ZX-diagram.
Then please provide a csv that lists of the spiders of this ZX-diagram, given the column headers:
index,type,phase,x-pos,y-pos
The type here should be given as either 'Z' or 'X' (ignore boundary spiders). The indexing should start from 0. And the phases should be written in terms of pi. x-pos and y-pos should respectively refer to their horizontal and vertical positions in the image, normalized from 0,0 (top-left) to 1,1 (bottom-right).
Then please provide a csv that lists the edges of this ZX-diagram, given the column headers:
source,target,type
The type here should be given as 1 for a normal (i.e. black) edge and 2 for a Hadamard (i.e. blue) edge, and the sources and targets should refer to the indices of the relevant spiders. Be sure to only include direct edges connecting two spiders.
Please ensure the csv's are expressed with comma separators and not in a table format.
"""
#After that, under a clearly marked heading "HINT", please advise me as to what ONE simplification step I should take to help immediately simplify this ZX-diagram. Please be specific to this case and not give general simplification tips.
#"""

def image_to_text(image_path):
"""Takes a ZX-diagram-like image and returns DODO-GPT's structured description of it."""

query = get_image_prompt()

# Getting the base64 string
base64_image = encode_image(image_path)

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}

payload = {
"model": "gpt-4o-mini",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": query
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}

response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)

#return response.json()
return response.json()['choices'][0]['message']['content']

def response_to_zx(strResponse):
scale = 2.5

strResponse = strResponse
strResponse = strResponse[strResponse.index('index,type,phase'):]
strResponse = strResponse[strResponse.index('\n')+1:]
str_csv_verts = strResponse[:strResponse.index('```')-1]
strResponse = strResponse[strResponse.index('source,target,type'):]
strResponse = strResponse[strResponse.index('\n')+1:]
str_csv_edges = strResponse[:strResponse.index('```')-1]

g = zx.Graph()

for line in str_csv_verts.split('\n'):
idx,ty,ph,x,y = line.split(',')
g.add_vertex(qubit=float(y)*scale,row=float(x)*scale,ty=VertexType[ty],phase=ph)

for line in str_csv_edges.split('\n'):
source,target,ty = line.split(',')
g.add_edge((int(source),int(target)),int(ty))

return g

def action_dodo_hint(active_graph) -> None:
"""Queries DODO-GPT for a hint as to what simplification step should be taken next."""
#print("\n\nQUERY...\n\n", prep_hint(active_graph), "\n\nANSWER...\n\n") #TEMP
dodoResponse = query_chatgpt(prep_hint(active_graph))
#print(dodoResponse) #TEMP
text_to_speech(dodoResponse)

def action_dodo_query(active_graph) -> None:
"""Records the user's voice (plus the current ZX-diagram) and prompts DODO-GPT for a response."""
doIncludeGraph = True # Whether or not to pass information about the current ZX-diagram in with the DODO-GPT query
strPrime = describe_graph(active_graph)
userQuery = speech_to_text()
#print("\n\nQUERY...\n\n", strPrime+userQuery, "\n\nANSWER...\n\n") #TEMP
dodoResponse = query_chatgpt(strPrime+userQuery)
#print(dodoResponse) #TEMP
text_to_speech(dodoResponse)

def action_dodo_image_to_zx(path) -> None:
"""Queries DODO-GPT to generate a ZX-diagram from an image."""
strResponse = image_to_text(path)
#print(strResponse) #TEMP
new_graph = response_to_zx(strResponse)
return new_graph
21 changes: 20 additions & 1 deletion zxlive/edit_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Iterator

from PySide6.QtCore import Signal, QSettings
from PySide6.QtGui import QAction
from PySide6.QtGui import (QAction, QIcon)
from PySide6.QtWidgets import (QToolButton)
from pyzx import EdgeType, VertexType, sqasm
from pyzx.circuit.qasmparser import QASMParser
Expand All @@ -18,6 +18,7 @@
from .graphscene import EditGraphScene
from .graphview import GraphView
from .settings_dialog import input_circuit_formats
from .dodo import action_dodo_hint, action_dodo_query


class GraphEditPanel(EditorBasePanel):
Expand Down Expand Up @@ -60,6 +61,18 @@ def _toolbar_sections(self) -> Iterator[ToolbarSection]:
self.start_derivation.setText("Start Derivation")
self.start_derivation.clicked.connect(self._start_derivation)
yield ToolbarSection(self.start_derivation)

self.dodo_hint = QToolButton(self)
self.dodo_hint.setIcon(QIcon(get_data("icons/dodo.png")))
self.dodo_hint.setToolTip("Dodo Hint")
self.dodo_hint.clicked.connect(self._dodo_hint)

self.dodo_query = QToolButton(self)
self.dodo_query.setIcon(QIcon(get_data("icons/mic.svg")))
self.dodo_query.setToolTip("Dodo Query")
self.dodo_query.clicked.connect(self._dodo_query)
yield ToolbarSection(self.dodo_hint, self.dodo_query)


def _start_derivation(self) -> None:
if not self.graph_scene.g.is_well_formed():
Expand Down Expand Up @@ -107,3 +120,9 @@ def _input_circuit(self) -> None:
cmd = UpdateGraph(self.graph_view, new_g)
self.undo_stack.push(cmd)
self.graph_scene.select_vertices(new_verts)

def _dodo_hint(self) -> None:
action_dodo_hint(self.graph_scene.g)

def _dodo_query(self) -> None:
action_dodo_query(self.graph_scene.g)
Binary file added zxlive/icons/dodo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions zxlive/icons/mic.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 5d9f41b

Please sign in to comment.