Skip to content

Commit

Permalink
Merge pull request #40 from Visionatrix/feat/bench-improvements
Browse files Browse the repository at this point in the history
benchmark improvements
bigcat88 authored Sep 18, 2024
2 parents 0960658 + 0b977d4 commit c3c8153
Showing 1 changed file with 128 additions and 72 deletions.
200 changes: 128 additions & 72 deletions scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,8 @@
BASIC_AUTH = httpx.BasicAuth(USER_NAME, USER_PASS) if USER_NAME and USER_PASS else None
if BASIC_AUTH:
print("Using authentication for connect")
PAUSE_INTERVAL = int(os.environ.get("PAUSE_INTERVAL", "0"))
FIRST_TEST_FLAG = True

SELECTED_TEST_FLOW_SUITE = []

@@ -192,37 +194,7 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
]
TEST_CASES_OTHER = [
FlowTest(
flow_name="stable_cascade",
test_cases=[
TestCase(
name="one_pass",
input_params={"prompt": "green apple", "pass_count": "One pass"},
),
TestCase(
name="two_pass",
input_params={"prompt": "green apple", "pass_count": "Two pass"},
),
TestCase(
name="three_pass",
input_params={"prompt": "green apple", "pass_count": "Three pass"},
),
],
),
FlowTest(
flow_name="hunyuan_dit",
test_cases=[
TestCase(
name="20steps",
input_params={"prompt": "green apple", "steps_count": 20},
),
TestCase(
name="40steps",
input_params={"prompt": "green apple", "steps_count": 40},
),
],
),
TEST_CASES_PORTRAITS = [
FlowTest(
flow_name="vintage_portrait",
test_cases=[
@@ -284,16 +256,48 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
FlowTest(
flow_name="remove_background_birefnet",
flow_name="photo_stickers",
test_cases=[
TestCase(
name="1024x1024",
name="default",
input_files=["man.png"],
),
],
),
]
TEST_CASES_OTHER = [
FlowTest(
flow_name="remove_background_bria",
flow_name="stable_cascade",
test_cases=[
TestCase(
name="one_pass",
input_params={"prompt": "green apple", "pass_count": "One pass"},
),
TestCase(
name="two_pass",
input_params={"prompt": "green apple", "pass_count": "Two pass"},
),
TestCase(
name="three_pass",
input_params={"prompt": "green apple", "pass_count": "Three pass"},
),
],
),
FlowTest(
flow_name="hunyuan_dit",
test_cases=[
TestCase(
name="20steps",
input_params={"prompt": "green apple", "steps_count": 20},
),
TestCase(
name="40steps",
input_params={"prompt": "green apple", "steps_count": 40},
),
],
),
FlowTest(
flow_name="remove_background_birefnet",
test_cases=[
TestCase(
name="1024x1024",
@@ -302,20 +306,20 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
],
),
FlowTest(
flow_name="supir_upscaler",
flow_name="remove_background_bria",
test_cases=[
TestCase(
name="1MPx1.5",
input_params={"scale_factor": 1.5},
name="1024x1024",
input_files=["man.png"],
),
],
),
FlowTest(
flow_name="photo_stickers",
flow_name="supir_upscaler",
test_cases=[
TestCase(
name="default",
name="1MPx1.5",
input_params={"scale_factor": 1.5},
input_files=["man.png"],
),
],
@@ -358,30 +362,38 @@ def validate_flows_test_cases(flows_test_cases: list[FlowTest]):
),
]
validate_flows_test_cases(
TEST_CASES_SDXL + TEST_CASES_OTHER + TEST_CASES_FLUX + TEST_CASES_HEAVY
TEST_CASES_SDXL
+ TEST_CASES_PORTRAITS
+ TEST_CASES_OTHER
+ TEST_CASES_FLUX
+ TEST_CASES_HEAVY
)


async def select_test_flow_suite():
global SELECTED_TEST_FLOW_SUITE
print("Please select the test suite you want to run:")
print("1. SDXL Suite")
print("2. FLUX Suite")
print("3. OTHER Suite")
print("4. HEAVY(24GB+ VRAM) Suite")
print("2. PORTRAITS Suite")
print("3. FLUX Suite")
print("4. OTHER Suite")
print("5. HEAVY(24GB+ VRAM) Suite")

user_choice = input("Enter the number of the suite (1/2/3/4): ")
user_choice = input("Enter the number of the suite (1/2/3/4/5): ")

if user_choice == "1":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_SDXL
print("Selected SDXL Suite.")
elif user_choice == "2":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_FLUX
print("Selected FLUX Suite.")
SELECTED_TEST_FLOW_SUITE = TEST_CASES_PORTRAITS
print("Selected PORTRAITS Suite.")
elif user_choice == "3":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_OTHER
print("Selected OTHER Suite.")
elif user_choice == "4":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_FLUX
print("Selected FLUX Suite.")
elif user_choice == "5":
SELECTED_TEST_FLOW_SUITE = TEST_CASES_HEAVY
print("Selected HEAVY Suite.")
else:
@@ -450,6 +462,7 @@ async def wait_for_installation_to_complete(
flow_name: str, poll_interval: int = 5, timeout: int = FLOW_INSTALL_TIMEOUT
) -> bool:
elapsed_time = 0
max_read_timeout_count = 20
async with httpx.AsyncClient(auth=BASIC_AUTH) as client:
while elapsed_time < timeout:
try:
@@ -459,27 +472,41 @@ async def wait_for_installation_to_complete(

for flow in install_progress:
if flow["name"] == flow_name:
if flow["error"]:
print(
f"Error during installation of flow '{flow_name}': {flow['error']}"
)
return False
if flow["progress"] == 100:
if flow["error"]:
print(
f"Error during installation of flow '{flow_name}': {flow['error']}"
)
return False
print(f"Flow '{flow_name}' installation completed.")
return True
rounded_flow_progress = math.floor(flow["progress"] * 10) / 10
print(
f"Flow '{flow_name}' installation progress: {rounded_flow_progress}%"
)
break

await asyncio.sleep(poll_interval)
elapsed_time += poll_interval
max_read_timeout_count = 20
except httpx.ReadTimeout:
max_read_timeout_count -= 1
print(
f"Flow '{flow_name}': ReadTimeout error during installation progress check, "
f"continuing... {max_read_timeout_count} tries left"
)
if not max_read_timeout_count:
print(
f"Installation of flow '{flow_name}' failed due to repeated timeouts."
)
return False
except httpx.RequestError as exc:
print(
f"An error occurred while checking installation progress: {exc.request.url!r}: {exc}"
)
return False

print(f"Installation of flow '{flow_name}' timed out.")
print(f"Installation of flow '{flow_name}' timed out after {timeout} seconds.")
return False


@@ -523,8 +550,9 @@ async def create_task(
return []


async def get_task_progress(task_id: int, poll_interval: int = 3) -> dict:
async def get_task_progress(task_id: int, poll_interval: int = 5) -> dict:
max_read_timeout_count = 20
previous_progress = 0.0
async with httpx.AsyncClient(auth=BASIC_AUTH) as client:
while True:
try:
@@ -533,20 +561,31 @@ async def get_task_progress(task_id: int, poll_interval: int = 3) -> dict:
)
if response.status_code == 200:
task_data = response.json()
if task_data.get("progress") == 100.0 or task_data.get("error"):
if task_data.get("error"):
print(
f"Task with id={task_id} failed with error: {task_data['error']}"
)
return task_data
rounded_task_progress = (
math.floor(task_data.get("progress") * 10) / 10
)
print(
f"Task `{task_data['name']}` with id={task_id}, progress: {rounded_task_progress}%"
)
progress = task_data.get("progress", 0.0)
if progress == 100.0:
print(
f"Task `{task_data['name']}` with id={task_id}, progress: 100%"
)
return task_data
if progress > 0.0:
# Only print progress if it has increased
if progress != previous_progress:
rounded_task_progress = math.floor(progress * 10) / 10
print(
f"Task `{task_data['name']}` with id={task_id}, progress: {rounded_task_progress}%"
)
previous_progress = progress
max_read_timeout_count = 20
await asyncio.sleep(poll_interval)
except httpx.ReadTimeout:
max_read_timeout_count -= 1
print(
f"Task with id={task_id}: ReadTimeout error, continue... {max_read_timeout_count} tries left"
f"Task with id={task_id}: ReadTimeout error, continuing... {max_read_timeout_count} tries left"
)
if not max_read_timeout_count:
return {"error": "read timeout error"}
@@ -575,6 +614,17 @@ async def run_test_case(

# Use the test_semaphore to limit concurrent tests
async with test_semaphore:
global FIRST_TEST_FLAG

if FIRST_TEST_FLAG:
FIRST_TEST_FLAG = False
else:
if PAUSE_INTERVAL:
print(
f"Pausing for {PAUSE_INTERVAL} seconds before starting next test."
)
await asyncio.sleep(PAUSE_INTERVAL)

input_params = test_case.input_params
input_files = test_case.input_files
warm_up = test_case.warm_up
@@ -588,16 +638,15 @@ async def run_test_case(
)
return

# Wait for tasks to finish and save results
task_results = []
flow_comfy_saved = False
for task_id in task_ids:
result = await get_task_progress(task_id)
task_results.append(result)
# Create a list of coroutines for task progress
task_progress_coroutines = [get_task_progress(task_id) for task_id in task_ids]
task_results = await asyncio.gather(*task_progress_coroutines)

if not flow_comfy_saved and "flow_comfy" in result:
# Save the flow_comfy from the first task that has it
for result in task_results:
if "flow_comfy" in result:
await save_flow_comfy(flow_name, test_case.name, result["flow_comfy"])
flow_comfy_saved = True
break # No need to check further once saved

if warm_up:
task_results = task_results[1:]
@@ -606,7 +655,12 @@ async def run_test_case(
summary = await save_results(flow_name, test_case.name, test_case, task_results)
if summary:
results_summary[flow_name].append(
{"test_case": test_case.name, "avg_exec_time": summary["avg_exec_time"]}
{
"test_case": test_case.name,
"avg_exec_time": summary["avg_exec_time"],
"min_exec_time": summary["min_exec_time"],
"max_exec_time": summary["max_exec_time"],
}
)

# Save the updated results summary after each test case
@@ -836,10 +890,12 @@ async def save_output(
def get_suite_identifier() -> str:
if SELECTED_TEST_FLOW_SUITE == TEST_CASES_SDXL:
return "SDXL"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_FLUX:
return "FLUX"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_PORTRAITS:
return "PORTRAITS"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_OTHER:
return "OTHER"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_FLUX:
return "FLUX"
elif SELECTED_TEST_FLOW_SUITE == TEST_CASES_HEAVY:
return "HEAVY"
raise RuntimeError("Unknown TEST SUITE")

0 comments on commit c3c8153

Please sign in to comment.