Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 2.0.1 #36

Merged
merged 5 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
#--
#-- 23/01/2024 Lyaaaaa
#-- - Removed TOKENIZERS_PATH.
#--
#-- 08/05/2024 Lyaaaaa
#-- - Added TORCH_DTYPE_SAFETY.
#---------------------------------------------------------------------------

import logging
Expand Down Expand Up @@ -79,3 +82,5 @@
DEVICE_MAP = None # None/see documentation
TORCH_DTYPE = None # "Auto"/None/torch.dtype/See torch_dtype.py for more info.

# Safeguards
TORCH_DTYPE_SAFETY = True # True/False. If CUDA isn't available, will enforce Torch_Dtype to float32 to avoir error. See issue https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
18 changes: 15 additions & 3 deletions server/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@
#-- - 31/01/2024 Lyaaaaa
#-- - generate_text now longer receives memory and context as parameters.
#-- They are embedded in the prompt parameter by the client.
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated generate_text to now be able to censor generation. The words
#-- passed in p_banned_words parameters won't be generated anymore.
#------------------------------------------------------------------------------

from model import Model
Expand All @@ -48,16 +52,24 @@
import logger

class Generator(Model):

#------------------------------------------------------------------------------
#-- generate_text
#------------------------------------------------------------------------------
def generate_text(self,
p_prompt = None,
p_parameters = None):
p_prompt = None,
p_parameters = None,
p_banned_words = []):

model_input = self._Tokenizer(p_prompt, return_tensors = "pt")

if p_banned_words:
banned_words_ids = self._Tokenizer(
p_banned_words,
add_special_tokens=False
).input_ids

p_parameters["bad_words_ids"] = banned_words_ids

if self.is_cuda_available:
logger.log.info("Loading inputs to GPU")
model_input.to("cuda")
Expand Down
21 changes: 20 additions & 1 deletion server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,17 @@
#-- - p_model_path is now the second parameter of __init__. p_parameters the third.
#-- - Added a log message to display the model's name and its path.
#-- - Added a log message to display if cuda is supported.
#--
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated _load_tokens to set add_prefix_space to True. It is needed
#-- for using bad_words_ids parameter for generation.
#--
#-- - 08/05/2024 Lyaaaaa
#-- - Updated _load_model to force (if config.TORCH_DTYPE_SAFETY is True)
#-- torch_dtype to be set to float32 if cuda isn't available.
#-- Because otherwise, it will lead to an error during generation.
#-- See https://github.com/LyaaaaaGames/AIdventure_Server/issues/31
#------------------------------------------------------------------------------

from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
Expand Down Expand Up @@ -296,7 +307,10 @@ def _load(self):
#------------------------------------------------------------------------------
def _load_tokens(self):
try:
self._Tokenizer = AutoTokenizer.from_pretrained(self._model_path)
self._Tokenizer = AutoTokenizer.from_pretrained(
self._model_path,
add_prefix_space=True
)
except Exception as e:
logger.log.error("Error loading tokens in " + self._model_path)
logger.log.error(e)
Expand All @@ -320,6 +334,11 @@ def _load_model(self):
logger.log.debug("Model settings:")
logger.log.debug(args)

if not self.is_cuda_available and config.TORCH_DTYPE_SAFETY:
logger.log.warn("Cuda isn't available.")
logger.log.warn("Setting torch_dtype to float 32 to avoid error.")
args["torch_dtype"] = Torch_Dtypes.dtypes.value[Torch_Dtypes.FLOAT_32.value]

self._Model = AutoModelForCausalLM.from_pretrained(self._model_path,
**args)
except Exception as e:
Expand Down
11 changes: 8 additions & 3 deletions server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@
#-- - 31/01/2024 Lyaaaaa
#-- - generate_text now longer receives memory and context as parameters.
#-- They are embedded in the prompt parameter by the client.
#--
#-- - 07/05/2024 Lyaaaaa
#-- - Updated handle_request and generation case to receive a banned_words
#-- parameter and pass it to generator.generate_text
#------------------------------------------------------------------------------

import asyncio
Expand Down Expand Up @@ -177,10 +181,11 @@ def handle_request(p_websocket, p_data : dict):
request = p_data['request']

if request == Request.TEXT_GENERATION.value:
prompt = p_data['prompt']
parameters = p_data['parameters']
prompt = p_data['prompt']
parameters = p_data['parameters']
banned_words = p_data['banned_words']

generated_text = generator.generate_text(prompt, parameters)
generated_text = generator.generate_text(prompt, parameters, banned_words)

p_data["generated_text"] = generated_text

Expand Down
Loading