Skip to content

Commit

Permalink
fix import, check working, remove redundant
Browse files Browse the repository at this point in the history
  • Loading branch information
bluecoconut committed May 12, 2023
1 parent a80f99d commit 902a1c7
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions lambdaprompt/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,16 @@ def set_backend(backend_name):
def get_backend(method):
if method in backends:
return backends[method]

backend_env = os.environ.get("LAMBDAPROMPT_BACKEND", None)
if backend_env:
set_backend(backend_env)
if method in backends:
return backends[method]

print(f"No backend set for [{method}], setting to default of OpenAI")
set_backend('OpenAI')
return backends[method]



class Backend:
class Parameters(BaseModel):
class Config:
Expand Down Expand Up @@ -110,10 +107,7 @@ def body(self, prompt, **kwargs):
"prompt": prompt,
**self.parse_param(**kwargs)
}
stop = data.pop('stop')
if stop:
data["stop"] = stop
return data
return {k: v for k, v in data.items() if v is not None}

def parse_response(self, answer):
if "error" in answer:
Expand All @@ -137,10 +131,7 @@ def body(self, messages, **kwargs):
"messages": messages,
**self.parse_param(**kwargs)
}
stop = data.pop('stop')
if stop:
data["stop"] = stop
return data
return {k: v for k, v in data.items() if v is not None}

def parse_response(self, answer):
if "error" in answer:
Expand Down Expand Up @@ -189,6 +180,7 @@ def preprocess(self, prompt):
return prompt

async def __call__(self, prompt, **kwargs):
import torch
s = self.preprocess(prompt)
input_ids = self.tokenizer(s, return_tensors="pt").input_ids
input_ids = input_ids.to(self.model.device)
Expand Down

0 comments on commit 902a1c7

Please sign in to comment.