Skip to content

Commit

Permalink
some small fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Alexander Piskun <bigcat88@icloud.com>
  • Loading branch information
bigcat88 committed Dec 6, 2024
1 parent 1cbe5bb commit 20a2b2e
Showing 1 changed file with 35 additions and 72 deletions.
107 changes: 35 additions & 72 deletions models_catalog_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
QApplication,
QCheckBox,
QDialog,
QFormLayout,
QGridLayout,
QGroupBox,
QHBoxLayout,
Expand Down Expand Up @@ -257,7 +256,6 @@ def process_huggingface_url(self):
self.default_filename_found.emit(default_filename)

# Fetch hash for HuggingFace models
model_hash = ""
if "huggingface.co" in corrected_url:
model_hash = self.prefill_hash_from_huggingface(corrected_url)
if model_hash:
Expand Down Expand Up @@ -393,7 +391,8 @@ def search_civitai_by_hash(self, model_hash):
self.model_type_found.emit(types_value)
else:
self.status_update.emit(
f"Unknown combination of model_type '{model_type}' and file_type '{file_type}', please report this",
f"Unknown combination of model_type '{model_type}' and "
f"file_type '{file_type}', please report this",
True,
)
else:
Expand Down Expand Up @@ -463,23 +462,24 @@ def create_stretching_line_edit(placeholder_text):
form_layout.addWidget(self.homepage_edit, 3, 1)

# HuggingFace Filename
self.filename_edit = create_stretching_line_edit("HuggingFace Filename")
self.filename_edit.setReadOnly(True)
self.filename_hg = create_stretching_line_edit("HuggingFace Filename")
self.filename_hg.setReadOnly(True)
form_layout.addWidget(QLabel("HugFace name:"), 4, 0)
form_layout.addWidget(self.filename_edit, 4, 1)
form_layout.addWidget(self.filename_hg, 4, 1)

# CivitAI Filename
self.second_filename_label = create_stretching_line_edit("CivitAI Filename")
self.second_filename_label.setReadOnly(True)
self.filename_ca = create_stretching_line_edit("CivitAI Filename")
self.filename_ca.setReadOnly(True)
form_layout.addWidget(QLabel("CivitAI name:"), 5, 0)
form_layout.addWidget(self.second_filename_label, 5, 1)
form_layout.addWidget(self.filename_ca, 5, 1)

# Overridden Filename
self.overridden_filename_edit = create_stretching_line_edit(
self.overridden_filename = create_stretching_line_edit(
"Enter overridden filename"
)
self.overridden_filename.textChanged.connect(self.update_input_value_regex)
form_layout.addWidget(QLabel("Force Filename:"), 6, 0)
form_layout.addWidget(self.overridden_filename_edit, 6, 1)
form_layout.addWidget(self.overridden_filename, 6, 1)

# Hash
self.hash_edit = create_stretching_line_edit("Enter SHA256 hash")
Expand Down Expand Up @@ -591,9 +591,9 @@ def save_model(self):
# Gather data from fields
homepage = self.homepage_edit.text().strip()
url = self.download_url_edit.text().strip()
filename = self.filename_edit.text().strip()
second_filename = self.second_filename_label.text().strip()
overridden_filename = self.overridden_filename_edit.text().strip()
filename = self.filename_hg.text().strip()
second_filename = self.filename_ca.text().strip()
overridden_filename = self.overridden_filename.text().strip()
model_hash = self.hash_edit.text().strip()
gated = self.gated_checkbox.isChecked()

Expand Down Expand Up @@ -633,10 +633,7 @@ def save_model(self):
)
return

# Validate regex captures filenames
filenames = {filename, second_filename, overridden_filename} - {
""
} # Remove empty strings
filenames = {filename, second_filename, overridden_filename} - {""}
invalid_filenames = [
fname for fname in filenames if not input_value_pattern.search(fname)
]
Expand Down Expand Up @@ -672,10 +669,6 @@ def save_model(self):

if overridden_filename:
model_entry["filename"] = overridden_filename
else:
default_filename = self.get_default_filename(url)
if filename and filename != default_filename:
model_entry["filename"] = filename

if types:
model_entry["types"] = types
Expand All @@ -687,7 +680,7 @@ def save_model(self):
existing_keys_with_hash = [
key
for key, entry in self.models_catalog.items()
if entry.get("hash") == model_hash
if entry.get("hash", "").lower() == model_hash.lower()
]

# Ensure no filename matches more than one existing regex in the catalog
Expand Down Expand Up @@ -745,7 +738,7 @@ def get_model_key(self):
existing_keys_with_hash = [
key
for key, entry in self.models_catalog.items()
if entry.get("hash") == model_hash
if entry.get("hash", "").lower() == model_hash.lower()
]

