Skip to content

Commit

Permalink
[feat] add model to replicate
Browse files Browse the repository at this point in the history
  • Loading branch information
floflokie committed Oct 30, 2023
1 parent 81d6ca9 commit 8554dc6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
5 changes: 5 additions & 0 deletions edenai_apis/apis/replicate/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
get_model_id = {
"llama-2-70b" : "14ce4448d5e7e9ed0c37745ac46eca157aab09061f0c179ac2b323b5de56552b",
"llama-2-70b-chat" :"58d078176e02c219e11eb4da5a02a7830a283b14cf8f94537af893ccff5ee781"
}
get_model_id_image = {
"anime-style" : "09a5805203f4c12da649ec1923bb7729517ca25fcac790e640eaa9ed66573b65",
"classic" : "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c",
"vintedois-diffusion" : "28cea91bdfced0e2dc7fda466cc0a46501c0edc84905b2120ea02e0707b967fd",
}
8 changes: 7 additions & 1 deletion edenai_apis/apis/replicate/info.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
"256x256",
"512x512",
"1024x1024"
]
],
"models": [
"anime-style",
"vintedois-diffusion",
"classic"
],
"default_model" : "classic"
},
"version" : "v1"
}
Expand Down
26 changes: 14 additions & 12 deletions edenai_apis/apis/replicate/outputs/image/generation_output.json

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions edenai_apis/apis/replicate/replicate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from edenai_apis.utils.exception import ProviderException
from edenai_apis.utils.types import ResponseType
import base64
from .config import get_model_id
from .config import get_model_id, get_model_id_image


class ReplicateApi(ProviderInterface, ImageInterface, TextInterface):
Expand Down Expand Up @@ -96,7 +96,6 @@ def __get_response(
status = response_dict["status"]
while status != "succeeded":
response = requests.get(url_get_response, headers=self.headers)

try:
response_dict = response.json()
except requests.JSONDecodeError:
Expand All @@ -120,17 +119,21 @@ def image__generation(
) -> ResponseType[GenerationDataClass]:
url = f"{self.base_url}/predictions"
size = resolution.split("x")
version = get_model_id_image[model]

payload = {
"input": {
"prompt": text,
"width": int(size[0]),
"height": int(size[1]),
},
"version": "c0259010b93e7a4102a4ba946d70e06d7d0c7dc007201af443cfc8f943ab1d3c",
"version": version,
}

response_dict = ReplicateApi.__get_response(self, url, payload)
image_url = response_dict.get("output")
if isinstance(image_url, list):
image_url = image_url[0]
image_bytes = base64.b64encode(requests.get(image_url).content)

return ResponseType[GenerationDataClass](
Expand Down

0 comments on commit 8554dc6

Please sign in to comment.