Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add import control check to avoid leaking optional langchain stuff into generate/gradio. Add test #146

Merged
merged 1 commit into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def main(
dbs = {k: v for k, v in dbs.items() if v is not None}
else:
dbs = {}
# import control
if os.environ.get("TEST_LANGCHAIN_IMPORT"):
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"

if not gradio:
if eval_sharegpt_prompts_only > 0:
Expand Down
11 changes: 10 additions & 1 deletion gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ def upload_file(files, x):
with upload_row:
fileup_output = gr.File()
with gr.Row():
from gpt_langchain import file_types
# import control
if kwargs['langchain_mode'] != 'Disabled':
from gpt_langchain import file_types
else:
file_types = []
upload_button = gr.UploadButton("Upload %s" % file_types,
file_types=file_types,
file_count="multiple",
Expand Down Expand Up @@ -1090,6 +1094,11 @@ def get_system_info():
scheduler.add_job(func=ping, trigger="interval", seconds=60)
scheduler.start()

# import control
if kwargs['langchain_mode'] == 'Disabled' and os.environ.get("TEST_LANGCHAIN_IMPORT"):
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"

demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
favicon_path=favicon_path, prevent_thread_lock=True,
auth=kwargs['auth'])
Expand Down
17 changes: 16 additions & 1 deletion tests/test_client_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ def test_client1():


def run_client1():
import os, sys
os.environ['TEST_LANGCHAIN_IMPORT'] = "1"
sys.modules.pop('gpt_langchain', None)
sys.modules.pop('langchain', None)

from generate import main
main(base_model='h2oai/h2ogpt-oig-oasst1-512-6.9b', prompt_type='human_bot', chat=False,
stream_output=False, gradio=True, num_beams=1, block_gradio_exit=False)
Expand All @@ -27,6 +32,11 @@ def test_client_chat_nostream():


def run_client_chat(prompt='Who are you?', stream_output=False, max_new_tokens=256):
import os, sys
os.environ['TEST_LANGCHAIN_IMPORT'] = "1"
sys.modules.pop('gpt_langchain', None)
sys.modules.pop('langchain', None)

from generate import main
main(base_model='h2oai/h2ogpt-oig-oasst1-512-6.9b', prompt_type='human_bot', chat=True,
stream_output=stream_output, gradio=True, num_beams=1, block_gradio_exit=False,
Expand Down Expand Up @@ -60,6 +70,11 @@ def test_client_long():


def run_client_long():
import os, sys
os.environ['TEST_LANGCHAIN_IMPORT'] = "1"
sys.modules.pop('gpt_langchain', None)
sys.modules.pop('langchain', None)

from generate import main
main(base_model='mosaicml/mpt-7b-storywriter', prompt_type='plain', chat=False,
stream_output=False, gradio=True, num_beams=1, block_gradio_exit=False)
Expand All @@ -68,5 +83,5 @@ def run_client_long():
prompt = f.readlines()

from client_test import run_client_nochat
res_dict = run_client_nochat(prompt=prompt, prompt_type='plain')
res_dict = run_client_nochat(prompt=prompt, prompt_type='plain', max_new_tokens=86000)
print(res_dict['response'])
5 changes: 5 additions & 0 deletions tests/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ def test_eval1_cpu():


def run_eval1(cpu=False):
import os, sys
os.environ['TEST_LANGCHAIN_IMPORT'] = "1"
sys.modules.pop('gpt_langchain', None)
sys.modules.pop('langchain', None)

if cpu:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
Expand Down