if existing_keys_with_hash:
Expand All @@ -760,42 +753,16 @@ def get_model_key(self):
if dialog.exec() == QDialog.Accepted:
result_key = dialog.get_result()
# If the hash exists and the user changes the key, we need to delete the old record
if existing_keys_with_hash and result_key != existing_key:
if existing_keys_with_hash and result_key != existing_keys_with_hash[0]:
# Delete old record
del self.models_catalog[existing_key]
# Log this action
del self.models_catalog[existing_keys_with_hash[0]]
self.append_log(
f"Deleted existing model with key '{existing_key}' due to hash conflict.",
f"Deleted existing model with key '{existing_keys_with_hash[0]}' due to hash conflict.",
True,
)
return result_key
return None

def get_default_filename(self, url):
try:
headers = {}
if "huggingface.co" in url and self.huggingface_token:
headers["Authorization"] = f"Bearer {self.huggingface_token}"
response = requests.head(url, headers=headers, allow_redirects=True)
if "Content-Disposition" in response.headers:
cd = response.headers["Content-Disposition"]
if "filename*=" in cd:
match = re.search(
r"filename\*\s*=\s*[^']*'[^']*'([^;]+)", cd, re.IGNORECASE
)
if match:
return match.group(1)
match = re.search(r'filename\s*=\s*"([^"]+)"', cd, re.IGNORECASE)
if match:
return match.group(1)
match = re.search(r"filename\s*=\s*([^;]+)", cd, re.IGNORECASE)
if match:
return match.group(1).strip()
return os.path.basename(response.url)
except Exception as e:
print(f"Error getting default filename: {e}")
return ""

def process_url(self):
self.clear_form()

Expand Down Expand Up @@ -846,23 +813,21 @@ def handle_homepage_extracted(self, homepage):
self.homepage_edit.setText(homepage)

def handle_default_filename_found(self, filename):
self.filename = filename
self.filename_edit.setText(filename)
self.filename_hg.setText(filename)
self.update_input_value_regex()

def handle_second_filename_found(self, second_filename):
self.second_filename = second_filename
self.second_filename_label.setText(second_filename)
self.filename_ca.setText(second_filename)
self.update_input_value_regex()

def update_input_value_regex(self):
# Combine filenames into a set to remove duplicates
filenames = set()
if self.filename:
filenames.add(self.filename)
if self.second_filename:
filenames.add(self.second_filename)
overridden_filename = self.overridden_filename_edit.text().strip()
if self.filename_hg.text().strip():
filenames.add(self.filename_hg.text().strip())
if self.filename_ca.text().strip():
filenames.add(self.filename_ca.text().strip())
overridden_filename = self.overridden_filename.text().strip()
if overridden_filename:
filenames.add(overridden_filename)

Expand Down Expand Up @@ -929,7 +894,6 @@ def handle_files_found(self, files, model_type):
size = metadata.get("size", "Unknown size")
fp = metadata.get("fp", "Unknown FP")

# Add additional metadata to differentiate files with the same name
description = f"{name} (Type: {file_type}, Size: {size}, FP: {fp})"
file_descriptions.append(description)

Expand Down Expand Up @@ -966,8 +930,7 @@ def handle_files_found(self, files, model_type):
return

if filename:
self.second_filename = filename
self.second_filename_label.setText(filename)
self.filename_ca.setText(filename)
self.update_input_value_regex()
self.append_log(f"Set CivitAI filename: {filename}", False)
else:
Expand Down Expand Up @@ -1027,11 +990,9 @@ def clear_full_form(self):
def clear_form(self):
self.homepage_edit.clear()
self.download_url_edit.clear()
self.filename_edit.clear()
self.second_filename_label.setText("")
self.filename = ""
self.second_filename = ""
self.overridden_filename_edit.clear()
self.filename_hg.clear()
self.filename_ca.clear()
self.overridden_filename.clear()
self.hash_edit.clear()
self.gated_checkbox.setChecked(False)
self.class_type_edit.clear()
Expand Down Expand Up @@ -1122,9 +1083,11 @@ def validate_input(self):
else:
self.warning_label.clear()
if self.existing_key:
self.warning_label.setText(
f"A model with the same hash exists under the name '{self.existing_key}."
)
self.info_label.setText(
f"A model with the same hash exists under the name '{self.existing_key}'.\n"
f"Proceeding will delete the old entry and save under the new name."
"Proceeding will delete the old entry and save under the new name."
)
else:
self.info_label.clear()
Expand Down

0 comments on commit 20a2b2e

Please sign in to comment.