Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use SAVE_DIR and --save_dir instead of SAVE_PATH and --save_path. #71

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main(

llama_type: bool = None,
debug: bool = False,
save_path: str = None,
save_dir: str = None,
share: bool = True,
local_files_only: bool = False,
resume_download: bool = True,
Expand Down Expand Up @@ -112,7 +112,7 @@ def main(
if is_hf:
# must override share if in spaces
share = False
save_path = os.getenv('SAVE_PATH', save_path)
save_dir = os.getenv('SAVE_DIR', save_dir)

# get defaults
model_lower = base_model.lower()
Expand Down Expand Up @@ -180,7 +180,7 @@ def main(
if not eval_sharegpt_as_output:
model, tokenizer, device = get_model(**locals())
model_state = [model, tokenizer, device, base_model]
fun = partial(evaluate, model_state, debug=debug, chat=chat, save_path=save_path)
fun = partial(evaluate, model_state, debug=debug, chat=chat, save_dir=save_dir)
else:
assert eval_sharegpt_prompts_only > 0

Expand Down Expand Up @@ -814,7 +814,7 @@ def _postprocess_chat_messages(self, chat_message: str):
file_output = gr.File()

# Get flagged data
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_path']])
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
zip_btn.click(zip_data1, inputs=None, outputs=file_output)

def check_admin_pass(x):
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def get_system_info():


input_args_list = ['model_state']
inputs_kwargs_list = ['debug', 'chat', 'save_path', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
inputs_kwargs_list = ['debug', 'chat', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']


def get_inputs_list(inputs_dict, model_lower):
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def evaluate(
src_lang=None,
tgt_lang=None,
debug=False,
save_path=None,
save_dir=None,
chat=False,
hard_stop_list=None,
sanitize_bot_response=True,
Expand Down Expand Up @@ -1421,16 +1421,16 @@ def generate(callback=None, **kwargs):
raise StopIteration
yield prompter.get_response(decoded_output, prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response)
if save_path and decoded_output:
save_generate_output(output=decoded_output, base_model=base_model, json_file_path=save_path)
if save_dir and decoded_output:
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
else:
outputs = model.generate(**gen_kwargs)
outputs = [decoder(s) for s in outputs.sequences]
yield prompter.get_response(outputs, prompt=inputs_decoded,
sanitize_bot_response=sanitize_bot_response)
if save_path and outputs and len(outputs) >= 1:
if save_dir and outputs and len(outputs) >= 1:
decoded_output = prompt + outputs[0]
save_generate_output(output=decoded_output, base_model=base_model, json_file_path=save_path)
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)


def get_generate_params(model_lower, chat,
Expand Down
17 changes: 9 additions & 8 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,30 +118,31 @@ def _zip_data(root_dirs=None, zip_path='data.zip', base_dir='./'):
return "data.zip"


def save_generate_output(output=None, base_model=None, json_file_path=None):
def save_generate_output(output=None, base_model=None, save_dir=None):
try:
return _save_generate_output(output=output, base_model=base_model, json_file_path=json_file_path)
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
except Exception as e:
traceback.print_exc()
print('Exception in saving: %s' % str(e))


def _save_generate_output(output=None, base_model=None, json_file_path=None):
def _save_generate_output(output=None, base_model=None, save_dir=None):
"""
Save conversation to .json, row by row.
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
Appends if file exists
"""
assert isinstance(json_file_path, str), "must provide save_path"
if os.path.dirname(json_file_path):
os.makedirs(os.path.dirname(json_file_path), exist_ok=True)
assert save_dir, "save_dir must be provided"
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
raise RuntimeError("save_dir already exists and is not a directory!")
os.makedirs(save_dir, exist_ok=True)
import json
if output[-10:] == '\n\n<human>:':
# remove trailing <human>:
output = output[:-10]
with filelock.FileLock("save_path.lock"):
with filelock.FileLock("save_dir.lock"):
# lock logging in case have concurrency
with open(json_file_path, "a") as f:
with open(os.path.join(save_dir, "history.json"), "a") as f:
# just add [ at start, and ] at end, and have proper JSON dataset
f.write(
" " + json.dumps(
Expand Down