From 8554dc6b566feda670a46c48524ee6ee40cf768f Mon Sep 17 00:00:00 2001 From: Florine Kieraga Date: Mon, 30 Oct 2023 11:06:44 +0100 Subject: [PATCH] [feat] add model to replicate --- edenai_apis/apis/replicate/config.py | 5 ++++ edenai_apis/apis/replicate/info.json | 8 +++++- .../outputs/image/generation_output.json | 26 ++++++++++--------- edenai_apis/apis/replicate/replicate_api.py | 9 ++++--- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/edenai_apis/apis/replicate/config.py b/edenai_apis/apis/replicate/config.py index f866a9ef..b09188d1 100644 --- a/edenai_apis/apis/replicate/config.py +++ b/edenai_apis/apis/replicate/config.py @@ -1,4 +1,9 @@ get_model_id = { "llama-2-70b" : "14ce4448d5e7e9ed0c37745ac46eca157aab09061f0c179ac2b323b5de56552b", "llama-2-70b-chat" :"58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781" +} +get_model_id_image = { + "anime-style" : "09a5805203f4c12da649ec1923bb7729517ca25fcac790e640eaa9ed66573b65", + "classic" : "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c", + "vintedois-diffusion" : "28cea91bdfced0e2dc7fda466cc0a46501c0edc84905b2120ea02e0707b967fd", } \ No newline at end of file diff --git a/edenai_apis/apis/replicate/info.json b/edenai_apis/apis/replicate/info.json index 44747f6a..c9ad1a2f 100644 --- a/edenai_apis/apis/replicate/info.json +++ b/edenai_apis/apis/replicate/info.json @@ -6,7 +6,13 @@ "256x256", "512x512", "1024x1024" - ] + ], + "models": [ + "anime-style", + "vintedois-diffusion", + "classic" + ], + "default_model" : "classic" }, "version" : "v1" } diff --git a/edenai_apis/apis/replicate/outputs/image/generation_output.json b/edenai_apis/apis/replicate/outputs/image/generation_output.json index 6a3864f2..6176a2d2 100644 --- a/edenai_apis/apis/replicate/outputs/image/generation_output.json +++ b/edenai_apis/apis/replicate/outputs/image/generation_output.json @@ -1,32 +1,34 @@ { "original_response": { - "id": "ovcvsidbo2fxbwqsdypw4ewwbq", - "version": "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c", + "id": "qscb5xzbkbolvhpkgmgux3e2iy", + "version": "28cea91bdfced0e2dc7fda466cc0a46501c0edc84905b2120ea02e0707b967fd", "input": { "height": 512, "prompt": "A huge red ballon flying outside the city.", "width": 512 }, - "logs": " 0%| | 0/20 [00:00 ResponseType[GenerationDataClass]: url = f"{self.base_url}/predictions" size = resolution.split("x") + version = get_model_id_image[model] + payload = { "input": { "prompt": text, "width": int(size[0]), "height": int(size[1]), }, - "version": "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c", + "version": version, } response_dict = ReplicateApi.__get_response(self, url, payload) image_url = response_dict.get("output") + if isinstance(image_url, list): + image_url = image_url[0] image_bytes = base64.b64encode(requests.get(image_url).content) return ResponseType[GenerationDataClass](