Skip to content

Commit

Permalink
less imports
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Jun 7, 2023
1 parent 598437e commit 56575fa
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions imodelsx/iprompt/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,15 @@
)
import transformers
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
from langchain.cache import InMemoryCache
import re
from transformers import LlamaForCausalLM, LlamaTokenizer
from typing import Any, Dict, List, Mapping, Optional
import numpy as np
import openai
import os.path
from os.path import join, dirname
import os
import pickle as pkl
import langchain
from scipy.special import softmax
import openai
from langchain.llms.base import LLM
import hashlib
import torch
import time
Expand All @@ -47,7 +42,8 @@ def get_llm(checkpoint, seed=1, role: str = None):
checkpoint, seed=seed
) # warning: this sets torch.manual_seed(seed)

def llm_openai(checkpoint="text-davinci-003", seed=1) -> LLM:
def llm_openai(checkpoint="text-davinci-003", seed=1):
import openai
class LLM_OpenAI:
def __init__(self, checkpoint, seed):
self.cache_dir = join(
Expand Down Expand Up @@ -86,7 +82,7 @@ def __call__(self, prompt: str, max_new_tokens=250, do_sample=True, stop=None):
return LLM_OpenAI(checkpoint, seed)


def llm_openai_chat(checkpoint="gpt-3.5-turbo", seed=1, role=None) -> LLM:
def llm_openai_chat(checkpoint="gpt-3.5-turbo", seed=1, role=None):
class LLM_Chat:
"""Chat models take a different format: https://platform.openai.com/docs/guides/chat/introduction"""

Expand Down Expand Up @@ -167,7 +163,7 @@ def __call__(
return LLM_Chat(checkpoint, seed, role)


def llm_hf(checkpoint="google/flan-t5-xl", seed=1) -> LLM:
def llm_hf(checkpoint="google/flan-t5-xl", seed=1):
LLAMA_DIR = "/home/chansingh/llama"

class LLM_HF:
Expand Down

0 comments on commit 56575fa

Please sign in to comment.