From f697dd550a9ddf22d2b853df036cd3b778446497 Mon Sep 17 00:00:00 2001 From: Sara H Date: Tue, 27 Aug 2024 18:42:04 -0500 Subject: [PATCH 1/3] Instanciate more dalle2 and 3 provides, fix the way of instanciating several model passing the model instance --- .../domain/services/utils/image_generators.py | 634 +----------------- .../domain/services/utils/multi_generator.py | 45 +- 2 files changed, 49 insertions(+), 630 deletions(-) diff --git a/backend/app/domain/services/utils/image_generators.py b/backend/app/domain/services/utils/image_generators.py index 73569819..d585cc57 100644 --- a/backend/app/domain/services/utils/image_generators.py +++ b/backend/app/domain/services/utils/image_generators.py @@ -23,6 +23,7 @@ class ImageProvider(ABC): def __init__(self): self.task_service = TaskService() + self.timeout = 60 self.session = boto3.Session( aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"), @@ -82,7 +83,7 @@ def generate_images( model="dall-e-2", prompt=prompt, size="512x512", - n=num_images, + n=1, response_format="b64_json", ) message = "Success" @@ -109,10 +110,13 @@ def generate_images( except openai.BadRequestError as e: json_error = json.loads(e.response.text) error_code = json_error.get("error", {}).get("code", e.response.status_code) + error_message = json_error.get("error", {}).get( + "message", "openai bad request error" + ) image_id = self.get_image_id(prompt, user_id, forbidden_image) return { "generator": self.provider_name(), - "message": error_code, + "message": f"error code: {error_code}, error message: {error_message}", "id": image_id, } @@ -196,77 +200,24 @@ def provider_name(self): class SDXLImageProvider(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxl1"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxl1']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json()[0]["image"]["images"][0] - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl1.0" - - -class SDXLImageProvider2(ImageProvider): - def __init__(self): + def __init__(self, model_instance: str): super().__init__() self.api_key = os.getenv("HF") self.session = RequestSession() + self.model_instance = model_instance def generate_images( self, prompt: str, num_images: int, model, endpoint, user_id ) -> list: - print("Trying model", endpoint["sdxl2"]["endpoint"]) + print("Trying model", endpoint[self.model_instance]["endpoint"]) payload = {"inputs": prompt, "steps": 30} headers = {"Authorization": self.api_key} try: response = requests.post( - f"{endpoint['sdxl2']['endpoint']}", + f"{endpoint[self.model_instance]['endpoint']}", json=payload, headers=headers, - timeout=25, + timeout=self.timeout, ) message = "Success" if response.status_code == 200: @@ -303,236 +254,26 @@ def provider_name(self): return "sdxl1.0" -class SDXLImageProvider3(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxl3"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxl3']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl1.0" - - class SDXLTurboImageProvider(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxlturbo"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxlturbo']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl-turbo" - - -class SDXLTurboImageProvider2(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxlturbo2"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxlturbo2']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl-turbo" - - -class SDXLTurboImageProvider3(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxlturbo3"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxlturbo3']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl-turbo" - - -class SDXLTurboImageProvider4(ImageProvider): - def __init__(self): + def __init__(self, model_instance: str): super().__init__() self.api_key = os.getenv("HF") self.session = RequestSession() + self.model_instance = model_instance + self.timeout = 25 def generate_images( self, prompt: str, num_images: int, model, endpoint, user_id ) -> list: - print("Trying model", endpoint["sdxlturbo4"]["endpoint"]) + print("Trying model", endpoint[self.model_instance]["endpoint"]) payload = {"inputs": prompt, "steps": 30} headers = {"Authorization": self.api_key} try: response = requests.post( - f"{endpoint['sdxlturbo4']['endpoint']}", + f"{endpoint[self.model_instance]['endpoint']}", json=payload, headers=headers, - timeout=25, + timeout=self.timeout, ) message = "Success" if response.status_code == 200: @@ -568,24 +309,26 @@ def provider_name(self): return "sdxl-turbo" -class SDXLTurboImageProvider5(ImageProvider): - def __init__(self): +class SDRunwayMLImageProvider(ImageProvider): + def __init__(self, model_instance: str): super().__init__() self.api_key = os.getenv("HF") self.session = RequestSession() + self.model_instance = model_instance + self.timeout = 25 def generate_images( self, prompt: str, num_images: int, model, endpoint, user_id ) -> list: - print("Trying model", endpoint["sdxlturbo5"]["endpoint"]) + print("Trying model", endpoint[self.model_instance]["endpoint"]) payload = {"inputs": prompt, "steps": 30} headers = {"Authorization": self.api_key} try: response = requests.post( - f"{endpoint['sdxlturbo5']['endpoint']}", + f"{endpoint[self.model_instance]['endpoint']}", json=payload, headers=headers, - timeout=25, + timeout=self.timeout, ) message = "Success" if response.status_code == 200: @@ -618,345 +361,28 @@ def generate_images( } def provider_name(self): - return "sdxl-turbo" - - -class SDXLTurboImageProvider6(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sdxlturbo6"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sdxlturbo6']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sdxl-turbo" - - -class SDRunwayMLImageProvider(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sd15"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sd15']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "runwayml-sd1.5" - - -class SDRunwayMLImageProvider2(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sd152"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sd152']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "runwayml-sd1.5" - - -class SDRunwayMLImageProvider3(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sd153"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sd153']['endpoint']}", - json=payload, - headers=headers, - timeout=25, - ) - message = "Success" - if response.status_code == 200: - image = response.json() - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "runwayml-sd1.5" + return "runwayml-sd1.5" class SDVariableAutoEncoder(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sd21vae"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sd21vae']['endpoint']}", - json=payload, - headers=headers, - timeout=50, - ) - message = "Success" - if response.status_code == 200: - image = response.json()[0]["image"]["images"][0] - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sd+vae_ft_mse" - - -class SDVariableAutoEncoder2(ImageProvider): - def __init__(self): - super().__init__() - self.api_key = os.getenv("HF") - self.session = RequestSession() - - def generate_images( - self, prompt: str, num_images: int, model, endpoint, user_id - ) -> list: - print("Trying model", endpoint["sd21vae2"]["endpoint"]) - payload = {"inputs": prompt, "steps": 30} - headers = {"Authorization": self.api_key} - try: - response = requests.post( - f"{endpoint['sd21vae2']['endpoint']}", - json=payload, - headers=headers, - timeout=50, - ) - message = "Success" - if response.status_code == 200: - image = response.json()[0]["image"]["images"][0] - dark_image = self.verify_image_darkness(image) - if dark_image: - image = forbidden_image - message = "Image is too dark" - image_id = self.get_image_id(prompt, user_id, image) - filename = f"adversarial-nibbler/{prompt}/{user_id}/{image_id}.jpeg" - self.s3.put_object( - Body=base64.b64decode(image), - Bucket=self.dataperf_bucket, - Key=filename, - ) - return { - "generator": self.provider_name(), - "message": message, - "prompt": prompt, - "id": image_id, - } - - except Exception as e: - error_code = str(e) - return { - "generator": self.provider_name(), - "message": error_code, - "images": None, - "id": None, - } - - def provider_name(self): - return "sd+vae_ft_mse" - - -class SDVariableAutoEncoder3(ImageProvider): - def __init__(self): + def __init__(self, model_instance: str): super().__init__() self.api_key = os.getenv("HF") self.session = RequestSession() + self.model_instance = model_instance def generate_images( self, prompt: str, num_images: int, model, endpoint, user_id ) -> list: - print("Trying model", endpoint["sd21vae3"]["endpoint"]) + print("Trying model", endpoint[self.model_instance]["endpoint"]) payload = {"inputs": prompt, "steps": 30} headers = {"Authorization": self.api_key} try: response = requests.post( - f"{endpoint['sd21vae3']['endpoint']}", + f"{endpoint[self.model_instance]['endpoint']}", json=payload, headers=headers, - timeout=50, + timeout=self.timeout, ) message = "Success" if response.status_code == 200: diff --git a/backend/app/domain/services/utils/multi_generator.py b/backend/app/domain/services/utils/multi_generator.py index 52300cda..f7ff9fbc 100644 --- a/backend/app/domain/services/utils/multi_generator.py +++ b/backend/app/domain/services/utils/multi_generator.py @@ -11,20 +11,9 @@ Dalle2ImageProvider, Dalle3ImageProvider, SDRunwayMLImageProvider, - SDRunwayMLImageProvider2, - SDRunwayMLImageProvider3, SDVariableAutoEncoder, - SDVariableAutoEncoder2, - SDVariableAutoEncoder3, SDXLImageProvider, - SDXLImageProvider2, - SDXLImageProvider3, SDXLTurboImageProvider, - SDXLTurboImageProvider2, - SDXLTurboImageProvider3, - SDXLTurboImageProvider4, - SDXLTurboImageProvider5, - SDXLTurboImageProvider6, ) from app.domain.services.utils.llm import ( AlephAlphaProvider, @@ -107,23 +96,27 @@ async def generate_all_texts( class ImageGenerator: def __init__(self): self.image_providers = [ - SDVariableAutoEncoder(), - SDVariableAutoEncoder2(), - SDVariableAutoEncoder3(), + SDVariableAutoEncoder(model_instance="sd21vae"), + SDVariableAutoEncoder(model_instance="sd21vae2"), + SDVariableAutoEncoder(model_instance="sd21vae3"), Dalle3ImageProvider(), + Dalle3ImageProvider(), + Dalle3ImageProvider(), + Dalle2ImageProvider(), + Dalle2ImageProvider(), Dalle2ImageProvider(), - SDRunwayMLImageProvider(), - SDRunwayMLImageProvider2(), - SDRunwayMLImageProvider3(), - SDXLImageProvider(), - SDXLImageProvider2(), - SDXLImageProvider3(), - SDXLTurboImageProvider(), - SDXLTurboImageProvider2(), - SDXLTurboImageProvider3(), - SDXLTurboImageProvider4(), - SDXLTurboImageProvider5(), - SDXLTurboImageProvider6(), + SDRunwayMLImageProvider(model_instance="sd15"), + SDRunwayMLImageProvider(model_instance="sd152"), + SDRunwayMLImageProvider(model_instance="sd153"), + SDXLImageProvider(model_instance="sdxl1"), + SDXLImageProvider(model_instance="sdxl2"), + SDXLImageProvider(model_instance="sdxl3"), + SDXLTurboImageProvider(model_instance="sdxlturbo"), + SDXLTurboImageProvider(model_instance="sdxlturbo2"), + SDXLTurboImageProvider(model_instance="sdxlturbo3"), + SDXLTurboImageProvider(model_instance="sdxlturbo4"), + SDXLTurboImageProvider(model_instance="sdxlturbo5"), + SDXLTurboImageProvider(model_instance="sdxlturbo6"), HF_SDXL(), ] From 76b619e1e229b06a73e24788061b7d004148d9cf Mon Sep 17 00:00:00 2001 From: Sara H Date: Thu, 29 Aug 2024 15:21:07 -0500 Subject: [PATCH 2/3] Fix response parser in SDXLImageProvider --- backend/app/domain/services/utils/image_generators.py | 2 +- backend/app/domain/services/utils/multi_generator.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/backend/app/domain/services/utils/image_generators.py b/backend/app/domain/services/utils/image_generators.py index d585cc57..4276792a 100644 --- a/backend/app/domain/services/utils/image_generators.py +++ b/backend/app/domain/services/utils/image_generators.py @@ -221,7 +221,7 @@ def generate_images( ) message = "Success" if response.status_code == 200: - image = response.json()[0]["image"]["images"][0] + image = response.json() dark_image = self.verify_image_darkness(image) if dark_image: image = forbidden_image diff --git a/backend/app/domain/services/utils/multi_generator.py b/backend/app/domain/services/utils/multi_generator.py index f7ff9fbc..552057a9 100644 --- a/backend/app/domain/services/utils/multi_generator.py +++ b/backend/app/domain/services/utils/multi_generator.py @@ -108,7 +108,6 @@ def __init__(self): SDRunwayMLImageProvider(model_instance="sd15"), SDRunwayMLImageProvider(model_instance="sd152"), SDRunwayMLImageProvider(model_instance="sd153"), - SDXLImageProvider(model_instance="sdxl1"), SDXLImageProvider(model_instance="sdxl2"), SDXLImageProvider(model_instance="sdxl3"), SDXLTurboImageProvider(model_instance="sdxlturbo"), From 019b31d77b5910b264d320fc55e1e091cda48917 Mon Sep 17 00:00:00 2001 From: Sara H Date: Tue, 3 Sep 2024 11:44:03 -0500 Subject: [PATCH 3/3] Add retry logic and give 2 more seconds to SDXL Image Provider --- .../domain/services/utils/image_generators.py | 1 + backend/worker/tasks.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/backend/app/domain/services/utils/image_generators.py b/backend/app/domain/services/utils/image_generators.py index 4276792a..4e06a7a8 100644 --- a/backend/app/domain/services/utils/image_generators.py +++ b/backend/app/domain/services/utils/image_generators.py @@ -205,6 +205,7 @@ def __init__(self, model_instance: str): self.api_key = os.getenv("HF") self.session = RequestSession() self.model_instance = model_instance + self.timeout = 62 def generate_images( self, prompt: str, num_images: int, model, endpoint, user_id diff --git a/backend/worker/tasks.py b/backend/worker/tasks.py index d41accfb..b6154225 100644 --- a/backend/worker/tasks.py +++ b/backend/worker/tasks.py @@ -65,6 +65,21 @@ def generate_images( res = images.apply_async() print(res) all_responses = res.get(disable_sync_subtasks=False) + successes = len( + [response for response in all_responses if response["message"] == "Success"] + ) + if (successes + num_of_current_images) < 5: + more_images = celery.group( + *[ + generate_nibbler_images_celery.s( + prompt, num_images, models, endpoint, user_id + ) + for _ in range(5 - successes + num_of_current_images) + ] + ) + res = more_images.apply_async() + additional_responses = res.get(disable_sync_subtasks=False) + all_responses.extend(additional_responses) if (len(all_responses) + num_of_current_images) >= 5: job_service.remove_registry({"prompt": prompt, "user_id": user_id}) @@ -85,5 +100,6 @@ def generate_images( logger.critical(json.dumps(info_to_log)) except Exception as e: - print(e) + error = e + print("Error: ", error) job_service.remove_registry({"prompt": prompt, "user_id": user_id})