Skip to content

Commit

Permalink
updated plotter to plot metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Archly2022 committed Sep 25, 2024
1 parent 679c9ca commit 20e57b0
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 216 deletions.
111 changes: 17 additions & 94 deletions demos/svc-sample.ipynb

Large diffs are not rendered by default.

113 changes: 79 additions & 34 deletions logllm/plot.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,85 @@
import json
import numpy as np
import matplotlib.pyplot as plt

def plot_ml_metrics(response_data):
# Parse the JSON string into a Python dictionary if needed
if isinstance(response_data, str):
response_data = json.loads(response_data)
def plot_metrics(*models_results):
# Process each model to ensure it's in dictionary format
processed_models = []
for model in models_results:
# Convert to dictionary if the input is a JSON string
if isinstance(model, str):
model = json.loads(model) # Convert string to dictionary
processed_models.append(model)

try:
# Define keys to exclude from the plot
exclude_keys = {'cache_size', 'random_state_tts', 'random_state',}
# Define keys to exclude from the plot
exclude_keys = {'cache_size', 'random_state_tts', 'random_state', 'random_state_1', 'n_estimators'}

# Extract all keys and numeric values from the response, excluding specific keys
metrics = []
values = []
for key, value in response_data.items():
# Check if the key is not in the exclude list and if the value is numeric
# Initialize containers for metrics, values, and model names
metrics = []
model_names = []
model_values = {}

for model in processed_models:
model_name = model.get("model_name", "Test Model")
model_names.append(model_name)

for key, value in model.items():
if key.startswith("result_name_"):
metric_name = value
metric_index = key.split("_")[-1] # Extract the index (e.g., "1" from "result_name_1")
metric_value = model.get(f"result_value_{metric_index}", None)

if metric_value is not None:
if metric_name not in metrics:
metrics.append(metric_name)
if metric_name not in model_values:
model_values[metric_name] = []

# Add the metric value to the list
model_values[metric_name].append(metric_value)

# Handle additional numeric values in the model (not using the result_name format)
for key, value in model.items():
# Exclude specific keys and ensure the value is numeric
if key not in exclude_keys and isinstance(value, (int, float)):
metrics.append(key)
values.append(value)

# Check if any numeric values were found
if not values:
print("No numeric values found to plot.")
return

# Plotting
plt.figure(figsize=(10, 6)) # Adjusted figure size for better visualization
plt.bar(metrics, values, color='pink')
plt.title('Model Performance Metrics', fontsize=16)
plt.xlabel('Parameter', fontsize=14)
plt.ylabel('Value', fontsize=14)
plt.ylim(0, max(values) * 1.1) # Dynamically set the y-limit based on the maximum value
plt.xticks(rotation=45, ha='right') # Rotate x-axis labels for better readability
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

except Exception as e:
print(f"Error processing the response data: {e}")
if key not in metrics:
metrics.append(key)
if key not in model_values:
model_values[key] = []
model_values[key].append(value)

# Handle cases where no valid metrics were provided
if not metrics or not model_names:
print("No valid metrics or model names found.")
return

# Ensure all models have values for all metrics, filling in with 0 if not available
for metric in metrics:
for i in range(len(model_names)):
if len(model_values[metric]) <= i:
model_values[metric].append(0) # Default value if missing

# Plotting side-by-side bar chart
x = np.arange(len(metrics)) # Label locations
bar_width = 0.15 # Width of the bars
fig, ax = plt.subplots(figsize=(10, 6))

# Create a bar for each model's performance metrics
for i, model_name in enumerate(model_names):
values = [model_values[metric][i] for metric in metrics]
ax.bar(x + i * bar_width, values, width=bar_width, label=model_name)

# Customization of the plot
ax.set_xlabel('Metric', fontsize=14)
ax.set_ylabel('Value', fontsize=14)
ax.set_title('Comparison of Model Performance Metrics', fontsize=16)
ax.set_xticks(x + bar_width * (len(model_names) - 1) / 2)
ax.set_xticklabels(metrics, fontsize=12)
ax.legend(title='Models')
ax.grid(True, axis='y', linestyle='--', alpha=0.7)

plt.xticks(rotation=45, ha='right') # Rotate x-axis labels for better readability
plt.tight_layout()
plt.show()


85 changes: 0 additions & 85 deletions logllm/plotter.py

This file was deleted.

6 changes: 3 additions & 3 deletions logllm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def query_gemini(user_input: str, code):

# General query function that calls the appropriate provider

def query(user_input: str, provider: str):
def query(provider):
if provider == 'openai':
return query_openai(user_input)
return query_openai()
elif provider == 'gemini':
return query_gemini(user_input)
return query_gemini()
else:
raise ValueError("Invalid provider specified. Use 'openai' or 'gemini'.")

Expand Down

0 comments on commit 20e57b0

Please sign in to comment.