Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

benchmark improvements #40

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 128 additions & 72 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
BASIC_AUTH = httpx.BasicAuth(USER_NAME, USER_PASS) if USER_NAME and USER_PASS else None
if BASIC_AUTH:
print("Using authentication for connect")
PAUSE_INTERVAL = int(os.environ.get("PAUSE_INTERVAL", "0"))
FIRST_TEST_FLAG = True

SELECTED_TEST_FLOW_SUITE = []

Expand Down Expand Up @@ -192,37 +194,7 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
]
TEST_CASES_OTHER = [
FlowTest(
flow_name="stable_cascade",
test_cases=[
TestCase(
name="one_pass",
input_params={"prompt": "green apple", "pass_count": "One pass"},
),
TestCase(
name="two_pass",
input_params={"prompt": "green apple", "pass_count": "Two pass"},
),
TestCase(
name="three_pass",
input_params={"prompt": "green apple", "pass_count": "Three pass"},
),
],
),
FlowTest(
flow_name="hunyuan_dit",
test_cases=[
TestCase(
name="20steps",
input_params={"prompt": "green apple", "steps_count": 20},
),
TestCase(
name="40steps",
input_params={"prompt": "green apple", "steps_count": 40},
),
],
),
TEST_CASES_PORTRAITS = [
FlowTest(
flow_name="vintage_portrait",
test_cases=[
Expand Down Expand Up @@ -284,16 +256,48 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
FlowTest(
flow_name="remove_background_birefnet",
flow_name="photo_stickers",
test_cases=[
TestCase(
name="1024x1024",
name="default",
input_files=["man.png"],
),
],
),
]
TEST_CASES_OTHER = [
FlowTest(
flow_name="remove_background_bria",
flow_name="stable_cascade",
test_cases=[
TestCase(
name="one_pass",
input_params={"prompt": "green apple", "pass_count": "One pass"},
),
TestCase(
name="two_pass",
input_params={"prompt": "green apple", "pass_count": "Two pass"},
),
TestCase(
name="three_pass",
input_params={"prompt": "green apple", "pass_count": "Three pass"},
),
],
),
FlowTest(
flow_name="hunyuan_dit",
test_cases=[
TestCase(
name="20steps",
input_params={"prompt": "green apple", "steps_count": 20},
),
TestCase(
name="40steps",
input_params={"prompt": "green apple", "steps_count": 40},
),
],
),
FlowTest(
flow_name="remove_background_birefnet",
test_cases=[
TestCase(
name="1024x1024",
Expand All @@ -302,20 +306,20 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
FlowTest(
flow_name="supir_upscaler",
flow_name="remove_background_bria",
test_cases=[
TestCase(
name="1MPx1.5",
input_params={"scale_factor": 1.5},
name="1024x1024",
input_files=["man.png"],
),
],
),
FlowTest(
flow_name="photo_stickers",
flow_name="supir_upscaler",
test_cases=[
TestCase(
name="default",
name="1MPx1.5",
input_params={"scale_factor": 1.5},
input_files=["man.png"],
),
],
Expand Down Expand Up @@ -358,30 +362,38 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
),
]
validate_flows_test_cases(
TEST_CASES_SDXL + TEST_CASES_OTHER + TEST_CASES_FLUX + TEST_CASES_HEAVY
TEST_CASES_SDXL
+ TEST_CASES_PORTRAITS
+ TEST_CASES_OTHER
+ TEST_CASES_FLUX
+ TEST_CASES_HEAVY
)


async def select_test_flow_suite():
global SELECTED_TEST_FLOW_SUITE
print("Please select the test suite you want to run:")
print("1. SDXL Suite")
print("2. FLUX Suite")
print("3. OTHER Suite")
print("4. HEAVY(24GB+ VRAM) Suite")
print("2. PORTRAITS Suite")
print("3. FLUX Suite")
print("4. OTHER Suite")
print("5. HEAVY(24GB+ VRAM) Suite")

user_choice = input("Enter the number of the suite (1/2/3/4): ")
user_choice = input("Enter the number of the suite (1/2/3/4/5): ")

if user_choice == "1":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_SDXL
print("Selected SDXL Suite.")
elif user_choice == "2":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_FLUX
print("Selected FLUX Suite.")
SELECTED_TEST_FLOW_SUITE = TEST_CASES_PORTRAITS
print("Selected PORTRAITS Suite.")
elif user_choice == "3":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_OTHER
print("Selected OTHER Suite.")
elif user_choice == "4":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_FLUX
print("Selected FLUX Suite.")
elif user_choice == "5":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_HEAVY
print("Selected HEAVY Suite.")
else:
Expand Down Expand Up @@ -450,6 +462,7 @@ async def wait_for_installation_to_complete(
flow_name: str, poll_interval: int = 5, timeout: int = FLOW_INSTALL_TIMEOUT
) -> bool:
elapsed_time = 0
max_read_timeout_count = 20
async with httpx.AsyncClient(auth=BASIC_AUTH) as client:
while elapsed_time < timeout:
try:
Expand All @@ -459,27 +472,41 @@ async def wait_for_installation_to_complete(

for flow in install_progress:
if flow["name"] == flow_name:
if flow["error"]:
print(
f"Error during installation of flow '{flow_name}': {flow['error']}"
)
return False
if flow["progress"] == 100:
if flow["error"]:
print(
f"Error during installation of flow '{flow_name}': {flow['error']}"
)
return False
print(f"Flow '{flow_name}' installation completed.")
return True
rounded_flow_progress = math.floor(flow["progress"] * 10) / 10
print(
f"Flow '{flow_name}' installation progress: {rounded_flow_progress}%"
)
break

await asyncio.sleep(poll_interval)
elapsed_time += poll_interval
max_read_timeout_count = 20
except httpx.ReadTimeout:
max_read_timeout_count -= 1
print(
f"Flow '{flow_name}': ReadTimeout error during installation progress check, "
f"continuing... {max_read_timeout_count} tries left"
)
if not max_read_timeout_count:
print(
f"Installation of flow '{flow_name}' failed due to repeated timeouts."
)
return False
except httpx.RequestError as exc:
print(
f"An error occurred while checking installation progress: {exc.request.url!r}: {exc}"
)
return False

print(f"Installation of flow '{flow_name}' timed out.")
print(f"Installation of flow '{flow_name}' timed out after {timeout} seconds.")
return False


Expand Down Expand Up @@ -523,8 +550,9 @@ async def create_task(
return []


async def get_task_progress(task_id: int, poll_interval: int = 3) -> dict:
async def get_task_progress(task_id: int, poll_interval: int = 5) -> dict:
max_read_timeout_count = 20
previous_progress = 0.0
async with httpx.AsyncClient(auth=BASIC_AUTH) as client:
while True:
try:
Expand All @@ -533,20 +561,31 @@ async def get_task_progress(task_id: int, poll_interval: int = 3) -> dict:
)
if response.status_code == 200:
task_data = response.json()
if task_data.get("progress") == 100.0 or task_data.get("error"):
if task_data.get("error"):
print(
f"Task with id={task_id} failed with error: {task_data['error']}"
)
return task_data
rounded_task_progress = (
math.floor(task_data.get("progress") * 10) / 10
)
print(
f"Task `{task_data['name']}` with id={task_id}, progress: {rounded_task_progress}%"
)
progress = task_data.get("progress", 0.0)
if progress == 100.0:
print(
f"Task `{task_data['name']}` with id={task_id}, progress: 100%"
)
return task_data
if progress > 0.0:
# Only print progress if it has increased
if progress != previous_progress:
rounded_task_progress = math.floor(progress * 10) / 10
print(
f"Task `{task_data['name']}` with id={task_id}, progress: {rounded_task_progress}%"
)
previous_progress = progress
max_read_timeout_count = 20
await asyncio.sleep(poll_interval)
except httpx.ReadTimeout:
max_read_timeout_count -= 1
print(
f"Task with id={task_id}: ReadTimeout error, continue... {max_read_timeout_count} tries left"
f"Task with id={task_id}: ReadTimeout error, continuing... {max_read_timeout_count} tries left"
)
if not max_read_timeout_count:
return {"error": "read timeout error"}
Expand Down Expand Up @@ -575,6 +614,17 @@ async def run_test_case(

# Use the test_semaphore to limit concurrent tests
async with test_semaphore:
global FIRST_TEST_FLAG

if FIRST_TEST_FLAG:
FIRST_TEST_FLAG = False
else:
if PAUSE_INTERVAL:
print(
f"Pausing for {PAUSE_INTERVAL} seconds before starting next test."
)
await asyncio.sleep(PAUSE_INTERVAL)

input_params = test_case.input_params
input_files = test_case.input_files
warm_up = test_case.warm_up
Expand All @@ -588,16 +638,15 @@ async def run_test_case(
)
return

# Wait for tasks to finish and save results
task_results = []
flow_comfy_saved = False
for task_id in task_ids:
result = await get_task_progress(task_id)
task_results.append(result)
# Create a list of coroutines for task progress
task_progress_coroutines = [get_task_progress(task_id) for task_id in task_ids]
task_results = await asyncio.gather(*task_progress_coroutines)

if not flow_comfy_saved and "flow_comfy" in result:
# Save the flow_comfy from the first task that has it
for result in task_results:
if "flow_comfy" in result:
await save_flow_comfy(flow_name, test_case.name, result["flow_comfy"])
flow_comfy_saved = True
break # No need to check further once saved

if warm_up:
task_results = task_results[1:]
Expand All @@ -606,7 +655,12 @@ async def run_test_case(
summary = await save_results(flow_name, test_case.name, test_case, task_results)
if summary:
results_summary[flow_name].append(
{"test_case": test_case.name, "avg_exec_time": summary["avg_exec_time"]}
{
"test_case": test_case.name,
"avg_exec_time": summary["avg_exec_time"],
"min_exec_time": summary["min_exec_time"],
"max_exec_time": summary["max_exec_time"],
}
)

# Save the updated results summary after each test case
Expand Down Expand Up @@ -836,10 +890,12 @@ async def save_output(
def get_suite_identifier() -> str:
if SELECTED_TEST_FLOW_SUITE == TEST_CASES_SDXL:
return "SDXL"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_FLUX:
return "FLUX"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_PORTRAITS:
return "PORTRAITS"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_OTHER:
return "OTHER"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_FLUX:
return "FLUX"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_HEAVY:
return "HEAVY"
raise RuntimeError("Unknown TEST SUITE")
Expand Down