-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
105 lines (83 loc) · 3.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import tkinter as tk
from tkinter import ttk, filedialog
import requests
import io
from PIL import Image, ImageTk
# Global variable to store the PIL Image object
pil_image = None
# Function to query the model and display the image
def generate_image():
global pil_image
prompt = entry.get()
width = int(width_entry.get() or 1080) # Default to 1080 if input is empty
height = int(height_entry.get() or 1080) # Default to 1080 if input is empty
payload = {"inputs": prompt}
# Query the model
try:
image_bytes, response_headers = query(payload)
# Check if the content type is correct
if 'image/jpeg' in response_headers.get('Content-Type'):
pil_image = Image.open(io.BytesIO(image_bytes))
pil_image.thumbnail((width, height)) # Resize image to fit in the specified resolution
img = ImageTk.PhotoImage(pil_image) # Convert to PhotoImage for display
image_label.config(image=img)
image_label.image = img # Keep a reference to avoid garbage collection
result_label.config(text="")
else:
result_label.config(text="The response is not an image.")
except Exception as e:
result_label.config(text=f"Error: {e}")
def get_token():
with open("token.txt", "r") as file:
return file.read().strip() # Reads the token and removes any extra spaces/newlines
def query(payload):
API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
token = get_token() # Get token from the file
headers = {"Authorization": f"Bearer {token}"}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
return response.content, response.headers
# Function to save the image
def save_image():
global pil_image
if pil_image:
# Open file dialog to select the save location
file_path = filedialog.asksaveasfilename(defaultextension=".jpg", filetypes=[("JPEG files", "*.jpg")])
if file_path:
pil_image.save(file_path)
result_label.config(text="Image saved successfully.")
else:
result_label.config(text="No image to save.")
# Create the main window
root = tk.Tk()
root.title("Image Generator")
# Create and pack widgets
frame = ttk.Frame(root, padding="10")
frame.pack(fill="both", expand=True)
entry = ttk.Entry(frame, width=50)
entry.pack(pady=10)
# Frame for resolution inputs
resolution_frame = ttk.Frame(frame)
resolution_frame.pack(pady=5)
width_label = ttk.Label(resolution_frame, text="Width:")
width_label.pack(side="left")
width_entry = ttk.Entry(resolution_frame, width=10)
width_entry.pack(side="left", padx=5)
height_label = ttk.Label(resolution_frame, text="Height:")
height_label.pack(side="left")
height_entry = ttk.Entry(resolution_frame, width=10)
height_entry.pack(side="left", padx=5)
# Frame for buttons
button_frame = ttk.Frame(frame)
button_frame.pack(pady=5)
generate_button = ttk.Button(button_frame, text="Generate Image", command=generate_image)
generate_button.pack(side="left")
download_button = ttk.Button(button_frame, text="Download Image", command=save_image)
download_button.pack(side="left", padx=5)
image_label = tk.Label(frame)
image_label.pack(pady=10)
result_label = ttk.Label(frame, text="")
result_label.pack(pady=5)
# Start the Tkinter event loop
root.mainloop()