Skip to content

Commit

Permalink
Merge pull request #1542 from h2oai/githash
Browse files Browse the repository at this point in the history
Check and version
  • Loading branch information
pseudotensor authored Apr 10, 2024
2 parents ed4d4d8 + 4484c26 commit 0f54a13
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 15 deletions.
Empty file added .gitattributes
Empty file.
12 changes: 7 additions & 5 deletions gradio_utils/grclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,13 +338,14 @@ def get_endpoints(client, verbose=False):
print("duration endpoints: %s" % (time.time() - t0), flush=True)

def get_server_hash(self):
t0 = time.time()
if self.config is None:
self.setup()
"""
Get server hash using super without any refresh action triggered
Returns: git hash of gradio server
"""
t0 = time.time()
if self.config is None:
self.setup()
t1 = time.time()
ret = "GET_GITHASH"
try:
if self.check_hash:
Expand All @@ -353,7 +354,7 @@ def get_server_hash(self):
finally:
if self.verbose:
print(
"duration server_hash: %s %s" % (time.time() - t0, ret), flush=True
"duration server_hash: %s full time: %s system_hash time: %s" % (ret, time.time() - t0, time.time() - t1), flush=True
)

def refresh_client_if_should(self):
Expand Down Expand Up @@ -394,9 +395,10 @@ def refresh_client(self):
kwargs.pop("check_model_name", None)
ntrials = 3
client = None
for trial in range(0, ntrials + 1):
for trial in range(0, ntrials):
try:
client = Client(*self.args, **kwargs)
break
except ValueError as e:
if trial >= ntrials:
raise
Expand Down
6 changes: 6 additions & 0 deletions openai_server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import uuid
from collections import deque

import filelock

from log import logger
from openai_server.backend_utils import convert_messages_to_structure

Expand Down Expand Up @@ -88,6 +90,10 @@ def get_client(user=None):
client = get_gradio_client(user=user)
elif hasattr(gradio_client, 'clone'):
client = gradio_client.clone()
if client.get_server_hash() != gradio_client.server_hash:
os.makedirs('locks', exist_ok=True)
with filelock.FileLock(os.path.join('locks', 'openai_gradio_client.lock')):
gradio_client.refresh_client()
else:
print(
"re-get to ensure concurrency ok, slower if API is large, for speed ensure gradio_utils/grclient.py exists.")
Expand Down
5 changes: 5 additions & 0 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from datetime import datetime
from random import randint

import filelock
import httpx
import requests
from requests import ConnectTimeout, JSONDecodeError
Expand Down Expand Up @@ -4923,6 +4924,10 @@ def evaluate(
break
else:
raise RuntimeError("Failed to get client: %s" % inference_server)
if isinstance(model, GradioClient) and not regenerate_gradio_clients and gr_client is not None:
if gr_client.server_hash != model.server_hash:
with filelock.FileLock(os.path.join('locks', 'gradio_client.lock')):
model.refresh_client()
else:
raise RuntimeError("No such inference_server %s" % inference_server)

Expand Down
11 changes: 11 additions & 0 deletions src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,11 @@ def _call(
if self.verbose:
print("end _call", flush=True)
self.use_gradio_return(res_dict, prompt)

# ensure parent client is updated if remote server changed
if client.server_hash != self.client.server_hash:
self.client.refresh_client()

return ret
else:
text_callback = None
Expand Down Expand Up @@ -903,6 +908,12 @@ def _call(
if self.verbose:
print("end _call", flush=True)
self.use_gradio_return(res_dict, prompt)

# ensure parent client is updated if remote server changed
if client.server_hash != self.client.server_hash:
with filelock.FileLock(os.path.join('locks', 'gradio_client.lock')):
self.client.refresh_client()

return ret

def use_gradio_return(self, res_dict, prompt_raw):
Expand Down
19 changes: 10 additions & 9 deletions src/gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,16 @@ def add_role_func(name, file, mic, roles1, use_mic):
system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm')
system_text3 = gr.Textbox(label='Hash', interactive=False,
visible=not is_public, show_copy_button=True)

def get_hash():
return kwargs['git_hash']

system_event = system_btn3.click(get_hash,
outputs=system_text3,
api_name='system_hash' if allow_api else None,
**noqueue_kwargs_curl,
)

system_btn4 = gr.Button(value='Get Model Names', visible=not is_public, size='sm')
system_text4 = gr.Textbox(label='Model Names', interactive=False,
visible=not is_public, show_copy_button=True)
Expand Down Expand Up @@ -6325,15 +6335,6 @@ def get_system_info_dict(system_input1, **kwargs1):
**noqueue_kwargs, # queue to avoid spam
)

def get_hash():
return kwargs['git_hash']

system_event = system_btn3.click(get_hash,
outputs=system_text3,
api_name='system_hash' if allow_api else None,
**noqueue_kwargs,
)

def get_model_names():
key_list = ['base_model', 'prompt_type', 'prompt_dict'] + list(kwargs['other_model_state_defaults'].keys())
# don't want to expose backend inference server IP etc.
Expand Down
16 changes: 16 additions & 0 deletions src/pre-commit
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/sh

# The path to the utils.py file relative to the root of the repository
FILE_PATH="src/version.py"

# Get the current git commit hash
GITHASH=$(git rev-parse HEAD)

# Update the __version__ variable in utils.py
# This uses a Perl one-liner to find the __version__ line and replace it with the current GITHASH
perl -pi -e "s/__version__ = \".*\"/__version__ = \"$GITHASH\"/" $FILE_PATH

# Add the modified utils.py file to the commit
git add $FILE_PATH

# End of script
16 changes: 16 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,17 @@ def get_githash():
githash = f.read()
except:
githash = "GET_GITHASH"

if githash == "GET_GITHASH":
try:
from src.version import __version__
githash = __version__
except:
pass

if os.getenv('HARD_ASSERTS'):
assert is_full_git_hash(githash)

return githash


Expand Down Expand Up @@ -2116,6 +2127,11 @@ def is_uuid4(string):
return bool(pattern.match(string))


def is_full_git_hash(s):
# This regex checks for exactly 40 hexadecimal characters.
return bool(re.fullmatch(r'[0-9a-f]{40}', s))


def get_show_username(username1):
if split_google in username1:
show_username = split_google.join(username1.split(split_google)[0:1])
Expand Down
1 change: 1 addition & 0 deletions src/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "d2ef601d9f2ea68b8b321870eb135d9eeb0e3f58"
9 changes: 8 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from src.enums import invalid_json_str
from src.gen import apply_chat_template
from src.utils import get_list_or_str, read_popen_pipes, get_token_count, reverse_ucurve_list, undo_reverse_ucurve_list, \
is_uuid4, has_starting_code_block, extract_code_block_content, looks_like_json, get_json
is_uuid4, has_starting_code_block, extract_code_block_content, looks_like_json, get_json, is_full_git_hash
from tests.utils import wrap_test_forked
import subprocess as sp

Expand Down Expand Up @@ -205,6 +205,13 @@ def test_is_uuid4():
assert [is_uuid4(s) for s in test_strings] == [True, False, False, False]


def test_is_git_hash():
# Example usage:
hashes = ["1a3b5c7d9e1a3b5c7d9e1a3b5c7d9e1a3b5c7d9e", "1G3b5c7d9e1a3b5c7d9e1a3b5c7d9e1a3b5c7d9e", "1a3b5c7d"]

assert [is_full_git_hash(h) for h in hashes] == [True, False, False]


def test_chat_template():
instruction = "Who are you?"
system_prompt = "Be kind"
Expand Down

0 comments on commit 0f54a13

Please sign in to comment.