diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 9221f8d..23f7c08 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -1,4 +1,4 @@ -name: Checking +name: Development on: push jobs: @@ -20,19 +20,31 @@ jobs: - name: Install Hatch run: python -m pip install --upgrade hatch + - name: Get Hatch Dependency Hash + run: echo "HATCH_DEP_HASH=$(hatch dep hash)" >> $GITHUB_ENV + + - name: Cache Hatch environment + uses: actions/cache@v4.1.2 + with: + path: | + ~/.cache/hatch + ~/.local/share/hatch + key: ${{ runner.os }}-hatch-${{ env.HATCH_DEP_HASH }} + - name: Generate Requirements run: python -m hatch dep show requirements > requirements.txt + # Upload requirements to have them - name: Upload Requirements uses: actions/upload-artifact@v4.4.3 with: name: requirements path: requirements.txt - pylint: name: PyLint runs-on: ubuntu-latest + needs: setup-requirements timeout-minutes: 10 steps: @@ -45,17 +57,28 @@ jobs: with: python-version: "3.12.0" - - name: Install pipenv - run: | - python -m pip install --upgrade pipenv wheel + - name: Install Hatch + if: steps.cache-hatch.outputs.cache-hit != 'true' + run: python -m pip install --upgrade hatch - - name: Install dependencies - run: | - pipenv install --deploy --dev + - name: Cache Hatch environment + uses: actions/cache@v4.1.2 + with: + path: | + ~/.cache/hatch + ~/.local/share/hatch + key: ${{ runner.os }}-hatch-${{ env.HATCH_DEP_HASH }} + # Don't fail just output, since we want the score to be above 9 not 10.0 + # Don’t let the Perfect be the Enemy of the Good - name: Pylint on esbmc_ai + run: hatch run pylint esbmc_ai || true + + # Check if pass, the test command only takes integers so truncate decimals + - name: Check If Pass (90%) run: | - pipenv run pylint esbmc_ai + SCORE="$(sed -n '$s/[^0-9]*\([0-9.]*\).*/\1/p' <<< "$(hatch run pylint esbmc_ai)")" + test "${SCORE%.*}" -ge 9 test: name: PyTest @@ -73,27 +96,109 @@ jobs: with: python-version: "3.12.0" - - name: Download Requirements - uses: actions/download-artifact@v4.1.8 + - name: Install Hatch + if: steps.cache-hatch.outputs.cache-hit != 'true' + run: python -m pip install --upgrade hatch + + - name: Cache Hatch environment + uses: actions/cache@v4.1.2 with: - name: requirements - path: . + path: | + ~/.cache/hatch + ~/.local/share/hatch + key: ${{ runner.os }}-hatch-${{ env.HATCH_DEP_HASH }} + + - name: Run test suite + run: hatch run pytest + + # incremenet_version: + # name: Increment Version + # runs-on: ubuntu-latest + # needs: setup-requirements + # timeout-minutes: 10 + # # Configure permissions for git push + # permissions: + # contents: write + + # steps: + # - name: Check out repository code + # uses: actions/checkout@v4.2.2 + # with: + # persist-credentials: false # otherwise, the token used is the GITHUB_TOKEN, instead of your personal access token. + # fetch-depth: 0 # otherwise, there would be errors pushing refs to the destination repository. + + # # Setup Python (faster than using Python container) + # - name: Setup Python + # uses: actions/setup-python@v5.3.0 + # with: + # python-version: "3.12.0" + + # - name: Install Hatch + # if: steps.cache-hatch.outputs.cache-hit != 'true' + # run: python -m pip install --upgrade hatch + + # - name: Cache Hatch environment + # uses: actions/cache@v4.1.2 + # with: + # path: | + # ~/.cache/hatch + # ~/.local/share/hatch + # key: ${{ runner.os }}-hatch-${{ env.HATCH_DEP_HASH }} + + # - name: Invrement Version + # run: hatch version dev + + # - name: Configure Git + # run: | + # git config --global user.email "github-actions[bot]@users.noreply.github.com" + # git config --global user.name "github-actions[bot]" + + # # Add and commit without changing message + # - name: Git Add + # run: | + # git add esbmc_ai/__about__.py + # git commit -m "Increment version" + + # - name: GitHub Push + # if: github.ref != 'refs/heads/master' + # uses: ad-m/github-push-action@v0.8.0 + # with: + # github_token: ${{ secrets.GITHUB_TOKEN }} + # branch: ${{ github.ref }} + + build: + name: Build + runs-on: ubuntu-latest + needs: setup-requirements + timeout-minutes: 10 - - name: Install Environment - run: python -m pip install --upgrade pipenv wheel + steps: + - name: Check out repository code + uses: actions/checkout@v4.2.2 - - name: Cache Pipenv - id: cache-pipenv - uses: actions/cache@v4.1.2 + # Setup Python (faster than using Python container) + - name: Setup Python + uses: actions/setup-python@v5.3.0 with: - path: ~/.local/share/virtualenvs - key: ${{ runner.os }}-pipenv-${{ hashFiles('**/Pipfile.lock') }} + python-version: "3.12.0" - - name: Install dependencies - if: steps.cache-pipenv.outputs.cache-hit != 'true' - run: | - pipenv install -r requirements.txt - pipenv lock + - name: Install Hatch + if: steps.cache-hatch.outputs.cache-hit != 'true' + run: python -m pip install --upgrade hatch - - name: Run test suite - run: pipenv run pytest -v + - name: Cache Hatch environment + uses: actions/cache@v4.1.2 + with: + path: | + ~/.cache/hatch + ~/.local/share/hatch + key: ${{ runner.os }}-hatch-${{ env.HATCH_DEP_HASH }} + + - name: Hatch build + run: hatch build + + - name: Upload build files + uses: actions/upload-artifact@v4.4.3 + with: + name: build + path: dist diff --git a/.gitignore b/.gitignore index 1e964aa..d88f4f2 100644 --- a/.gitignore +++ b/.gitignore @@ -166,3 +166,5 @@ config_dev.toml # Proprietary source code samples. uav_test.sh + +addons/ \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index caf4bfc..2e11c7f 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -26,6 +26,7 @@ "-c", "fix-code", "-v", + "-r", "${file}" ] }, diff --git a/config.toml b/config.toml index 43b9ca5..6ba9e55 100644 --- a/config.toml +++ b/config.toml @@ -5,10 +5,14 @@ allow_successful = false loading_hints = true source_code_format = "full" -[esbmc] +[verifier] +name = "esbmc" + +[verifier.esbmc] path = "~/.local/bin/esbmc" params = [ "--interval-analysis", + "--memory-leak-check", "--goto-unwind", "--unlimited-goto-unwind", "--k-induction", @@ -18,7 +22,6 @@ params = [ "2", "--floatbv", "--unlimited-k-steps", - "--compact-trace", "--context-bound", "2", ] diff --git a/esbmc_ai/__about__.py b/esbmc_ai/__about__.py index 70ab742..2289d7a 100644 --- a/esbmc_ai/__about__.py +++ b/esbmc_ai/__about__.py @@ -1,4 +1,4 @@ # Author: Yiannis Charalambous -__version__ = "v0.6.0" +__version__ = "v0.6.0.dev2" __author__: str = "Yiannis Charalambous" diff --git a/esbmc_ai/__main__.py b/esbmc_ai/__main__.py index 6b58b94..8409a90 100755 --- a/esbmc_ai/__main__.py +++ b/esbmc_ai/__main__.py @@ -8,9 +8,12 @@ # Enables arrow key functionality for input(). Do not remove import. import readline -_ = readline +from esbmc_ai.addon_loader import AddonLoader +from esbmc_ai.commands.user_chat_command import UserChatCommand +from esbmc_ai.verifier_runner import VerifierRunner +from esbmc_ai.verifiers.esbmc import ESBMC -from langchain_core.language_models import BaseChatModel +_ = readline from esbmc_ai.command_runner import CommandRunner from esbmc_ai.commands.fix_code_command import FixCodeCommandResult @@ -20,7 +23,7 @@ from esbmc_ai import Config from esbmc_ai import __author__, __version__ -from esbmc_ai.solution import SourceFile, Solution, get_solution +from esbmc_ai.solution import get_solution from esbmc_ai.commands import ( ChatCommand, @@ -32,16 +35,16 @@ from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget from esbmc_ai.chats import UserChat -from esbmc_ai.logging import print_horizontal_line, printv, printvv -from esbmc_ai.esbmc_util import ESBMCUtil -from esbmc_ai.chat_response import FinishReason, ChatResponse +from esbmc_ai.logging import printv, printvv from esbmc_ai.ai_models import _ai_model_names help_command: HelpCommand = HelpCommand() fix_code_command: FixCodeCommand = FixCodeCommand() exit_command: ExitCommand = ExitCommand() -command_runner: CommandRunner = CommandRunner( - [ + +verifier_runner: VerifierRunner = VerifierRunner().init([ESBMC()]) +command_runner: CommandRunner = CommandRunner().init( + builtin_commands=[ help_command, exit_command, fix_code_command, @@ -88,7 +91,7 @@ def check_health() -> None: printv("Performing health check...") # Check that ESBMC exists. - esbmc_path: Path = Config.get_value("esbmc.path") + esbmc_path: Path = Config().get_value("verifier.esbmc.path") if esbmc_path.exists(): printv("ESBMC has been located") else: @@ -96,115 +99,27 @@ def check_health() -> None: sys.exit(3) -def print_assistant_response( - chat: UserChat, - response: ChatResponse, - hide_stats: bool = False, -) -> None: - print(f"{response.message.type}: {response.message.content}\n\n") - - if not hide_stats: - print( - "Stats:", - f"total tokens: {response.total_tokens},", - f"max tokens: {chat.ai_model.tokens}", - f"finish reason: {response.finish_reason}", - ) - - -def init_addons() -> None: - command_runner.addon_commands.clear() - command_runner.addon_commands.extend(Config.get_value("addon_modules")) - if len(command_runner.addon_commands): - printv("Addons:\n\t* " + "\t * ".join(command_runner.addon_commands_names)) - - -def update_solution(source_code: str) -> None: - get_solution().files[0].update_content(content=source_code, reset_changes=True) - - -def _run_esbmc(source_file: SourceFile, anim: BaseLoadingWidget) -> str: - assert source_file.file_path - - with anim("ESBMC is processing... Please Wait"): - exit_code, esbmc_output = ESBMCUtil.esbmc( - path=source_file.file_path, - esbmc_params=Config.get_value("esbmc.params"), - timeout=Config.get_value("esbmc.timeout"), - ) - - # ESBMC will output 0 for verification success and 1 for verification - # failed, if anything else gets thrown, it's an ESBMC error. - if not Config.get_value("allow_successful") and exit_code == 0: - printv("Success!") - print(esbmc_output) - sys.exit(0) - elif exit_code != 0 and exit_code != 1: - printv(f"ESBMC exit code: {exit_code}") - printv(f"ESBMC Output:\n\n{esbmc_output}") - sys.exit(1) - - return esbmc_output - - -def init_commands() -> None: - """# Bus Signals - Function that handles initializing commands. Each command needs to be added - into the commands array in order for the command to register to be called by - the user and also register in the help system.""" - - # Let the AI model know about the corrected code. - fix_code_command.on_solution_signal.add_listener(chat.set_solution) - fix_code_command.on_solution_signal.add_listener(update_solution) - - -def _execute_fix_code_command(source_file: SourceFile) -> FixCodeCommandResult: - """Shortcut method to execute fix code command.""" - return fix_code_command.execute( - ai_model=Config.get_ai_model(), - source_file=source_file, - generate_patches=Config.generate_patches, - message_history=Config.get_value("fix_code.message_history"), - api_keys=Config.api_keys, - temperature=Config.get_value("fix_code.temperature"), - max_attempts=Config.get_value("fix_code.max_attempts"), - requests_max_tries=Config.get_llm_requests_max_tries(), - requests_timeout=Config.get_llm_requests_timeout(), - esbmc_params=Config.get_value("esbmc.params"), - raw_conversation=Config.raw_conversation, - temp_auto_clean=Config.get_value("temp_auto_clean"), - verifier_timeout=Config.get_value("esbmc.timeout"), - source_code_format=Config.get_value("source_code_format"), - esbmc_output_format=Config.get_value("esbmc.output_type"), - scenarios=Config.get_fix_code_scenarios(), - temp_file_dir=Config.get_value("temp_file_dir"), - output_dir=Config.output_dir, - ) - - def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: - path_arg: Path = Path(args.filename) - - anim: BaseLoadingWidget = ( - LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget() - ) - - solution: Solution = get_solution() - if path_arg.is_dir(): - for path in path_arg.glob("**/*"): - if path.is_file() and path.name: - solution.add_source_file(path, None) - else: - solution.add_source_file(path_arg, None) - match command.command_name: + # Basic fix mode: Supports only 1 file repair. case fix_code_command.command_name: - for source_file in solution.files: - # Run ESBMC first round - esbmc_output: str = _run_esbmc(source_file, anim) - source_file.assign_verifier_output(esbmc_output) - - result: FixCodeCommandResult = _execute_fix_code_command(source_file) + print("Reading source code...") + get_solution().load_source_files(Config().filenames) + print(f"Running ESBMC with {Config().get_value('verifier.esbmc.params')}\n") + + anim: BaseLoadingWidget = ( + LoadingWidget() + if Config().get_value("loading_hints") + else BaseLoadingWidget() + ) + for source_file in get_solution().files: + result: FixCodeCommandResult = ( + UserChatCommand._execute_fix_code_command_one_file( + fix_code_command, + source_file, + anim=anim, + ) + ) print(result) case _: @@ -214,7 +129,7 @@ def _run_command_mode(command: ChatCommand, args: argparse.Namespace) -> None: def main() -> None: parser = argparse.ArgumentParser( - prog="ESBMC-ChatGPT", + prog="ESBMC-AI", description=HELP_MESSAGE, # argparse.RawDescriptionHelpFormatter allows the ESBMC_HELP to display # properly. @@ -223,14 +138,20 @@ def main() -> None: ) parser.add_argument( - "filename", - help="The filename to pass to esbmc.", + "filenames", + default=[], + type=str, + nargs=argparse.REMAINDER, + help="The filename(s) to pass to the verifier.", ) parser.add_argument( - "remaining", - nargs=argparse.REMAINDER, - help="Any arguments after the filename will be passed to ESBMC as parameters.", + "--entry-function", + action="store", + default="main", + type=str, + required=False, + help="The name of the entry function to repair, defaults to main.", ) parser.add_argument( @@ -265,14 +186,6 @@ def main() -> None: help="Show the raw conversation at the end of a command. Good for debugging...", ) - parser.add_argument( - "-a", - "--append", - action="store_true", - default=False, - help="Any ESBMC parameters passed after the file name will be appended to the ones set in the config file, or the default ones if config file options are not set.", - ) - parser.add_argument( "-c", "--command", @@ -303,23 +216,22 @@ def main() -> None: print(f"Made by {__author__}") print() - Config.init(args) - ESBMCUtil.init(Config.get_value("esbmc.path")) + printvv("Loading main config") + Config().init(args) + printv(f"Config File: {Config().cfg_path}") check_health() - init_addons() - - printv(f"Source code format: {Config.get_value('source_code_format')}") - printv(f"ESBMC output type: {Config.get_value('esbmc.output_type')}") - - anim: BaseLoadingWidget = ( - LoadingWidget() if Config.get_value("loading_hints") else BaseLoadingWidget() - ) - - # Read the source code and esbmc output. - printv("Reading source code...") - print(f"Running ESBMC with {Config.get_value('esbmc.params')}\n") - - assert isinstance(args.filename, str) + # Load addons + printvv("Loading addons") + AddonLoader().init(Config(), verifier_runner.builtin_verifier_names) + # Bind addons to command runner and verifier runner. + command_runner.addon_commands = AddonLoader().chat_command_addons + verifier_runner.addon_verifiers = AddonLoader().verifier_addons + # Set verifier to use + printvv(f"Verifier: {verifier_runner.verfifier.verifier_name}") + verifier_runner.set_verifier_by_name(Config().get_value("verifier.name")) + + printv(f"Source code format: {Config().get_value('source_code_format')}") + printv(f"ESBMC output type: {Config().get_value('verifier.esbmc.output_type')}") # =========================================== # Command mode @@ -331,9 +243,7 @@ def main() -> None: command_names: list[str] = command_runner.command_names if command in command_names: print("Running Command:", command) - for idx, command_name in enumerate(command_names): - if command == command_name: - _run_command_mode(command=command_runner.commands[idx], args=args) + _run_command_mode(command=command_runner.commands[command], args=args) sys.exit(0) else: print( @@ -346,145 +256,11 @@ def main() -> None: # User Mode (Supports only 1 file) # =========================================== - # Init Solution - solution: Solution - # Determine if we are processing one file versus multiple files - path_arg: Path = Path(args.filename) - if path_arg.is_dir(): - # Load only files. - print( - "Processing multiple files is not supported in User Mode." - "Call a command using -c to process directories" - ) - sys.exit(1) - else: - # Add the main source file to the solution explorer. - solution: Solution = get_solution() - solution.add_source_file(path_arg, None) - del path_arg - - # Assert that we have one file and one filepath - assert len(solution.files) == 1 - - source_file: SourceFile = solution.files[0] - - esbmc_output: str = _run_esbmc(source_file, anim) - - # Print verbose lvl 2 - print_horizontal_line(2) - printvv(esbmc_output) - print_horizontal_line(2) - - source_file.assign_verifier_output(esbmc_output) - del esbmc_output - - printv(f"Initializing the LLM: {Config.get_ai_model().name}\n") - chat_llm: BaseChatModel = Config.get_ai_model().create_llm( - api_keys=Config.api_keys, - temperature=Config.get_value("user_chat.temperature"), - requests_max_tries=Config.get_value("llm_requests.max_tries"), - requests_timeout=Config.get_value("llm_requests.timeout"), - ) - - printv("Creating user chat") - global chat - chat = UserChat( - ai_model=Config.get_ai_model(), - llm=chat_llm, - source_code=source_file.latest_content, - esbmc_output=source_file.latest_verifier_output, - system_messages=Config.get_user_chat_system_messages(), - set_solution_messages=Config.get_user_chat_set_solution(), - ) - - printv("Initializing commands...") - init_commands() - - # Show the initial output. - response: ChatResponse - if len(str(Config.get_user_chat_initial().content)) > 0: - printv("Using initial prompt from file...\n") - with anim("Model is parsing ESBMC output... Please Wait"): - try: - response = chat.send_message( - message=str(Config.get_user_chat_initial().content), - ) - except Exception as e: - print("There was an error while generating a response: {e}") - sys.exit(1) - - if response.finish_reason == FinishReason.length: - raise RuntimeError(f"The token length is too large: {chat.ai_model.tokens}") - else: - raise RuntimeError("User mode initial prompt not found in config.") - - print_assistant_response(chat, response) - print( - "ESBMC-AI: Type '/help' to view the available in-chat commands, along", - "with useful prompts to ask the AI model...", - ) - - while True: - # Get user input. - user_message = input("user>: ") - - # Check if it is a command, if not, then pass it to the chat interface. - if user_message.startswith("/"): - command, command_args = CommandRunner.parse_command(user_message) - command = command[1:] # Remove the / - if command == fix_code_command.command_name: - # Fix Code command - print() - print("ESBMC-AI will generate a fix for the code...") - - result: FixCodeCommandResult = _execute_fix_code_command(source_file) - - if result.successful: - print( - "\n\nESBMC-AI: Here is the corrected code, verified with ESBMC:" - ) - print(f"```\n{result.repaired_source}\n```") - continue - else: - # Commands without parameters or returns are handled automatically. - found: bool = False - for cmd in command_runner.commands: - if cmd.command_name == command: - found = True - cmd.execute() - break - - if not found: - print("Error: Unknown command...") - continue - elif user_message == "": - continue - else: - print() - - # User chat mode send and process current message response. - while True: - # Send user message to AI model and process. - with anim("Generating response... Please Wait"): - response = chat.send_message(user_message) - - if response.finish_reason == FinishReason.stop: - break - elif response.finish_reason == FinishReason.length: - with anim( - "Message stack limit reached. Shortening message stack... Please Wait" - ): - chat.compress_message_stack() - continue - else: - raise NotImplementedError( - f"User Chat Mode: Finish Reason: {response.finish_reason}" - ) - - print_assistant_response( - chat, - response, - ) + UserChatCommand( + command_runner=command_runner, + verifier_runner=verifier_runner, + fix_code_command=fix_code_command, + ).execute() if __name__ == "__main__": diff --git a/esbmc_ai/addon_loader.py b/esbmc_ai/addon_loader.py new file mode 100644 index 0000000..4f98ad5 --- /dev/null +++ b/esbmc_ai/addon_loader.py @@ -0,0 +1,255 @@ +# Author: Yiannis Charalambous + +"""This module contains code regarding configuring and loading addon modules.""" + +import inspect +from typing import Any +from typing_extensions import Optional, override +import importlib +from importlib.util import find_spec +from importlib.machinery import ModuleSpec +import sys + +from esbmc_ai.base_config import BaseConfig +from esbmc_ai.command_runner import ChatCommand +from esbmc_ai.logging import printv +from esbmc_ai.verifier_runner import BaseSourceVerifier +from esbmc_ai.config import Config, ConfigField + + +class AddonLoader(BaseConfig): + """The addon loader manages loading addon modules. This includes: + * Managing the config fields of the addons. + * Binding the loaded addons with CommandRunner and VerifierRunner. + + It hooks into the main config loader and adds config fields for selecting + the modules. Additionally it behaves like a config loader, this is because + it puts all the configs that the addons use into a namespace called "addons". + + The exception to this are the built-in modules which directly hook into the + main config object. + """ + + addon_prefix: str = "addons" + + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(AddonLoader, cls).__new__(cls) + return cls.instance + + def init(self, config: Config, builtin_verifier_names: list[str]): + """Call to initialize the addon loader. It will load the addons and + register them with the command runner and verifier runner.""" + + self.base_init(config.cfg_path, []) + + self._config: Config = config + + # Register field with Config to know which modules to load. This will + # load them automatically. + config.add_config_field( + ConfigField( + name="addon_modules", + default_value=[], + validate=self._validate_addon_modules, + on_load=self._init_addon_modules, + error_message="addon_modules must be a list of Python modules to load", + ), + ) + + self.chat_command_addons: dict[str, ChatCommand] = {} + self.verifier_addons: dict[str, BaseSourceVerifier] = {} + + # Load all the addon commands + self._load_chat_command_addons() + + # Load all the addon verifiers + self._load_verifier_addons() + + # Register config with modules + for mod in (self.chat_command_addons | self.verifier_addons).values(): + mod.config = self + + # Ensure no duplicates + field_names: list[str] = [] + for f in self._fields + self._config._fields: + if f.name in field_names: + raise KeyError(f"Field is redefined: {f.name}") + field_names.append(f.name) + del field_names + + # Init the verifier.name field for the main config + self._config.add_config_field( + ConfigField( + name="verifier.name", + default_value="esbmc", + validate=lambda v: isinstance(v, str) + and v in self.verifier_addon_names + builtin_verifier_names, + error_message="Invalid verifier name specified.", + ) + ) + + @property + def chat_command_addon_names(self) -> list[str]: + return list(self.chat_command_addons.keys()) + + @property + def verifier_addon_names(self) -> list[str]: + return list(self.verifier_addons.keys()) + + def _load_chat_command_addons(self) -> None: + self.chat_command_addons.clear() + + for m in self._config.get_value("addon_modules"): + if isinstance(m, ChatCommand): + self.chat_command_addons[m.command_name] = m + + # Init config fields + for field in self._get_chat_command_addon_fields(): + self.add_config_field(field) + + if len(self.chat_command_addons) > 0: + printv( + "ChatCommand Addons:\n\t* " + + "\t * ".join(self.chat_command_addon_names) + ) + + def _load_verifier_addons(self) -> None: + """Loads the verifier addons, initializes their config fields.""" + self.verifier_addons.clear() + for m in self._config.get_value("addon_modules"): + if isinstance(m, BaseSourceVerifier): + self.verifier_addons[m.verifier_name] = m + + # Init config fields + for field in self._get_verifier_addon_fields(): + self.add_config_field(field) + + if len(self.verifier_addons) > 0: + printv( + "Verifier Addons:\n\t* " + + "\t * ".join(list(self.verifier_addons.keys())) + ) + + def _validate_addon_modules(self, mods: list[str]) -> bool: + """Validates that all values are string.""" + for m in mods: + if not isinstance(m, str): + return False + spec: Optional[ModuleSpec] = find_spec(m) + if spec is None: + return False + return True + + def _init_addon_modules(self, mods: list[str]) -> list: + """Will import addon modules that exist and iterate through the exposed + attributes, will then get all available exposed classes and store them. + + This method will load classes: + * ChatCommands + * BaseSourceVerifier""" + from esbmc_ai.commands.chat_command import ChatCommand + from esbmc_ai.verifiers import BaseSourceVerifier + from esbmc_ai.testing.base_tester import BaseTester + + allowed_types = ChatCommand | BaseSourceVerifier | BaseTester + + result: list = [] + for module_name in mods: + try: + m = importlib.import_module(module_name) + for attr_name in getattr(m, "__all__"): + # Get the class + attr_class = getattr(m, attr_name) + # Check if valid addon type and import + if issubclass(attr_class, allowed_types): + # Initialize class. + result.append(attr_class()) + printv(f"Loading addon: {attr_name}") + except ModuleNotFoundError as e: + print(f"Addon Loader: Could not import module: {module_name}: {e}") + sys.exit(1) + + return result + + @override + def get_value(self, name: str) -> Any: + """Searches first for a config value in the addon config, if it is + not found, searches the global. + + How does it determine if a config field is part of the addon defined + fields? Well the AddonConfig class will add a prefix followed by the + name of the verifier.""" + + # Check if it references an addon. + split_key: list[str] = name.split(".") + if split_key[0] == AddonLoader.addon_prefix: + if ( + split_key[1] + in self.chat_command_addon_names + self.verifier_addon_names + ): + # Check if key exists. + if name in self._values: + return super().get_value(name) + else: + raise KeyError(f"Key: {name} not in AddonConfig") + # If the key is not in the addon prefix, then get from global config. + return self._config.get_value(name) + + def _resolve_config_field(self, field: ConfigField, prefix: str): + """Resolve the name of each field by prefixing it with the verifier name. + Returns a new config field with the name resolved to the prefix + supplied. Using inspection all the other fields are copied. The returning + field is exactly the same as the original, aside from the resolved name.""" + + # Inspect the signature of the ConfigField which is a named tuple. + signature = inspect.signature(ConfigField) + params: dict[str, Any] = {} + # Iterate and capture all parameters + for param_name, param in signature.parameters.items(): + _ = param + match param_name: + case "name": + params[param_name] = f"{prefix}.{getattr(field, param_name)}" + case _: + params[param_name] = getattr(field, param_name) + + return ConfigField(**params) + + def _get_chat_command_addon_fields(self) -> list[ConfigField]: + """Adds each chat command's config fields to the config. After resolving + each chat command's config fields to their namespace.""" + # If an addon prefix is defined, then add a . + addons_prefix: str = ( + AddonLoader.addon_prefix + "." if AddonLoader.addon_prefix else "" + ) + fields_resolved: list[ConfigField] = [] + # Loop through verifier addons + for cmd in self.chat_command_addons.values(): + # Loop through each field of the verifier addon + fields: list[ConfigField] = cmd.get_config_fields() + for field in fields: + new_field: ConfigField = self._resolve_config_field( + field, f"{addons_prefix}{cmd.command_name}" + ) + fields_resolved.append(new_field) + return fields_resolved + + def _get_verifier_addon_fields(self) -> list[ConfigField]: + """Adds each verifier's config fields to the config. After resolving + each verifier's config fields to their namespace.""" + # If an addon prefix is defined, then add a . + addons_prefix: str = ( + AddonLoader.addon_prefix + "." if AddonLoader.addon_prefix else "" + ) + fields_resolved: list[ConfigField] = [] + # Loop through verifier addons + for verifier in self.verifier_addons.values(): + # Loop through each field of the verifier addon + fields: list[ConfigField] = verifier.get_config_fields() + for field in fields: + new_field: ConfigField = self._resolve_config_field( + field, f"{addons_prefix}{verifier.verifier_name}" + ) + fields_resolved.append(new_field) + return fields_resolved diff --git a/esbmc_ai/base_config.py b/esbmc_ai/base_config.py new file mode 100644 index 0000000..9239899 --- /dev/null +++ b/esbmc_ai/base_config.py @@ -0,0 +1,105 @@ +# Author: Yiannis Charalambous 2023 + +"""ABC Config that can be used to load config files.""" + +from abc import ABC +import sys +from pathlib import Path +import tomllib as toml +from typing import ( + Any, + Dict, + List, +) + +from esbmc_ai.config_field import ConfigField + +default_scenario: str = "base" + + +class BaseConfig(ABC): + """Config loader for ESBMC-AI""" + + def base_init(self, cfg_path: Path, fields: list[ConfigField]) -> None: + """Initializes the base config structures. Loads the config file and fields.""" + self._fields: List[ConfigField] = fields + self._values: Dict[str, Any] = {} + + self.cfg_path: Path = cfg_path + + if not self.cfg_path.exists() and self.cfg_path.is_file(): + print(f"Error: Config not found: {self.cfg_path}") + sys.exit(1) + + with open(self.cfg_path, "r") as file: + self.original_config_file: dict[str, Any] = toml.loads(file.read()) + + # Flatten dict as the _fields are defined in a flattened format for + # convenience. + self.config_file: dict[str, Any] = self.flatten_dict(self.original_config_file) + + # Load all the config file field entries + for field in self._fields: + self.add_config_field(field) + + def add_config_field(self, field: ConfigField) -> None: + """Loads a new field from the config. Init needs to be called before + calling this to initialize the base config.""" + + # If on_read is overwritten, then the reading process is manually + # defined so fallback to that. + if field.on_read: + self._values[field.name] = field.on_read(self.original_config_file) + return + + # Proceed to default read + + # Is field entry found in config? + if field.name in self.config_file: + # Check if None and not allowed! + if ( + field.default_value is None + and not field.default_value_none + and self.config_file[field.name] is None + ): + raise ValueError( + f"The config entry {field.name} has a None value when it can't be" + ) + + # Validate field + if not field.validate(self.config_file[field.name]): + msg = f"Field: {field.name} is invalid: {self.config_file[field.name]}" + if field.get_error_message is not None: + msg += ": " + field.get_error_message(self.config_file[field.name]) + elif field.error_message: + msg += ": " + field.error_message + raise ValueError(f"Config loading error: {msg}") + + # Assign field from config file + self._values[field.name] = field.on_load(self.config_file[field.name]) + elif field.default_value is None and not field.default_value_none: + raise KeyError(f"{field.name} is missing from config file") + else: + # Use default value + self._values[field.name] = field.default_value + + def get_value(self, name: str) -> Any: + """Gets the value of key name""" + return self._values[name] + + def set_value(self, name: str, value: Any) -> None: + """Sets a value in the config, if it does not exist, it will create one. + This uses toml notation dot notation to namespace the elements.""" + self._values[name] = value + + @classmethod + def flatten_dict(cls, d, parent_key="", sep="."): + """Recursively flattens a nested dictionary.""" + items = {} + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, dict): + items.update(cls.flatten_dict(v, new_key, sep=sep)) + else: + items[new_key] = v + return items diff --git a/esbmc_ai/chat_response.py b/esbmc_ai/chat_response.py index 29f1934..641d203 100644 --- a/esbmc_ai/chat_response.py +++ b/esbmc_ai/chat_response.py @@ -49,6 +49,6 @@ def dict_to_base_message(json_string: dict) -> BaseMessage: raise Exception() -def list_to_base_messages(json_messages: list[dict]) -> list[BaseMessage]: +def list_to_base_messages(json_messages: list[dict]) -> tuple[BaseMessage, ...]: """Converts a list of messages from JSON format to a list of BaseMessage.""" - return [dict_to_base_message(msg) for msg in json_messages] + return tuple(dict_to_base_message(msg) for msg in json_messages) diff --git a/esbmc_ai/chats/solution_generator.py b/esbmc_ai/chats/solution_generator.py index 1b23a2e..0557917 100644 --- a/esbmc_ai/chats/solution_generator.py +++ b/esbmc_ai/chats/solution_generator.py @@ -8,33 +8,29 @@ from langchain.schema import BaseMessage, HumanMessage from esbmc_ai.chat_response import ChatResponse, FinishReason -from esbmc_ai.config import FixCodeScenarios, default_scenario +from esbmc_ai.config import FixCodeScenario, default_scenario from esbmc_ai.solution import SourceFile from esbmc_ai.ai_models import AIModel -from esbmc_ai.esbmc_util import ESBMCUtil +from esbmc_ai.verifiers.base_source_verifier import ( + BaseSourceVerifier, + SourceCodeParseError, +) from .base_chat_interface import BaseChatInterface -class ESBMCTimedOutException(Exception): - """Error that means that ESBMC timed out and so the error could not be - determined.""" - - -class SourceCodeParseError(Exception): - """Error that means that SolutionGenerator could not parse the source code - to return the right format.""" - - def get_source_code_formatted( - source_code_format: str, source_code: str, esbmc_output: str + verifier: BaseSourceVerifier, + source_code_format: str, + source_code: str, + esbmc_output: str, ) -> str: """Gets the formatted output source code, based on the source_code_format passed.""" match source_code_format: case "single": # Get source code error line from esbmc output - line: Optional[int] = ESBMCUtil.get_source_code_err_line_idx(esbmc_output) + line: Optional[int] = verifier.get_error_line_idx(esbmc_output) if line: return source_code.splitlines(True)[line] @@ -49,34 +45,6 @@ def get_source_code_formatted( ) -def get_esbmc_output_formatted(esbmc_output_type: str, esbmc_output: str) -> str: - """Gets the formatted output ESBMC output, based on the esbmc_output_type - passed.""" - # Check for parsing error - if "ERROR: PARSING ERROR" in esbmc_output: - # Parsing errors are usually small in nature. - raise SourceCodeParseError() - - if "ERROR: Timed out" in esbmc_output: - raise ESBMCTimedOutException() - - match esbmc_output_type: - case "vp": - value: Optional[str] = ESBMCUtil.esbmc_get_violated_property(esbmc_output) - if not value: - raise ValueError("Not found violated property." + esbmc_output) - return value - case "ce": - value: Optional[str] = ESBMCUtil.esbmc_get_counter_example(esbmc_output) - if not value: - raise ValueError("Not found counterexample.") - return value - case "full": - return esbmc_output - case _: - raise ValueError(f"Not a valid ESBMC output type: {esbmc_output_type}") - - class SolutionGenerator(BaseChatInterface): """SolutionGenerator is a simple conversation-based automated program repair class. The class works in a cycle, by first calling update_state with the @@ -86,9 +54,10 @@ class supports scenarios to customize the system message and initial prompt def __init__( self, - scenarios: FixCodeScenarios, + scenarios: dict[str, FixCodeScenario], llm: BaseChatModel, ai_model: AIModel, + verifier: BaseSourceVerifier, source_code_format: str = "full", esbmc_output_type: str = "full", ) -> None: @@ -100,9 +69,11 @@ def __init__( system_messages=[], # Empty as it will be updated in the update method. ) - self.scenarios: FixCodeScenarios = scenarios + self.scenarios: dict[str, FixCodeScenario] = scenarios self.scenario: Optional[str] = None + self.verifier: BaseSourceVerifier = verifier + self.esbmc_output_type: str = esbmc_output_type self.source_code_format: str = source_code_format @@ -120,7 +91,7 @@ def compress_message_stack(self) -> None: self.invokations = 0 @classmethod - def get_code_from_solution(cls, solution: str) -> str: + def extract_code_from_solution(cls, solution: str) -> str: """Strip the source code of any leftover text as sometimes the AI model will generate text and formatting despite being told not to.""" try: @@ -150,15 +121,14 @@ def update_state(self, source_code: str, esbmc_output: str) -> None: the scenario, which is the type of error that ESBMC has shown. This should be called before generate_solution.""" - self.scenario = ESBMCUtil.esbmc_get_error_type(esbmc_output) - + self.scenario = self.verifier.get_error_scenario(esbmc_output) self.source_code_raw = source_code # Format ESBMC output try: - self.esbmc_output = get_esbmc_output_formatted( - esbmc_output_type=self.esbmc_output_type, - esbmc_output=esbmc_output, + self.esbmc_output = self.verifier.apply_formatting( + verifier_output=esbmc_output, + format=self.esbmc_output_type, ) except SourceCodeParseError: # When clang output is displayed, show it entirely as it doesn't get very @@ -167,6 +137,7 @@ def update_state(self, source_code: str, esbmc_output: str) -> None: # Format source code self.source_code_formatted = get_source_code_formatted( + verifier=self.verifier, source_code_format=self.source_code_format, source_code=source_code, esbmc_output=self.esbmc_output, @@ -176,27 +147,29 @@ def _get_system_messages( self, override_scenario: Optional[str] = None ) -> tuple[BaseMessage, ...]: if override_scenario: - system_messages = self.scenarios[override_scenario]["system"] + system_messages = self.scenarios[override_scenario].system else: assert self.scenario, "Call update or set the scenario" if self.scenario in self.scenarios: - system_messages = self.scenarios[self.scenario]["system"] + system_messages = self.scenarios[self.scenario].system else: - system_messages = self.scenarios[default_scenario]["system"] + system_messages = self.scenarios[default_scenario].system assert isinstance(system_messages, tuple) assert all(isinstance(msg, BaseMessage) for msg in system_messages) return system_messages - def _get_initial_message(self, override_scenario: Optional[str] = None) -> str: + def _get_initial_message( + self, override_scenario: Optional[str] = None + ) -> BaseMessage: if override_scenario: - return str(self.scenarios[override_scenario]["initial"]) + return self.scenarios[override_scenario].initial else: assert self.scenario, "Call update or set the scenario" if self.scenario in self.scenarios: - return str(self.scenarios[self.scenario]["initial"]) + return self.scenarios[self.scenario].initial else: - return str(self.scenarios[default_scenario]["initial"]) + return self.scenarios[default_scenario].initial def generate_solution( self, @@ -205,11 +178,11 @@ def generate_solution( ) -> tuple[str, FinishReason]: """Prompts the LLM to repair the source code using the verifier output. If this is the first time the method is called, the system message will - be sent to the LLM, unless ignore_system_message is True, in which case - the initial prompt will be used. + be sent to the LLM, unless ignore_system_message is True. Then the + initial prompt will be sent. In subsequent invokations of generate_solution, the initial prompt will - be used. + be used only. So the system messages and initial message should each include at least {source_code} and {esbmc_output} so that they are substituted into the @@ -222,34 +195,42 @@ def generate_solution( self.source_code_raw is not None and self.source_code_formatted is not None and self.esbmc_output is not None + and self.scenario is not None ), "Call update_state before calling generate_solution." - if ignore_system_message or self.invokations > 0: - # Get scenario initial message and push it to message stack - self.push_to_message_stack( - HumanMessage(content=self._get_initial_message(override_scenario)) - ) - else: + # Show system message + if not ignore_system_message and self.invokations <= 0: # Get scenario system messages and push it to message stack. Don't # push to system message stack because we want to regenerate from # the beginning at every reset. - self.push_to_message_stack(self._get_system_messages(override_scenario)) + system_messages: tuple[BaseMessage, ...] = self._get_system_messages( + override_scenario=override_scenario + ) + if len(system_messages) > 0: + self.push_to_message_stack(system_messages) + + # Get scenario initial message and push it to message stack + self.push_to_message_stack( + self._get_initial_message(override_scenario=override_scenario) + ) self.invokations += 1 + error_type: Optional[str] = self.verifier.get_error_type(self.esbmc_output) + # Apply template substitution to message stack self.apply_template_value( source_code=self.source_code_formatted, esbmc_output=self.esbmc_output, - error_line=str(ESBMCUtil.get_source_code_err_line(self.esbmc_output)), - error_type=ESBMCUtil.esbmc_get_error_type(self.esbmc_output), + error_line=str(self.verifier.get_error_line(self.esbmc_output)), + error_type=error_type if error_type else "unknown error", ) # Generate the solution response: ChatResponse = self.send_message() solution: str = str(response.message.content) - solution = SolutionGenerator.get_code_from_solution(solution) + solution = SolutionGenerator.extract_code_from_solution(solution) # Post process source code # If source code passed to LLM is formatted then we need to recombine to @@ -257,10 +238,9 @@ def generate_solution( match self.source_code_format: case "single": # Get source code error line from esbmc output - line: Optional[int] = ESBMCUtil.get_source_code_err_line_idx( + line: Optional[int] = self.verifier.get_error_line_idx( self.esbmc_output ) - assert line, ( "fix code command: error line could not be found to apply " "brutal patch replacement" diff --git a/esbmc_ai/chats/user_chat.py b/esbmc_ai/chats/user_chat.py index 2f1b72d..df88dd5 100644 --- a/esbmc_ai/chats/user_chat.py +++ b/esbmc_ai/chats/user_chat.py @@ -2,6 +2,7 @@ """Contains class that handles the UserChat of ESBMC-AI""" +from typing import Optional from typing_extensions import override from langchain.memory import ConversationSummaryMemory @@ -11,7 +12,7 @@ from esbmc_ai.ai_models import AIModel -from esbmc_ai.esbmc_util import ESBMCUtil +from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier from .base_chat_interface import BaseChatInterface @@ -26,6 +27,7 @@ def __init__( self, ai_model: AIModel, llm: BaseChatModel, + verifier: BaseSourceVerifier, source_code: str, esbmc_output: str, system_messages: list[BaseMessage], @@ -43,11 +45,13 @@ def __init__( # The messsages for setting a new solution to the source code. self.set_solution_messages = set_solution_messages + error_type: Optional[str] = verifier.get_error_type(self.esbmc_output) + self.apply_template_value( source_code=self.source_code, esbmc_output=self.esbmc_output, - error_line=str(ESBMCUtil.get_source_code_err_line(self.esbmc_output)), - error_type=ESBMCUtil.esbmc_get_error_type(self.esbmc_output), + error_line=str(verifier.get_error_line(self.esbmc_output)), + error_type=error_type if error_type else "unknown error", ) def set_solution(self, source_code: str) -> None: diff --git a/esbmc_ai/command_runner.py b/esbmc_ai/command_runner.py index 38f7716..8b14543 100644 --- a/esbmc_ai/command_runner.py +++ b/esbmc_ai/command_runner.py @@ -6,45 +6,57 @@ class CommandRunner: - """Command runner manages running and storing commands.""" + """Command runner manages running and storing commands. Singleton class.""" - def __init__(self, builtin_commands: list[ChatCommand]) -> None: - self._builtin_commands: list[ChatCommand] = builtin_commands.copy() - self._addon_commands: list[ChatCommand] = [] + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(CommandRunner, cls).__new__(cls) + return cls.instance + + def init(self, builtin_commands: list[ChatCommand]) -> "CommandRunner": + self._builtin_commands: dict[str, ChatCommand] = { + cmd.command_name: cmd for cmd in builtin_commands + } + self._addon_commands: dict[str, ChatCommand] = {} # Set the help command commands - for cmd in self._builtin_commands: - if cmd.command_name == "help": - assert isinstance(cmd, HelpCommand) - cmd.commands = self.commands + if "help" in self._builtin_commands: + assert isinstance(self._builtin_commands["help"], HelpCommand) + self._builtin_commands["help"].commands = list(self.commands.values()) + + return self @property - def commands(self) -> list[ChatCommand]: - """Returns all commands. The list is copied.""" - return self._builtin_commands + self._addon_commands + def commands(self) -> dict[str, ChatCommand]: + """Returns all commands.""" + return self._builtin_commands | self._addon_commands @property def command_names(self) -> list[str]: """Returns a list of built-in commands. This is a reference to the internal list.""" - return [cmd.command_name for cmd in self.commands] + return list(self.commands.keys()) @property def builtin_commands_names(self) -> list[str]: """Returns a list of built-in command names.""" - return [cmd.command_name for cmd in self._builtin_commands] + return list(self._builtin_commands.keys()) @property def addon_commands_names(self) -> list[str]: """Returns a list of the addon command names.""" - return [cmd.command_name for cmd in self._addon_commands] + return list(self._addon_commands.keys()) @property - def addon_commands(self) -> list[ChatCommand]: + def addon_commands(self) -> dict[str, ChatCommand]: """Returns a list of the addon commands. This is a reference to the internal list.""" return self._addon_commands + @addon_commands.setter + def addon_commands(self, value: dict[str, ChatCommand]) -> None: + self._addon_commands = value + @staticmethod def parse_command(user_prompt_string: str) -> tuple[str, list[str]]: """Parses a command and returns it based on the command rules outlined in diff --git a/esbmc_ai/commands/chat_command.py b/esbmc_ai/commands/chat_command.py index 64b36bb..aec1b75 100644 --- a/esbmc_ai/commands/chat_command.py +++ b/esbmc_ai/commands/chat_command.py @@ -6,6 +6,8 @@ from typing import Any, Optional from esbmc_ai.commands.command_result import CommandResult +from esbmc_ai.config_field import ConfigField +from esbmc_ai.base_config import BaseConfig class ChatCommand(ABC): @@ -21,6 +23,27 @@ def __init__( self.command_name = command_name self.help_message = help_message self.authors = authors + self._config: BaseConfig + + @property + def config(self) -> BaseConfig: + return self._config + + @config.setter + def config(self, value: BaseConfig) -> None: + self._config: BaseConfig = value + + def get_config_fields(self) -> list[ConfigField]: + """Called during initialization, this is meant to return all config + fields that are going to be loaded from the config. The name that each + field has will automatically be prefixed with {verifier name}.""" + return [] + + def get_config_value(self, key: str) -> Any: + """Loads a value from the config. If the value is defined in the namespace + of the verifier name then that value will be returned. + """ + return self._config.get_value(key) @abstractmethod def execute(self, **kwargs: Optional[Any]) -> Optional[CommandResult]: diff --git a/esbmc_ai/commands/fix_code_command.py b/esbmc_ai/commands/fix_code_command.py index 84dc3a4..b67d545 100644 --- a/esbmc_ai/commands/fix_code_command.py +++ b/esbmc_ai/commands/fix_code_command.py @@ -5,23 +5,24 @@ from typing import Any, Optional from typing_extensions import override +from esbmc_ai.solution import Solution, SourceFile from esbmc_ai.ai_models import AIModel from esbmc_ai.api_key_collection import APIKeyCollection from esbmc_ai.chat_response import FinishReason from esbmc_ai.chats import LatestStateSolutionGenerator, SolutionGenerator -from esbmc_ai.chats.solution_generator import ESBMCTimedOutException +from esbmc_ai.verifiers.base_source_verifier import VerifierTimedOutException from esbmc_ai.commands.command_result import CommandResult -from esbmc_ai.config import FixCodeScenarios +from esbmc_ai.config import FixCodeScenario from esbmc_ai.chats.reverse_order_solution_generator import ( ReverseOrderSolutionGenerator, ) -from esbmc_ai.solution import SourceFile +from esbmc_ai.verifier_runner import VerifierRunner +from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier, VerifierOutput from .chat_command import ChatCommand from ..msg_bus import Signal from ..loading_widget import BaseLoadingWidget -from ..esbmc_util import ESBMCUtil -from ..logging import print_horizontal_line, printv, printvv +from ..logging import print_horizontal_line, printv, printvvv class FixCodeCommandResult(CommandResult): @@ -66,8 +67,9 @@ def execute(self, **kwargs: Any) -> FixCodeCommandResult: def print_raw_conversation() -> None: print_horizontal_line(0) print("ESBMC-AI Notice: Printing raw conversation...") - all_messages = solution_generator._system_messages.copy() - all_messages.extend(solution_generator.messages.copy()) + all_messages = ( + solution_generator._system_messages + solution_generator.messages + ) messages: list[str] = [f"{msg.type}: {msg.content}" for msg in all_messages] print("\n" + "\n\n".join(messages)) print("ESBMC-AI Notice: End of raw conversation") @@ -90,14 +92,8 @@ def print_raw_conversation() -> None: timeout: int = kwargs["requests_timeout"] source_code_format: str = kwargs["source_code_format"] esbmc_output_format: str = kwargs["esbmc_output_format"] - scenarios: FixCodeScenarios = kwargs["scenarios"] + scenarios: dict[str, FixCodeScenario] = kwargs["scenarios"] max_attempts: int = kwargs["max_attempts"] - esbmc_params: list[str] = kwargs["esbmc_params"] - verifier_timeout: int = kwargs["verifier_timeout"] - temp_auto_clean: bool = kwargs["temp_auto_clean"] - temp_file_dir: Optional[Path] = ( - kwargs["temp_file_dir"] if "temp_file_dir" in kwargs else None - ) raw_conversation: bool = ( kwargs["raw_conversation"] if "raw_conversation" in kwargs else False ) @@ -107,8 +103,28 @@ def print_raw_conversation() -> None: anim: BaseLoadingWidget = ( kwargs["anim"] if "anim" in kwargs else BaseLoadingWidget() ) + entry_function: str = ( + kwargs["entry_function"] if "entry_function" in kwargs else "main" + ) # End of handle kwargs + printv(f"Temperature: {temperature}") + printv(f"Verifying function: {entry_function}") + + verifier: BaseSourceVerifier = VerifierRunner().verifier + printv(f"Running verifier: {verifier.verifier_name}") + verifier_result: VerifierOutput = verifier.verify_source(**kwargs) + source_file.assign_verifier_output(verifier_result.output) + + if verifier_result.successful(): + print("File verified successfully") + returned_source: str + if generate_patches: + returned_source = source_file.get_patch(0, -1) + else: + returned_source = source_file.latest_content + return FixCodeCommandResult(True, 0, returned_source) + match message_history: case "normal": solution_generator = SolutionGenerator( @@ -119,6 +135,7 @@ def print_raw_conversation() -> None: requests_max_tries=max_tries, requests_timeout=timeout, ), + verifier=verifier, scenarios=scenarios, source_code_format=source_code_format, esbmc_output_type=esbmc_output_format, @@ -132,6 +149,7 @@ def print_raw_conversation() -> None: requests_max_tries=max_tries, requests_timeout=timeout, ), + verifier=verifier, scenarios=scenarios, source_code_format=source_code_format, esbmc_output_type=esbmc_output_format, @@ -145,6 +163,7 @@ def print_raw_conversation() -> None: requests_max_tries=max_tries, requests_timeout=timeout, ), + verifier=verifier, scenarios=scenarios, source_code_format=source_code_format, esbmc_output_type=esbmc_output_format, @@ -159,8 +178,8 @@ def print_raw_conversation() -> None: source_code=source_file.latest_content, esbmc_output=source_file.latest_verifier_output, ) - except ESBMCTimedOutException: - print("error: ESBMC has timed out...") + except VerifierTimedOutException: + print("ESBMC-AI Notice: ESBMC has timed out...") sys.exit(1) print() @@ -181,35 +200,27 @@ def print_raw_conversation() -> None: break # Print verbose lvl 2 - printvv("\nESBMC-AI Notice: Source Code Generation:") - print_horizontal_line(2) - printvv(source_file.latest_content) - print_horizontal_line(2) - printvv("") + printvvv("\nESBMC-AI Notice: Source Code Generation:") + print_horizontal_line(3) + printvvv(source_file.latest_content) + print_horizontal_line(3) + printvvv("") # Pass to ESBMC, a workaround is used where the file is saved # to a temporary location since ESBMC needs it in file format. with anim("Verifying with ESBMC... Please Wait"): - exit_code, esbmc_output = ESBMCUtil.esbmc_load_source_code( - source_file=source_file, - source_file_content_index=-1, - esbmc_params=esbmc_params, - auto_clean=temp_auto_clean, - temp_file_dir=temp_file_dir, - timeout=verifier_timeout, - ) + verifier_result: VerifierOutput = verifier.verify_source(**kwargs) - source_file.assign_verifier_output(esbmc_output) - del esbmc_output + source_file.assign_verifier_output(verifier_result.output) # Print verbose lvl 2 - printvv("\nESBMC-AI Notice: ESBMC Output:") - print_horizontal_line(2) - printvv(source_file.latest_verifier_output) - print_horizontal_line(2) + printvvv("\nESBMC-AI Notice: ESBMC Output:") + print_horizontal_line(3) + printvvv(source_file.latest_verifier_output) + print_horizontal_line(3) # Solution found - if exit_code == 0: + if verifier_result.return_code == 0: self.on_solution_signal.emit(source_file.latest_content) if raw_conversation: @@ -237,7 +248,7 @@ def print_raw_conversation() -> None: solution_generator.update_state( source_file.latest_content, source_file.latest_verifier_output ) - except ESBMCTimedOutException: + except VerifierTimedOutException: if raw_conversation: print_raw_conversation() print("ESBMC-AI Notice: error: ESBMC has timed out...") @@ -247,7 +258,7 @@ def print_raw_conversation() -> None: if attempt != max_attempts: print(f"ESBMC-AI Notice: Failure {attempt}/{max_attempts}: Retrying...") else: - print(f"ESBMC-AI Notice: Failure {attempt}/{max_attempts}") + print(f"ESBMC-AI Notice: Failure {attempt}/{max_attempts}: Exiting...") if raw_conversation: print_raw_conversation() diff --git a/esbmc_ai/commands/user_chat_command.py b/esbmc_ai/commands/user_chat_command.py new file mode 100644 index 0000000..bd2392a --- /dev/null +++ b/esbmc_ai/commands/user_chat_command.py @@ -0,0 +1,256 @@ +# Author: Yiannis Charalambous + +import sys +from typing import Any, Optional, override + +from langchain_core.language_models import BaseChatModel +from esbmc_ai.chat_response import ChatResponse, FinishReason +from esbmc_ai.chats.user_chat import UserChat +from esbmc_ai.command_runner import CommandRunner +from esbmc_ai.commands.chat_command import ChatCommand +from esbmc_ai.commands.command_result import CommandResult +from esbmc_ai.commands.fix_code_command import FixCodeCommand, FixCodeCommandResult +from esbmc_ai.config import Config +from esbmc_ai.loading_widget import BaseLoadingWidget, LoadingWidget +from esbmc_ai.logging import print_horizontal_line, printv, printvv +from esbmc_ai.solution import SourceFile, get_solution +from esbmc_ai.verifier_runner import VerifierRunner +from esbmc_ai.verifiers.base_source_verifier import VerifierOutput + +"""This module contains the User Chat Command which is the default command that +is executed when no command is specified. It acts as a command line interface +for running the program.""" + + +class UserChatCommand(ChatCommand): + """The user chat command is the default command that is executed when no + other command is specified. It runs with execute and exits the entire program + when the command is finished. It is used to launch other commands.""" + + def __init__( + self, + command_runner: CommandRunner, + verifier_runner: VerifierRunner, + fix_code_command: FixCodeCommand, + ) -> None: + super().__init__("userchat", "Ran automatically and not exposed to the system.") + + self.command_runner: CommandRunner = command_runner + self.verifier_runner: VerifierRunner = verifier_runner + self.fix_code_command: FixCodeCommand = fix_code_command + + self.anim: BaseLoadingWidget = ( + LoadingWidget() + if Config().get_value("loading_hints") + else BaseLoadingWidget() + ) + + def _run_esbmc(self, source_file: SourceFile, anim: BaseLoadingWidget) -> str: + assert source_file.file_path + + with anim("Verifier is processing... Please Wait"): + verifier_result: VerifierOutput = ( + self.verifier_runner.verifier.verify_source( + source_file=source_file, + esbmc_params=Config().get_value("verifier.esbmc.params"), + timeout=Config().get_value("verifier.esbmc.timeout"), + ) + ) + + # ESBMC will output 0 for verification success and 1 for verification + # failed, if anything else gets thrown, it's an ESBMC error. + if not Config().get_value("allow_successful") and verifier_result.successful(): + printv(f"Verifier exit code: {verifier_result.return_code}") + printv(f"Verifier Output:\n\n{verifier_result.output}") + print("Sample successfuly verified. Exiting...") + sys.exit(0) + + return verifier_result.output + + def init_commands(self) -> None: + """# Bus Signals + Function that handles initializing commands. Each command needs to be added + into the commands array in order for the command to register to be called by + the user and also register in the help system.""" + + # Let the AI model know about the corrected code. + self.fix_code_command.on_solution_signal.add_listener(self.chat.set_solution) + self.fix_code_command.on_solution_signal.add_listener( + lambda source_code: get_solution() + .files[0] + .update_content(content=source_code, reset_changes=True) + ) + + def print_assistant_response( + self, + response: ChatResponse, + hide_stats: bool = False, + ) -> None: + print(f"{response.message.type}: {response.message.content}\n\n") + + if not hide_stats: + print( + "Stats:", + f"total tokens: {response.total_tokens},", + f"max tokens: {self.chat.ai_model.tokens}", + f"finish reason: {response.finish_reason}", + ) + + @staticmethod + def _execute_fix_code_command_one_file( + fix_code_command: FixCodeCommand, + source_file: SourceFile, + anim: Optional[BaseLoadingWidget] = None, + ) -> FixCodeCommandResult: + """Shortcut method to execute fix code command.""" + return fix_code_command.execute( + anim=anim, + ai_model=Config().get_ai_model(), + source_file=source_file, + generate_patches=Config().generate_patches, + message_history=Config().get_value("fix_code.message_history"), + api_keys=Config().api_keys, + temperature=Config().get_value("fix_code.temperature"), + max_attempts=Config().get_value("fix_code.max_attempts"), + requests_max_tries=Config().get_llm_requests_max_tries(), + requests_timeout=Config().get_llm_requests_timeout(), + esbmc_params=Config().get_value("verifier.esbmc.params"), + raw_conversation=Config().raw_conversation, + temp_auto_clean=Config().get_value("temp_auto_clean"), + verifier_timeout=Config().get_value("verifier.esbmc.timeout"), + source_code_format=Config().get_value("source_code_format"), + esbmc_output_format=Config().get_value("verifier.esbmc.output_type"), + scenarios=Config().get_fix_code_scenarios(), + temp_file_dir=Config().get_value("temp_file_dir"), + output_dir=Config().output_dir, + ) + + @override + def execute(self, **kwargs: Optional[Any]) -> Optional[CommandResult]: + # Read the source code and esbmc output. + print("Reading source code...") + get_solution().load_source_files(Config().filenames) + print(f"Running ESBMC with {Config().get_value('verifier.esbmc.params')}\n") + source_file: SourceFile = get_solution().files[0] + + esbmc_output: str = self._run_esbmc(source_file, self.anim) + + # Print verbose lvl 2 + print_horizontal_line(2) + printvv(esbmc_output) + print_horizontal_line(2) + + source_file.assign_verifier_output(esbmc_output) + del esbmc_output + + printv(f"Initializing the LLM: {Config().get_ai_model().name}\n") + chat_llm: BaseChatModel = ( + Config() + .get_ai_model() + .create_llm( + api_keys=Config().api_keys, + temperature=Config().get_value("user_chat.temperature"), + requests_max_tries=Config().get_value("llm_requests.max_tries"), + requests_timeout=Config().get_value("llm_requests.timeout"), + ) + ) + + printv("Creating user chat") + self.chat: UserChat = UserChat( + ai_model=Config().get_ai_model(), + llm=chat_llm, + verifier=self.verifier_runner.verifier, + source_code=source_file.latest_content, + esbmc_output=source_file.latest_verifier_output, + system_messages=Config().get_user_chat_system_messages(), + set_solution_messages=Config().get_user_chat_set_solution(), + ) + + printv("Initializing commands...") + self.init_commands() + + # Show the initial output. + response: ChatResponse + if len(str(Config().get_user_chat_initial().content)) > 0: + printv("Using initial prompt from file...\n") + with self.anim("Model is parsing ESBMC output... Please Wait"): + try: + response = self.chat.send_message( + message=str(Config().get_user_chat_initial().content), + ) + except Exception as e: + print("There was an error while generating a response: {e}") + sys.exit(1) + + if response.finish_reason == FinishReason.length: + raise RuntimeError( + f"The token length is too large: {self.chat.ai_model.tokens}" + ) + else: + raise RuntimeError("User mode initial prompt not found in config.") + + self.print_assistant_response(response) + print( + "ESBMC-AI: Type '/help' to view the available in-chat commands, along", + "with useful prompts to ask the AI model...", + ) + + while True: + # Get user input. + user_message = input("user>: ") + + # Check if it is a command, if not, then pass it to the chat interface. + if user_message.startswith("/"): + command, command_args = CommandRunner.parse_command(user_message) + command = command[1:] # Remove the / + if command == self.fix_code_command.command_name: + # Fix Code command + print() + print("ESBMC-AI will generate a fix for the code...") + + result: FixCodeCommandResult = ( + self._execute_fix_code_command_one_file( + fix_code_command=self.fix_code_command, + source_file=source_file, + ) + ) + + if result.successful: + print( + "\n\nESBMC-AI: Here is the corrected code, verified with ESBMC:" + ) + print(f"```\n{result.repaired_source}\n```") + continue + else: + # Commands without parameters or returns are handled automatically. + if command in self.command_runner.commands: + self.command_runner.commands[command].execute() + break + else: + print("Error: Unknown command...") + continue + elif user_message == "": + continue + else: + print() + + # User chat mode send and process current message response. + while True: + # Send user message to AI model and process. + with self.anim("Generating response... Please Wait"): + response = self.chat.send_message(user_message) + + if response.finish_reason == FinishReason.stop: + break + elif response.finish_reason == FinishReason.length: + with self.anim( + "Message stack limit reached. Shortening message stack... Please Wait" + ): + self.chat.compress_message_stack() + continue + else: + raise NotImplementedError( + f"User Chat Mode: Finish Reason: {response.finish_reason}" + ) + + self.print_assistant_response(response) diff --git a/esbmc_ai/config.py b/esbmc_ai/config.py index 186f40b..3076c19 100644 --- a/esbmc_ai/config.py +++ b/esbmc_ai/config.py @@ -1,27 +1,26 @@ # Author: Yiannis Charalambous 2023 -import importlib -from importlib.util import find_spec -from importlib.machinery import ModuleSpec + +import argparse +from dataclasses import dataclass import os import sys from platform import system as system_name from pathlib import Path -import tomllib as toml from typing import ( Any, - Callable, Dict, List, - NamedTuple, Optional, ) from dotenv import load_dotenv, find_dotenv from langchain.schema import HumanMessage +from esbmc_ai.config_field import ConfigField +from esbmc_ai.base_config import BaseConfig, default_scenario from esbmc_ai.chat_response import list_to_base_messages -from esbmc_ai.logging import printv, set_verbose +from esbmc_ai.logging import set_verbose from .ai_models import ( BaseMessage, is_valid_ai_model, @@ -32,47 +31,13 @@ ) from .api_key_collection import APIKeyCollection -FixCodeScenarios = dict[str, dict[str, str | tuple[BaseMessage, ...]]] -"""Type for scenarios. A single scenario contains initial and system components. - -* Initial message can be accessed like so: `x["base"]["initial"]` -* System messages can be accessed like so: `x["base"]["system"]` - -The config loader ensures they conform to the specifications.""" - -default_scenario: str = "base" - - -class ConfigField(NamedTuple): - name: str - """The name of the config field and also namespace""" - default_value: Any - """If a default value is supplied, then it can be omitted from the config. - In order to have a "None" default value, default_value_none must be set.""" - default_value_none: bool = False - """If true, then the default value will be None, so during - validation, if no value is supplied, then None will be the - the default value, instead of failing due to None being the - default value which under normal circumstances means that the - field is not optional.""" - validate: Callable[[Any], bool] = lambda _: True - """Lambda function to validate if field has a valid value. - Default is identity function which is return true.""" - on_load: Callable[[Any], Any] = lambda v: v - """Transform the value once loaded, this allows the value to be saved - as a more complex type than that which is represented in the config - file. - - Is ignored if on_read is defined.""" - on_read: Optional[Callable[[dict[str, Any]], Any]] = None - """If defined, will be called and allows to custom load complex types that - may not match 1-1 in the config. The config file passed as a parameter here - is the original, unflattened version. The value returned should be the value - assigned to this field. - - This is a more versatile version of on_load. So if this is used, the on_load - will be ignored.""" - error_message: Optional[str] = None + +@dataclass +class FixCodeScenario: + """Type for scenarios. A single scenario contains initial and system components.""" + + initial: BaseMessage + system: tuple[BaseMessage, ...] def _validate_prompt_template_conversation(prompt_template: List[Dict]) -> bool: @@ -104,312 +69,251 @@ def _validate_prompt_template(conv: Dict[str, List[Dict]]) -> bool: return True -def _validate_addon_modules(mods: list[str]) -> bool: - """Validates that all values are string.""" - for m in mods: - if not isinstance(m, str): - return False - spec: Optional[ModuleSpec] = find_spec(m) - if spec is None: - return False - return True - - -def _init_addon_modules(mods: list[str]) -> list: - """Will import addon modules that exist and iterate through the exposed - attributes, will then get all available ChatCommands and store them.""" - from esbmc_ai.commands.chat_command import ChatCommand - - result: list[ChatCommand] = [] - for module_name in mods: - try: - m = importlib.import_module(module_name) - for attr_name in getattr(m, "__all__"): - attr_class = getattr(m, attr_name) - if issubclass(attr_class, ChatCommand): - result.append(attr_class()) - printv(f"Loading addon: {attr_name}") - except ModuleNotFoundError as e: - print(f"Addon Loader: Could not import module: {module_name}: {e}") - sys.exit(1) - - return result - - -class Config: +class Config(BaseConfig): """Config loader for ESBMC-AI""" - api_keys: APIKeyCollection - raw_conversation: bool = False - cfg_path: Path - generate_patches: bool - output_dir: Optional[Path] = None - - _fields: List[ConfigField] = [ - ConfigField( - name="ai_model", - default_value=None, - # Api keys are loaded from system env so they are already - # available - validate=lambda v: isinstance(v, str) - and is_valid_ai_model(v, Config.api_keys), - on_load=lambda v: get_ai_model_by_name(v, Config.api_keys), - ), - ConfigField( - name="temp_auto_clean", - default_value=True, - validate=lambda v: isinstance(v, bool), - ), - ConfigField( - name="temp_file_dir", - default_value=None, - validate=lambda v: isinstance(v, str) and Path(v).is_file(), - on_load=Path, - default_value_none=True, - ), - ConfigField( - name="allow_successful", - default_value=False, - validate=lambda v: isinstance(v, bool), - ), - ConfigField( - name="loading_hints", - default_value=True, - validate=lambda v: isinstance(v, bool), - ), - ConfigField( - name="source_code_format", - default_value="full", - validate=lambda v: isinstance(v, str) and v in ["full", "single"], - error_message="source_code_format can only be 'full' or 'single'", - ), - # Store as a list of commands - ConfigField( - name="addon_modules", - default_value=[], - validate=_validate_addon_modules, - on_load=_init_addon_modules, - error_message="addon_modules must be a list of Python modules to load", - ), - ConfigField( - name="esbmc.path", - default_value=None, - validate=lambda v: isinstance(v, str) and Path(v).expanduser().is_file(), - on_load=lambda v: Path(v).expanduser(), - ), - ConfigField( - name="esbmc.params", - default_value=[ - "--interval-analysis", - "--goto-unwind", - "--unlimited-goto-unwind", - "--k-induction", - "--state-hashing", - "--add-symex-value-sets", - "--k-step", - "2", - "--floatbv", - "--unlimited-k-steps", - "--compact-trace", - "--context-bound", - "2", - ], - validate=lambda v: isinstance(v, List), - ), - ConfigField( - name="esbmc.output_type", - default_value="full", - validate=lambda v: v in ["full", "vp", "ce"], - ), - ConfigField( - name="esbmc.timeout", - default_value=60, - validate=lambda v: isinstance(v, int), - ), - ConfigField( - name="llm_requests.max_tries", - default_value=5, - validate=lambda v: isinstance(v, int), - ), - ConfigField( - name="llm_requests.timeout", - default_value=60, - validate=lambda v: isinstance(v, int), - ), - ConfigField( - name="user_chat.temperature", - default_value=1.0, - validate=lambda v: isinstance(v, float) and 0 <= v <= 2.0, - error_message="Temperature needs to be a value between 0 and 2.0", - ), - ConfigField( - name="fix_code.temperature", - default_value=1.0, - validate=lambda v: isinstance(v, float) and 0 <= v <= 2, - error_message="Temperature needs to be a value between 0 and 2.0", - ), - ConfigField( - name="fix_code.max_attempts", - default_value=5, - validate=lambda v: isinstance(v, int), - ), - ConfigField( - name="fix_code.message_history", - default_value="normal", - validate=lambda v: v in ["normal", "latest_only", "reverse"], - error_message='fix_code.message_history can only be "normal", "latest_only", "reverse"', - ), - ConfigField( - name="prompt_templates.user_chat.initial", - default_value=None, - validate=lambda v: isinstance(v, str), - on_load=lambda v: HumanMessage(content=v), - ), - ConfigField( - name="prompt_templates.user_chat.system", - default_value=None, - validate=_validate_prompt_template_conversation, - on_load=list_to_base_messages, - ), - ConfigField( - name="prompt_templates.user_chat.set_solution", - default_value=None, - validate=_validate_prompt_template_conversation, - on_load=list_to_base_messages, - ), - # Here we have a list of prompt templates that are for each scenario. - # The base scenario prompt template is required. - ConfigField( - name="prompt_templates.fix_code", - default_value=None, - validate=lambda v: default_scenario in v - and all( - _validate_prompt_template(prompt_template) - for prompt_template in v.values() - ), - on_read=lambda config_file: { - scenario: { - "initial": HumanMessage(content=conv["initial"]), - "system": list_to_base_messages(conv["system"]), - } - for scenario, conv in config_file["prompt_templates"][ - "fix_code" - ].items() - }, - ), - ] - _values: Dict[str, Any] = {} + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(Config, cls).__new__(cls) + return cls.instance # Define some shortcuts for the values here (instead of having to use get_value) - @classmethod - def get_ai_model(cls) -> AIModel: + def get_ai_model(self) -> AIModel: """Value of field: ai_model""" - return cls.get_value("ai_model") + return self.get_value("ai_model") - @classmethod - def get_llm_requests_max_tries(cls) -> int: + def get_llm_requests_max_tries(self) -> int: """Value of field: llm_requests.max_tries""" - return cls.get_value("llm_requests.max_tries") + return self.get_value("llm_requests.max_tries") - @classmethod - def get_llm_requests_timeout(cls) -> float: + def get_llm_requests_timeout(self) -> float: """""" - return cls.get_value("llm_requests.timeout") + return self.get_value("llm_requests.timeout") - @classmethod - def get_user_chat_initial(cls) -> BaseMessage: + def get_user_chat_initial(self) -> BaseMessage: """Value of field: prompt_templates.user_chat.initial""" - return cls.get_value("prompt_templates.user_chat.initial") + return self.get_value("prompt_templates.user_chat.initial") - @classmethod - def get_user_chat_system_messages(cls) -> list[BaseMessage]: + def get_user_chat_system_messages(self) -> list[BaseMessage]: """Value of field: prompt_templates.user_chat.system""" - return cls.get_value("prompt_templates.user_chat.system") + return self.get_value("prompt_templates.user_chat.system") - @classmethod - def get_user_chat_set_solution(cls) -> list[BaseMessage]: + def get_user_chat_set_solution(self) -> list[BaseMessage]: """Value of field: prompt_templates.user_chat.set_solution""" - return cls.get_value("prompt_templates.user_chat.set_solution") + return self.get_value("prompt_templates.user_chat.set_solution") - @classmethod - def get_fix_code_scenarios(cls) -> FixCodeScenarios: + def get_fix_code_scenarios(self) -> dict[str, FixCodeScenario]: """Value of field: prompt_templates.fix_code""" - return cls.get_value("prompt_templates.fix_code") - - @classmethod - def init(cls, args: Any) -> None: - """Static init method for the static class. Will load the config from - the args, the env file and then from config file.""" - cls._load_envs() - - if not Config.cfg_path.exists() and Config.cfg_path.is_file(): - print(f"Error: Config not found: {Config.cfg_path}") - sys.exit(1) - - with open(Config.cfg_path, "r") as file: - original_config_file: dict[str, Any] = toml.loads(file.read()) - - # Load custom AIs - if "ai_custom" in original_config_file: - _load_custom_ai(original_config_file["ai_custom"]) - - # Flatten dict as the _fields are defined in a flattened format for - # convenience. - config_file: dict[str, Any] = cls._flatten_dict(original_config_file) - - # Load all the config file field entries - for field in cls._fields: - # If on_read is overwritten, then the reading process is manually - # defined so fallback to that. - if field.on_read: - cls._values[field.name] = field.on_read(original_config_file) - continue - - # Proceed to default read - - # Is field entry found in config? - if field.name in config_file: - # Check if None and not allowed! - if ( - field.default_value is None - and not field.default_value_none - and config_file[field.name] is None - ): - raise ValueError( - f"The config entry {field.name} has a None value when it can't be" + return self.get_value("prompt_templates.fix_code") + + @property + def filenames(self) -> list[Path]: + """Gets the filanames that are to be loaded""" + return self.get_value("solution.filenames") + + def init(self, args: Any) -> None: + """Will load the config from the args, the env file and then from config file. + Call once to initialize.""" + + self._args: argparse.Namespace = args + self.api_keys: APIKeyCollection + self.raw_conversation: bool = False + self.generate_patches: bool + self.output_dir: Optional[Path] = None + + fields: list[ConfigField] = [ + ConfigField( + name="ai_custom", + default_value=[], + on_read=lambda cfg: self._load_custom_ai(cfg["ai_custom"]), + error_message="Invalid custom AI specification", + ), + # This needs to be processed after ai_custom + ConfigField( + name="ai_model", + default_value=None, + # Api keys are loaded from system env so they are already + # available + validate=lambda v: isinstance(v, str) + and is_valid_ai_model(v, self.api_keys), + on_load=lambda v: get_ai_model_by_name(v, self.api_keys), + ), + ConfigField( + name="temp_auto_clean", + default_value=True, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="temp_file_dir", + default_value=None, + validate=lambda v: isinstance(v, str) and Path(v).is_file(), + on_load=Path, + default_value_none=True, + ), + ConfigField( + name="allow_successful", + default_value=False, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="loading_hints", + default_value=True, + validate=lambda v: isinstance(v, bool), + ), + ConfigField( + name="source_code_format", + default_value="full", + validate=lambda v: isinstance(v, str) and v in ["full", "single"], + error_message="source_code_format can only be 'full' or 'single'", + ), + ConfigField( + name="solution.filenames", + default_value=[], + validate=lambda v: isinstance(v, list) + # Validate config values + and all(isinstance(f, str) and Path(f).exists() for f in v) + # Validate arg values + and all(Path(f).exists() for f in self._args.filenames), + on_load=self._filenames_load, + get_error_message=self._filenames_error_msg, + ), + # If argument is passed, then the config value is ignored. + ConfigField( + name="solution.entry_function", + default_value=None, + validate=lambda v: isinstance(v, str) + and ( + # This impliments logical implication A => B + # So if entry_function arg is set then it must be a string + not self._args.entry_function + or isinstance(self._args.entry_function, str) + ), + on_load=lambda v: ( + self._args.entry_function if self._args.entry_function else v + ), + error_message="The entry function name needs to be a string", + ), + ConfigField( + name="verifier.esbmc.path", + default_value=None, + validate=lambda v: isinstance(v, str) + and Path(v).expanduser().is_file(), + on_load=lambda v: Path(v).expanduser(), + ), + ConfigField( + name="verifier.esbmc.params", + default_value=[ + "--interval-analysis", + "--goto-unwind", + "--unlimited-goto-unwind", + "--k-induction", + "--state-hashing", + "--add-symex-value-sets", + "--k-step", + "2", + "--floatbv", + "--unlimited-k-steps", + "--compact-trace", + "--context-bound", + "2", + ], + validate=lambda v: isinstance(v, list), + ), + ConfigField( + name="verifier.esbmc.output_type", + default_value="full", + validate=lambda v: v in ["full", "vp", "ce"], + ), + ConfigField( + name="verifier.esbmc.timeout", + default_value=60, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="tester", + default_value="simple", + ), + ConfigField( + name="llm_requests.max_tries", + default_value=5, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="llm_requests.timeout", + default_value=60, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="user_chat.temperature", + default_value=1.0, + validate=lambda v: isinstance(v, float) and 0 <= v <= 2.0, + error_message="Temperature needs to be a value between 0 and 2.0", + ), + ConfigField( + name="fix_code.temperature", + default_value=1.0, + validate=lambda v: isinstance(v, float) and 0 <= v <= 2, + error_message="Temperature needs to be a value between 0 and 2.0", + ), + ConfigField( + name="fix_code.max_attempts", + default_value=5, + validate=lambda v: isinstance(v, int), + ), + ConfigField( + name="fix_code.message_history", + default_value="normal", + validate=lambda v: v in ["normal", "latest_only", "reverse"], + error_message='fix_code.message_history can only be "normal", "latest_only", "reverse"', + ), + ConfigField( + name="prompt_templates.user_chat.initial", + default_value=None, + validate=lambda v: isinstance(v, str), + on_load=lambda v: HumanMessage(content=v), + ), + ConfigField( + name="prompt_templates.user_chat.system", + default_value=None, + validate=_validate_prompt_template_conversation, + on_load=list_to_base_messages, + ), + ConfigField( + name="prompt_templates.user_chat.set_solution", + default_value=None, + validate=_validate_prompt_template_conversation, + on_load=list_to_base_messages, + ), + # Here we have a list of prompt templates that are for each scenario. + # The base scenario prompt template is required. + ConfigField( + name="prompt_templates.fix_code", + default_value=None, + validate=lambda v: default_scenario in v + and all( + _validate_prompt_template(prompt_template) + for prompt_template in v.values() + ), + on_read=lambda config_file: { + scenario: FixCodeScenario( + initial=HumanMessage(content=conv["initial"]), + system=list_to_base_messages(conv["system"]), ) + for scenario, conv in config_file["prompt_templates"][ + "fix_code" + ].items() + }, + ), + ] - # Validate field - assert field.validate(config_file[field.name]), ( - field.error_message - if field.error_message - else f"Field: {field.name} is invalid: {config_file[field.name]}" - ) - - # Assign field from config file - cls._values[field.name] = field.on_load(config_file[field.name]) - elif field.default_value is None and not field.default_value_none: - raise KeyError(f"{field.name} is missing from config file") - else: - # Use default value - cls._values[field.name] = field.default_value - - cls._load_args(args) - - @classmethod - def get_value(cls, name: str) -> Any: - """Gets the value of key name""" - return cls._values[name] + self._load_envs() - @classmethod - def set_value(cls, name: str, value: Any) -> None: - """Sets a value in the config, if it does not exist, it will create one. - This uses toml notation dot notation to namespace the elements.""" - cls._values[name] = value + # Base init needs to be called last (only before load args) + super().base_init(self.cfg_path, fields) + self._load_args() - @classmethod - def _load_envs(cls) -> None: + def _load_envs(self) -> None: """Environment variables are loaded in the following order: 1. Environment variables already loaded. Any variable not present will be looked for in @@ -467,118 +371,133 @@ def get_env_vars() -> None: print(f"Error: No ${key} in environment.") sys.exit(1) - cls.api_keys = APIKeyCollection( + self.api_keys = APIKeyCollection( openai=str(os.getenv("OPENAI_API_KEY")), ) - cls.cfg_path = Path( + self.cfg_path: Path = Path( os.path.expanduser( os.path.expandvars(str(os.getenv("ESBMCAI_CONFIG_PATH"))) ) ) - @classmethod - def _load_args(cls, args) -> None: + def _load_args(self) -> None: + args: argparse.Namespace = self._args + set_verbose(args.verbose) # AI Model -m if args.ai_model != "": - if is_valid_ai_model(args.ai_model, cls.api_keys): - ai_model = get_ai_model_by_name(args.ai_model, cls.api_keys) - cls.set_value("ai_model", ai_model) + if is_valid_ai_model(args.ai_model, self.api_keys): + ai_model = get_ai_model_by_name(args.ai_model, self.api_keys) + self.set_value("ai_model", ai_model) else: print(f"Error: invalid --ai-model parameter {args.ai_model}") sys.exit(4) - # If append flag is set, then append. - if args.append: - esbmc_params: List[str] = cls.get_value("esbmc.params") - esbmc_params.extend(args.remaining) - cls.set_value("esbmc_params", esbmc_params) - elif len(args.remaining) != 0: - cls.set_value("esbmc_params", args.remaining) - - Config.raw_conversation = args.raw_conversation - Config.generate_patches = args.generate_patches + self.raw_conversation = args.raw_conversation + self.generate_patches = args.generate_patches if args.output_dir: path: Path = Path(args.output_dir).expanduser() if path.is_dir(): - Config.output_dir = path + self.output_dir = path else: print( "Error while parsing arguments: output_dir: dir does not exist:", - Config.output_dir, + self.output_dir, ) sys.exit(1) - @classmethod - def _flatten_dict(cls, d, parent_key="", sep="."): - """Recursively flattens a nested dictionary.""" - items = {} - for k, v in d.items(): - new_key = parent_key + sep + k if parent_key else k - if isinstance(v, dict): - items.update(cls._flatten_dict(v, new_key, sep=sep)) - else: - items[new_key] = v - return items + def _validate_custom_ai(self, ai_config_list: dict) -> bool: + for name, ai_config in ai_config_list.items(): + # Max tokens + if "max_tokens" not in ai_config: + raise KeyError( + f'max_tokens field not found in "ai_custom" entry "{name}".' + ) + elif not isinstance(ai_config["max_tokens"], int): + raise TypeError( + f'custom_ai_max_tokens in ai_custom entry "{name}" needs to ' + "be an int and greater than 0." + ) + elif ai_config["max_tokens"] <= 0: + raise ValueError( + f'custom_ai_max_tokens in ai_custom entry "{name}" needs to ' + "be an int and greater than 0." + ) + # URL + if "url" not in ai_config: + raise KeyError(f'url field not found in "ai_custom" entry "{name}".') -def _load_custom_ai(config: dict) -> None: - """Loads custom AI defined in the config and ascociates it with the AIModels - module.""" + # Server type + if "server_type" not in ai_config: + raise KeyError( + f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" + ) - def _load_config_value( - config_file: dict, name: str, default: object = None - ) -> tuple[Any, bool]: - if name in config_file: - return config_file[name], True + return True - print(f"Warning: {name} not found in config... Using default value: {default}") - return default, False + def _load_custom_ai(self, ai_config_list: dict) -> list[AIModel]: + """Loads custom AI defined in the config and ascociates it with the AIModels + module.""" - for name, ai_data in config.items(): - # Load the max tokens - custom_ai_max_tokens, ok = _load_config_value( - config_file=ai_data, - name="max_tokens", - ) - assert ok, f'max_tokens field not found in "ai_custom" entry "{name}".' - assert ( - isinstance(custom_ai_max_tokens, int) and custom_ai_max_tokens > 0 - ), f'custom_ai_max_tokens in ai_custom entry "{name}" needs to be an int and greater than 0.' - - # Load the URL - custom_ai_url, ok = _load_config_value( - config_file=ai_data, - name="url", - ) - assert ok, f'url field not found in "ai_custom" entry "{name}".' + self._validate_custom_ai(ai_config_list) - # Get provider type - server_type, ok = _load_config_value( - config_file=ai_data, - name="server_type", - default="localhost:11434", - ) - assert ( - ok - ), f"server_type for custom AI '{name}' is invalid, it needs to be a valid string" - - # Create correct type of LLM - llm: AIModel - match server_type: - case "ollama": - llm = OllamaAIModel( - name=name, - tokens=custom_ai_max_tokens, - url=custom_ai_url, - ) - case _: - raise NotImplementedError( - f"The custom AI server type is not implemented: {server_type}" - ) + custom_ai: list[AIModel] = [] + for name, ai_config in ai_config_list.items(): + # Load the max tokens + max_tokens: int = ai_config["max_tokens"] + + # Load the URL + url: str = ai_config["url"] + + # Get provider type + server_type = ai_config["server_type"] + + # Create correct type of LLM + llm: AIModel + match server_type: + case "ollama": + llm = OllamaAIModel( + name=name, + tokens=max_tokens, + url=url, + ) + case _: + raise NotImplementedError( + f"The custom AI server type is not implemented: {server_type}" + ) + + # Add the custom AI. + custom_ai.append(llm) + add_custom_ai_model(llm) + + return custom_ai + + def _filenames_load(self, file_names: list[str]) -> list[Path]: + """Loads the filenames from the command line first then from the config.""" + + results: list[Path] = [] + + if len(self._args.filenames): + results.extend(Path(f) for f in self._args.filenames) + + for file in file_names: + results.append(Path(file)) + return results + + @staticmethod + def _filenames_error_msg(file_names: list) -> str: + """Gets the error message for an invalid list of file_names specified in + the config.""" + + wrong: list[str] = [] + for file_name in file_names: + if not isinstance(file_name, str) or not ( + Path(file_name).is_file() and Path(file_name).is_dir() + ): + wrong.append(file_name) - # Add the custom AI. - add_custom_ai_model(llm) + return "The following files cannot be found: " + ", ".join(wrong) diff --git a/esbmc_ai/config_field.py b/esbmc_ai/config_field.py new file mode 100644 index 0000000..9d33ba0 --- /dev/null +++ b/esbmc_ai/config_field.py @@ -0,0 +1,47 @@ +# Author: Yiannis Charalambous + +"""This module can be used by other modules to declare config entries.""" + +from typing import ( + Any, + Callable, + NamedTuple, + Optional, +) + + +class ConfigField(NamedTuple): + """Represents a loadable entry in the config.""" + + name: str + """The name of the config field and also namespace""" + default_value: Any + """If a default value is supplied, then it can be omitted from the config. + In order to have a "None" default value, default_value_none must be set.""" + default_value_none: bool = False + """If true, then the default value will be None, so during + validation, if no value is supplied, then None will be the + the default value, instead of failing due to None being the + default value which under normal circumstances means that the + field is not optional.""" + validate: Callable[[Any], bool] = lambda _: True + """Lambda function to validate if field has a valid value. + Default is identity function which is return true.""" + on_load: Callable[[Any], Any] = lambda v: v + """Transform the value once loaded, this allows the value to be saved + as a more complex type than that which is represented in the config + file. + + Is ignored if on_read is defined.""" + on_read: Optional[Callable[[dict[str, Any]], Any]] = None + """If defined, will be called and allows to custom load complex types that + may not match 1-1 in the config. The config file passed as a parameter here + is the original, unflattened version. The value returned should be the value + assigned to this field. + + This is a more versatile version of on_load. So if this is used, the on_load + will be ignored.""" + error_message: Optional[str] = None + """Optional string to provide a generic error message.""" + get_error_message: Optional[Callable[[Any], str]] = None + """Optionsl function to get more verbose output than error_message.""" diff --git a/esbmc_ai/solution.py b/esbmc_ai/solution.py index 75fbc2c..4eac4ea 100644 --- a/esbmc_ai/solution.py +++ b/esbmc_ai/solution.py @@ -235,11 +235,31 @@ def files(self) -> tuple[SourceFile, ...]: return tuple(self._files) @property - def files_mapped(self) -> dict[Path, SourceFile]: + def files_mapped(self) -> dict[str, SourceFile]: """Will return the files mapped to their directory. Returns by value.""" - return {source_file.file_path: source_file for source_file in self._files} + return {str(source_file.file_path): source_file for source_file in self._files} + + def add_source_file(self, source_file: SourceFile) -> None: + """Adds a source file to the solution.""" + self._files.append(source_file) + + def add_source_files(self, source_files: list[SourceFile]) -> None: + """Adds multiple source files to the solution""" + for f in source_files: + self._files.append(f) + + def load_source_files(self, file_paths: list[Path]) -> None: + """Loads multiple source files from disk.""" + for f in file_paths: + assert isinstance(f, Path), f"Invalid type: {type(f)}" + if f.is_dir(): + for path in f.glob("**/*"): + if path.is_file() and path.name: + self.load_source_file(path, None) + else: + self.load_source_file(f, None) - def add_source_file(self, file_path: Path, content: Optional[str]) -> None: + def load_source_file(self, file_path: Path, content: Optional[str] = None) -> None: """Add a source file to the solution. If content is provided then it will not be loaded.""" assert file_path diff --git a/esbmc_ai/testing/__init__.py b/esbmc_ai/testing/__init__.py new file mode 100644 index 0000000..1f91aa3 --- /dev/null +++ b/esbmc_ai/testing/__init__.py @@ -0,0 +1 @@ +# Author: Yiannis Charalambous diff --git a/esbmc_ai/testing/base_tester.py b/esbmc_ai/testing/base_tester.py new file mode 100644 index 0000000..18a3b51 --- /dev/null +++ b/esbmc_ai/testing/base_tester.py @@ -0,0 +1,5 @@ +# Author: Yiannis Charalambous + + +class BaseTester: + pass diff --git a/esbmc_ai/verifier_runner.py b/esbmc_ai/verifier_runner.py new file mode 100644 index 0000000..a35bee9 --- /dev/null +++ b/esbmc_ai/verifier_runner.py @@ -0,0 +1,57 @@ +from esbmc_ai.verifiers import BaseSourceVerifier +from esbmc_ai.config import ConfigField + + +class VerifierRunner: + def __new__(cls): + if not hasattr(cls, "instance"): + cls.instance = super(VerifierRunner, cls).__new__(cls) + return cls.instance + + def init(self, builtin_verifiers: list[BaseSourceVerifier]) -> "VerifierRunner": + self._builtin_verifiers: dict[str, BaseSourceVerifier] = { + v.verifier_name: v for v in builtin_verifiers + } + """Builtin loaded verifiers""" + self._addon_verifiers: dict[str, BaseSourceVerifier] = {} + """Additional loaded verifiers""" + self._verifier: BaseSourceVerifier = builtin_verifiers[0] + """Default verifier""" + return self + + @property + def verfifier(self) -> BaseSourceVerifier: + return self._verifier + + @verfifier.setter + def verifier(self, value: BaseSourceVerifier) -> None: + assert ( + value not in self.verifiers + ), f"Unregistered verifier set: {value.verifier_name}" + self._verifier = value + + @property + def builtin_verifier_names(self) -> list[str]: + """Gets the names of the builtin verifiers""" + return list(self._builtin_verifiers.keys()) + + @property + def verifiers(self) -> dict[str, BaseSourceVerifier]: + """Gets all verifiers""" + return self._builtin_verifiers | self._addon_verifiers + + @property + def addon_verifiers(self) -> dict[str, BaseSourceVerifier]: + return self._addon_verifiers + + @addon_verifiers.setter + def addon_verifiers(self, vers: dict[str, BaseSourceVerifier]) -> None: + self._addon_verifiers = vers + + @property + def addon_verifier_names(self) -> list[str]: + """Gets all addon verifier names""" + return list(self._addon_verifiers.keys()) + + def set_verifier_by_name(self, value: str) -> None: + self.verifier = self.verifiers[value] diff --git a/esbmc_ai/verifiers/__init__.py b/esbmc_ai/verifiers/__init__.py new file mode 100644 index 0000000..02092af --- /dev/null +++ b/esbmc_ai/verifiers/__init__.py @@ -0,0 +1,4 @@ +from .base_source_verifier import BaseSourceVerifier +from .esbmc import ESBMC + +__all__ = ["BaseSourceVerifier", "ESBMC"] diff --git a/esbmc_ai/verifiers/base_source_verifier.py b/esbmc_ai/verifiers/base_source_verifier.py new file mode 100644 index 0000000..28978da --- /dev/null +++ b/esbmc_ai/verifiers/base_source_verifier.py @@ -0,0 +1,124 @@ +# Author: Yiannis Charalambous + +"""This module holds the code for the base source code verifier.""" + +from dataclasses import dataclass +import re +from abc import ABC, abstractmethod +from typing import Any, Optional + +from esbmc_ai.solution import SourceFile +from esbmc_ai.base_config import BaseConfig, ConfigField, default_scenario + + +class SourceCodeParseError(Exception): + """Error that means that SolutionGenerator could not parse the source code + to return the right format.""" + + +class VerifierTimedOutException(Exception): + """Error that means that ESBMC timed out and so the error could not be + determined.""" + + +@dataclass +class VerifierOutput: + """Class that represents the verifier output.""" + + return_code: int + """The return code of the verifier.""" + output: str + """The output of the verifier.""" + + @abstractmethod + def successful(self) -> bool: + """If the verification was successful.""" + raise NotImplementedError() + + +class BaseSourceVerifier(ABC): + """The base class for creating a source verifier for ESBMC-AI. In order for + this class to work with ESBMC-AI, the constructor must have default values + to all arguments because it will be invoked without passing anything. + + Loading from the config will be permitted but it would be preferred if you + use the base class method `get_config_value`. The fields of the config that + are going to be loaded need to be declared and returned by the + `get_config_fields` method. The config loader will automatically preppend + the verifier_name declared to each key so that there are no clashes in the + config. So the verifier "esbmc" will have for key "timeout" the following + field in the config: "esbmc.timeout".""" + + def __init__(self, verifier_name: str) -> None: + """Verifier name needs to be a valid TOML key.""" + super().__init__() + pattern = re.compile(r"[a-zA-Z_]\w*") + assert pattern.match( + verifier_name + ), f"Invalid toml-friendly verifier name: {verifier_name}" + + self.verifier_name: str = verifier_name + self._config: BaseConfig + + @property + def config(self) -> BaseConfig: + return self._config + + @config.setter + def config(self, value: BaseConfig) -> None: + self._config: BaseConfig = value + + def get_config_fields(self) -> list[ConfigField]: + """Called during initialization, this is meant to return all config + fields that are going to be loaded from the config. The name that each + field has will automatically be prefixed with {verifier name}.""" + return [] + + def get_config_value(self, key: str) -> Any: + """Loads a value from the config. If the value is defined in the namespace + of the verifier name then that value will be returned. + """ + return self._config.get_value(key) + + def verify_source( + self, + source_file: SourceFile, + source_file_iteration: int = -1, + **kwargs: Any, + ) -> VerifierOutput: + """Verifies source_file, the kwargs are optional arguments that are + child dependent. For API purposes, the overriden method can provide the + abilitiy to override values that would be loaded from the config by + specifying them in the kwargs.""" + _ = source_file + _ = source_file_iteration + _ = kwargs + raise NotImplementedError() + + def apply_formatting(self, verifier_output: str, format: str) -> str: + """Applies a formatting style to the verifier output. This is used to + change the output to a different form for when it is supplied to the + LLM.""" + _ = verifier_output + _ = format + raise NotImplementedError() + + def get_error_line(self, verifier_output: str) -> Optional[int]: + """Returns the line number of where the error as occurred.""" + _ = verifier_output + raise NotImplementedError() + + def get_error_line_idx(self, verifier_output: str) -> Optional[int]: + """Returns the line index of where the error as occurred.""" + _ = verifier_output + raise NotImplementedError() + + def get_error_type(self, verifier_output: str) -> Optional[str]: + """Returns a string of the type of error found by the verifier output.""" + _ = verifier_output + raise NotImplementedError() + + def get_error_scenario(self, verifier_output: str) -> str: + """Gets the scenario for fixing the error from verifier output""" + _ = verifier_output + return default_scenario diff --git a/esbmc_ai/verifiers/dummy_verifier.py b/esbmc_ai/verifiers/dummy_verifier.py new file mode 100644 index 0000000..1def731 --- /dev/null +++ b/esbmc_ai/verifiers/dummy_verifier.py @@ -0,0 +1,108 @@ +# Author: Yiannis Charalambous + +"""This module holds the code for a dummy source code verifier.""" + +from typing import Any, Optional, override + +from esbmc_ai.solution import SourceFile +from esbmc_ai.config import ConfigField +from esbmc_ai.verifiers.base_source_verifier import BaseSourceVerifier, VerifierOutput + + +class DummyVerifierOutput(VerifierOutput): + """Class that represents the dummy verifier output. The output is going to + be blank, and the return code can be set to whatever value. If it is 0, it + will be successful.""" + + return_code: int + """The return code of the verifier.""" + output: str + """The output of the verifier.""" + + @override + def successful(self) -> bool: + return self.return_code == 0 + + +class DummyVerifier(BaseSourceVerifier): + """Dummy verifier with pre-configured responses. Used for testing.""" + + def __init__( + self, responses: Optional[list[str]] = None, load_config: bool = True + ) -> None: + """Creates a new dummy verifier.""" + super().__init__(verifier_name="dummy_verifier") + self._responses: Optional[list[str]] = responses + self._current_response: int = 0 + self._load_config = load_config + + @property + def responses(self) -> list[str]: + if self._load_config: + return ( + self._responses + if self._responses + else self.get_config_value("responses") + ) + else: + return self._responses if self._responses else [] + + @responses.setter + def responses(self, value: Optional[list[str]]) -> None: + self._responses = value + + def set_response_counter(self, value: Optional[int] = None) -> None: + self._current_response = value if value else 0 + assert ( + 0 <= self._current_response < len(self.responses) + ), f"Responses index set out of range: 0 <= {self._current_response} < {len(self.responses)}" + + @override + def get_config_fields(self) -> list[ConfigField]: + return [ + ConfigField( + name="responses", + default_value=[], + error_message="Invalid, needs to be an array of strings.", + validate=lambda v: isinstance(v, str), + ) + ] + + @override + def verify_source( + self, + source_file: SourceFile, + source_file_iteration: int = -1, + **kwargs: Any, + ) -> DummyVerifierOutput: + """Verifies source_file, the kwargs are optional arguments that are + child dependent. For API purposes, the overriden method can provide the + abilitiy to override values that would be loaded from the config by + specifying them in the kwargs.""" + _ = source_file + _ = source_file_iteration + value = kwargs["value"] if "value" in kwargs else 0 + return DummyVerifierOutput(value, self.responses[self._current_response]) + + @override + def apply_formatting(self, verifier_output: str, format: str) -> str: + _ = format + return verifier_output + + @override + def get_error_line(self, verifier_output: str) -> Optional[int]: + """1""" + _ = verifier_output + return 1 + + @override + def get_error_line_idx(self, verifier_output: str) -> Optional[int]: + """0""" + _ = verifier_output + return 0 + + @override + def get_error_type(self, verifier_output: str) -> Optional[str]: + """Returns empty string""" + _ = verifier_output + return "" diff --git a/esbmc_ai/esbmc_util.py b/esbmc_ai/verifiers/esbmc.py similarity index 57% rename from esbmc_ai/esbmc_util.py rename to esbmc_ai/verifiers/esbmc.py index cb83280..9f2eead 100644 --- a/esbmc_ai/esbmc_util.py +++ b/esbmc_ai/verifiers/esbmc.py @@ -4,17 +4,27 @@ import sys from subprocess import PIPE, STDOUT, run, CompletedProcess from pathlib import Path -from typing import Optional +from typing_extensions import Any, Optional, override +from esbmc_ai.config import Config from esbmc_ai.solution import SourceFile -from esbmc_ai.config import default_scenario +from esbmc_ai.base_config import BaseConfig, default_scenario +from esbmc_ai.verifiers.base_source_verifier import ( + BaseSourceVerifier, + SourceCodeParseError, + VerifierOutput, + VerifierTimedOutException, +) + + +class ESBMCOutput(VerifierOutput): + @override + def successful(self) -> bool: + return self.return_code == 0 -class ESBMCUtil: - @classmethod - def init(cls, esbmc_path: Path) -> None: - cls.esbmc_path: Path = esbmc_path +class ESBMC(BaseSourceVerifier): @classmethod def esbmc_get_violated_property(cls, esbmc_output: str) -> Optional[str]: """Gets the violated property line of the ESBMC output.""" @@ -31,47 +41,7 @@ def esbmc_get_counter_example(cls, esbmc_output: str) -> Optional[str]: idx: int = esbmc_output.find("[Counterexample]\n") if idx == -1: return None - else: - return esbmc_output[idx:] - - @classmethod - def esbmc_get_error_type(cls, esbmc_output: str) -> str: - """Gets the error of violated property, the entire line.""" - # TODO Test me - # Start search from the marker. - marker: str = "Violated property:\n" - violated_property_index: int = esbmc_output.rfind(marker) + len(marker) - from_loc_error_msg: str = esbmc_output[violated_property_index:] - # Find second new line which contains the location of the violated - # property and that should point to the line with the type of error. - # In this case, the type of error is the "scenario". - scenario_index: int = from_loc_error_msg.find("\n") - scenario: str = from_loc_error_msg[scenario_index + 1 :] - scenario_end_l_index: int = scenario.find("\n") - scenario = scenario[:scenario_end_l_index].strip() - - if not scenario: - return default_scenario - - return scenario - - @classmethod - def get_source_code_err_line(cls, esbmc_output: str) -> Optional[int]: - """Gets the error line of the esbmc_output, regardless if it is a - counterexample or clang output.""" - line: Optional[int] = cls.get_esbmc_err_line(esbmc_output) - if not line: - line = ESBMCUtil.get_clang_err_line(esbmc_output) - return line - - @classmethod - def get_source_code_err_line_idx(cls, esbmc_output: str) -> Optional[int]: - """Gets the error line index of the esbmc_output regardless if it is a - counterexample or clang output.""" - line: Optional[int] = cls.get_source_code_err_line_idx(esbmc_output) - if not line: - return ESBMCUtil.get_clang_err_line_idx(esbmc_output) - return line - 1 + return esbmc_output[idx:] @classmethod def get_esbmc_err_line(cls, esbmc_output: str) -> Optional[int]: @@ -93,8 +63,7 @@ def get_esbmc_err_line_idx(cls, esbmc_output: str) -> Optional[int]: line: Optional[int] = cls.get_esbmc_err_line(esbmc_output) if line: return line - 1 - else: - return None + return None @classmethod def get_clang_err_line(cls, clang_output: str) -> Optional[int]: @@ -118,74 +87,60 @@ def get_clang_err_line_idx(cls, clang_output: str) -> Optional[int]: line: Optional[int] = cls.get_clang_err_line(clang_output) if line: return line - 1 - else: - return None - - @classmethod - def esbmc( - cls, - path: Path, - esbmc_params: list, - timeout: Optional[int] = None, - ): - """Exit code will be 0 if verification successful, 1 if verification - failed. And any other number for compilation error/general errors.""" - # Build parameters - esbmc_cmd = [str(cls.esbmc_path)] - esbmc_cmd.extend(esbmc_params) - esbmc_cmd.append(str(path)) - - if "--timeout" in esbmc_cmd: - print( - 'Do not add --timeout to ESBMC parameters, instead specify it in "verifier_timeout".' - ) - sys.exit(1) - - esbmc_cmd.extend(["--timeout", str(timeout)]) + return None - # Add slack time to process to allow verifier to timeout and end gracefully. - process_timeout: Optional[float] = timeout + 10 if timeout else None + def __init__(self) -> None: + super().__init__("esbmc") + self.config = Config() - # Run ESBMC and get output - process: CompletedProcess = run( - esbmc_cmd, - stdout=PIPE, - stderr=STDOUT, - timeout=process_timeout, - ) + @property + def esbmc_path(self) -> Path: + return self.get_config_value("verifier.esbmc.path") - output: str = process.stdout.decode("utf-8") - return process.returncode, output - - @classmethod - def esbmc_load_source_code( - cls, + @override + def verify_source( + self, source_file: SourceFile, - source_file_content_index: int, - esbmc_params: list, - auto_clean: bool, + source_file_iteration: int = -1, + esbmc_params: tuple = (), + auto_clean: bool = False, + entry_function: str = "main", temp_file_dir: Optional[Path] = None, timeout: Optional[int] = None, - ): + **kwargs: Any, + ) -> ESBMCOutput: + _ = kwargs + + if "--timeout" in esbmc_params: + print( + "Do not add --timeout to ESBMC parameters, instead specify it in its own field." + ) + sys.exit(1) + if "--function" in esbmc_params: + print( + "Don't add --function to ESBMC parameters, instead specify it in its own field." + ) + sys.exit(1) file_path: Path if temp_file_dir: file_path = source_file.save_file( file_path=Path(temp_file_dir), temp_dir=False, - index=source_file_content_index, + index=source_file_iteration, ) else: file_path = source_file.save_file( file_path=None, temp_dir=True, - index=source_file_content_index, + index=source_file_iteration, ) # Call ESBMC to temporary folder. - results = cls.esbmc( + results = self._esbmc( path=file_path, esbmc_params=esbmc_params, + entry_function=entry_function, timeout=timeout, ) @@ -195,4 +150,114 @@ def esbmc_load_source_code( os.remove(file_path) # Return - return results + return_code, output = results + return ESBMCOutput( + return_code=return_code, + output=output, + ) + + @override + def apply_formatting(self, verifier_output: str, format: str) -> str: + """Gets the formatted output ESBMC output, based on the esbmc_output_type + passed.""" + # Check for parsing error + if "ERROR: PARSING ERROR" in verifier_output: + # Parsing errors are usually small in nature. + raise SourceCodeParseError() + + if "ERROR: Timed out" in verifier_output: + raise VerifierTimedOutException() + + match format: + case "vp": + value: Optional[str] = self.esbmc_get_violated_property(verifier_output) + if not value: + raise ValueError("Not found violated property." + verifier_output) + return value + case "ce": + value: Optional[str] = self.esbmc_get_counter_example(verifier_output) + if not value: + raise ValueError("Not found counterexample.") + return value + case "full": + return verifier_output + case _: + raise ValueError(f"Not a valid ESBMC output type: {format}") + + @override + def get_error_type(self, verifier_output: str) -> Optional[str]: + """Gets the error of violated property, the entire line.""" + # TODO Test me + # Start search from the marker. + marker: str = "Violated property:\n" + violated_property_index: int = verifier_output.rfind(marker) + len(marker) + from_loc_error_msg: str = verifier_output[violated_property_index:] + # Find second new line which contains the location of the violated + # property and that should point to the line with the type of error. + # In this case, the type of error is the "scenario". + scenario_index: int = from_loc_error_msg.find("\n") + scenario: str = from_loc_error_msg[scenario_index + 1 :] + scenario_end_l_index: int = scenario.find("\n") + scenario = scenario[:scenario_end_l_index].strip() + + return scenario + + @override + def get_error_scenario(self, verifier_output: str) -> str: + scenario: Optional[str] = self.get_error_type(verifier_output) + if not scenario: + return default_scenario + return scenario + + @override + def get_error_line(self, verifier_output: str) -> Optional[int]: + """Gets the error line of the esbmc_output, regardless if it is a + counterexample or clang output.""" + line: Optional[int] = self.get_esbmc_err_line(verifier_output) + if not line: + line = self.get_clang_err_line(verifier_output) + return line + + @override + def get_error_line_idx(self, verifier_output: str) -> Optional[int]: + """Gets the error line index of the esbmc_output regardless if it is a + counterexample or clang output.""" + line: Optional[int] = self.get_esbmc_err_line_idx(verifier_output) + if not line: + return self.get_clang_err_line_idx(verifier_output) + return line - 1 + + def _esbmc( + self, + path: Path, + esbmc_params: tuple, + entry_function: str, + timeout: Optional[int] = None, + ): + """Exit code will be 0 if verification successful, 1 if verification + failed. And any other number for compilation error/general errors.""" + # TODO verify_source + # Build parameters + esbmc_cmd = [str(self.esbmc_path)] + esbmc_cmd.extend(esbmc_params) + esbmc_cmd.append(str(path)) + + # Add timeout suffix for parameter. + esbmc_cmd.extend(["--timeout", str(timeout) + "s"]) + # Add entry function for parameter. + esbmc_cmd.extend(["--function", entry_function]) + + # Add slack time to process to allow verifier to timeout and end gracefully. + process_timeout: Optional[float] = timeout + 10 if timeout else None + + # Run ESBMC and get output + process: CompletedProcess = run( + esbmc_cmd, + stdout=PIPE, + stderr=STDOUT, + timeout=process_timeout, + check=False, + ) + + output: str = process.stdout.decode("utf-8") + return process.returncode, output diff --git a/notebooks/ast.ipynb b/notebooks/ast.ipynb deleted file mode 100644 index cd05e5f..0000000 --- a/notebooks/ast.ipynb +++ /dev/null @@ -1,435 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "# Test LLVM AST Notebook\n", - "\n", - "## Author: Yiannis Charalambous\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from clang.cindex import Config\n", - "import clang.native\n", - "import clang.cindex\n", - "import sys\n", - "from typing import NamedTuple\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# Connect the Python API of Clang to the libclang.so file bundled in the libclang PyPI package.\n", - "Config.library_file = os.path.join(\n", - " os.path.dirname(clang.native.__file__),\n", - " \"libclang.so\",\n", - ")\n", - "\n", - "module_path = os.path.abspath(os.path.join(\"..\"))\n", - "if module_path not in sys.path:\n", - " sys.path.append(module_path)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found a [line=4, col=5]\n", - " Token int TokenKind.KEYWORD\n", - " Token a TokenKind.IDENTIFIER\n", - "Start: 42, End: 47, Range: 5\n", - "\n", - "Found b [line=4, col=8]\n", - " Token int TokenKind.KEYWORD\n", - " Token a TokenKind.IDENTIFIER\n", - " Token , TokenKind.PUNCTUATION\n", - " Token b TokenKind.IDENTIFIER\n", - "Start: 42, End: 50, Range: 8\n", - "\n", - "Found __VERIFIER_atomic_acquire [line=5, col=6]\n", - " Token void TokenKind.KEYWORD\n", - " Token __VERIFIER_atomic_acquire TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token void TokenKind.KEYWORD\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token { TokenKind.PUNCTUATION\n", - " Token __VERIFIER_assume TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token a TokenKind.IDENTIFIER\n", - " Token == TokenKind.PUNCTUATION\n", - " Token 0 TokenKind.LITERAL\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token a TokenKind.IDENTIFIER\n", - " Token = TokenKind.PUNCTUATION\n", - " Token 1 TokenKind.LITERAL\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token } TokenKind.PUNCTUATION\n", - "Start: 52, End: 134, Range: 82\n", - "\n", - "Found c [line=10, col=7]\n", - " Token void TokenKind.KEYWORD\n", - " Token * TokenKind.PUNCTUATION\n", - " Token c TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token void TokenKind.KEYWORD\n", - " Token * TokenKind.PUNCTUATION\n", - " Token arg TokenKind.IDENTIFIER\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token { TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token __VERIFIER_atomic_acquire TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token b TokenKind.IDENTIFIER\n", - " Token = TokenKind.PUNCTUATION\n", - " Token 1 TokenKind.LITERAL\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token return TokenKind.KEYWORD\n", - " Token NULL TokenKind.IDENTIFIER\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token } TokenKind.PUNCTUATION\n", - "Start: 135, End: 224, Range: 89\n", - "\n", - "Found d [line=17, col=11]\n", - " Token pthread_t TokenKind.IDENTIFIER\n", - " Token d TokenKind.IDENTIFIER\n", - "Start: 225, End: 236, Range: 11\n", - "\n", - "Found main [line=18, col=5]\n", - " Token int TokenKind.KEYWORD\n", - " Token main TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token { TokenKind.PUNCTUATION\n", - " Token pthread_create TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token & TokenKind.PUNCTUATION\n", - " Token d TokenKind.IDENTIFIER\n", - " Token , TokenKind.PUNCTUATION\n", - " Token 0 TokenKind.LITERAL\n", - " Token , TokenKind.PUNCTUATION\n", - " Token c TokenKind.IDENTIFIER\n", - " Token , TokenKind.PUNCTUATION\n", - " Token 0 TokenKind.LITERAL\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token __VERIFIER_atomic_acquire TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token if TokenKind.KEYWORD\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token ! TokenKind.PUNCTUATION\n", - " Token b TokenKind.IDENTIFIER\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token assert TokenKind.IDENTIFIER\n", - " Token ( TokenKind.PUNCTUATION\n", - " Token 0 TokenKind.LITERAL\n", - " Token ) TokenKind.PUNCTUATION\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token return TokenKind.KEYWORD\n", - " Token 0 TokenKind.LITERAL\n", - " Token ; TokenKind.PUNCTUATION\n", - " Token } TokenKind.PUNCTUATION\n", - "Start: 238, End: 363, Range: 125\n", - "\n", - "Total 6\n" - ] - } - ], - "source": [ - "FILE = \"../samples/threading.c\"\n", - "\n", - "\n", - "def get_declarations_local(root: clang.cindex.Cursor) -> list[clang.cindex.Cursor]:\n", - " declarations: list[clang.cindex.Cursor] = []\n", - " declarations_raw: set[str] = {}\n", - " tokens: list[clang.cindex.Token] = []\n", - " # Scan all direct symbols in root.\n", - " for child in root.get_children():\n", - " # print(f\"Scanning: {child.spelling}\")\n", - " node: clang.cindex.Cursor = child\n", - " kind: clang.cindex.CursorKind = node.kind\n", - " # Check if it is actually from the file.\n", - " if (\n", - " kind.is_declaration()\n", - " and node.storage_class == clang.cindex.StorageClass.NONE\n", - " ):\n", - " print(\n", - " f\"Found {node.spelling} [line={node.location.line}, col={node.location.column}]\"\n", - " )\n", - " tokens: clang.cindex.Token = node.get_tokens()\n", - " for token in tokens:\n", - " print(f\" Token {token.spelling} {token.kind}\")\n", - " loc: clang.cindex.SourceRange = node.extent\n", - " end: clang.cindex.SourceLocation = loc.end\n", - " start: clang.cindex.SourceLocation = loc.start\n", - " print(\n", - " f\"Start: {start.offset}, End: {end.offset}, Range: {end.offset - start.offset}\"\n", - " )\n", - " print()\n", - " declarations.append(node)\n", - " return declarations\n", - "\n", - "\n", - "index: clang.cindex.Index = clang.cindex.Index.create()\n", - "tu: clang.cindex.TranslationUnit = index.parse(FILE)\n", - "root: clang.cindex.Cursor = tu.cursor\n", - "declarations: clang.cindex.Cursor = get_declarations_local(root)\n", - "\n", - "print(f\"Total {len(declarations)}\")\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "tags": [] - }, - "source": [ - "## Reversing Reach AST To Source Code\n", - "\n", - "The only issue I have found, multiple declarations in one statement need to be recognized and the nodes combined:\n", - "\n", - "```c\n", - "int a, b;\n", - "```\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Code for a:\n", - "```\n", - "int a\n", - "```\n", - "\n", - "Code for b:\n", - "```\n", - "int a, b\n", - "```\n", - "\n", - "Code for __VERIFIER_atomic_acquire():\n", - "```\n", - "void __VERIFIER_atomic_acquire(void)\n", - "{\n", - " __VERIFIER_assume(a == 0);\n", - " a = 1;\n", - "}\n", - "```\n", - "\n", - "Code for c(void *):\n", - "```\n", - "void *c(void *arg)\n", - "{\n", - " ;\n", - " __VERIFIER_atomic_acquire();\n", - " b = 1;\n", - " return NULL;\n", - "}\n", - "```\n", - "\n", - "Code for d:\n", - "```\n", - "pthread_t d\n", - "```\n", - "\n", - "Code for main():\n", - "```\n", - "int main()\n", - "{\n", - " pthread_create(&d, 0, c, 0);\n", - " __VERIFIER_atomic_acquire();\n", - " if (!b)\n", - " assert(0);\n", - " return 0;\n", - "}\n", - "```\n", - "\n" - ] - } - ], - "source": [ - "with open(FILE) as file:\n", - " source_code: str = file.read()\n", - "\n", - "\n", - "def get_node_source_code(source_code: str, node: clang.cindex.Cursor) -> str:\n", - " loc: clang.cindex.SourceRange = node.extent\n", - " start: clang.cindex.SourceLocation = loc.start\n", - " end: clang.cindex.SourceLocation = loc.end\n", - " return source_code[start.offset : end.offset]\n", - "\n", - "\n", - "for node in declarations:\n", - " print(f\"Code for {node.displayname}:\")\n", - " print(\"```\")\n", - " print(get_node_source_code(source_code, node))\n", - " print(\"```\")\n", - " print()\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Test Code\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from esbmc_ai.frontend.ast import ClangAST\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "__VERIFIER_atomic_acquire()\n", - "\n", - "\n", - "\n", - "c(arg: void *)\n", - "\n", - "\n", - "\n", - "main()\n", - "\n", - "\n", - "\n" - ] - } - ], - "source": [ - "file = \"../samples/threading.c\"\n", - "cast = ClangAST(file)\n", - "functions = cast.get_fn_decl()\n", - "\n", - "for fn in functions:\n", - " print(str(fn) + \"\\n\")\n", - " # Seems like different cursors have the same translation unit...\n", - " print(fn.cursor)\n", - " print(fn.cursor.translation_unit)\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Test Code 2\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "struct linear {value: int}\n", - "typedef (LinearTypeDef) struct linear {struct linear: struct linear}\n", - "Point {x: int, y: int}\n", - "typedef (Point) Point {Point: Point}\n", - "enum Types {ONE: int, TWO: int, THREE: int}\n", - "typedef (Typest) enum Types {Types: enum Types}\n", - "union Combines {a: int, b: int, c: int}\n", - "typedef (CombinesTypeDef) union Combines {union Combines: union Combines}\n" - ] - } - ], - "source": [ - "file = \"./samples/typedefs.c\"\n", - "cast = ClangAST(file)\n", - "functions = cast.get_type_decl()\n", - "\n", - "for fn in functions:\n", - " print(fn)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "#include \"/usr/include/stdlib.h\"\n", - "#include \"/usr/include/assert.h\"\n" - ] - } - ], - "source": [ - "file = \"./samples/typedefs.c\"\n", - "cast: ClangAST = ClangAST(file)\n", - "includes = cast.get_include_directives()\n", - "\n", - "for include in includes:\n", - " print(include)\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "esbmc-ai-awqrJrdH", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/notebooks/samples/typedefs.c b/notebooks/samples/typedefs.c deleted file mode 100644 index d0ea9f2..0000000 --- a/notebooks/samples/typedefs.c +++ /dev/null @@ -1,54 +0,0 @@ -#include -#include - -#define NUM_ONE 1 -#define NUM_TWO 2 - -struct linear -{ - int value; -}; - -typedef struct linear LinearTypeDef; - -typedef struct -{ - int x; - int y; -} Point; - -Point a; -Point *b; - -int c; - -char *d; - -typedef enum Types -{ - ONE, - TWO, - THREE -} Typest; - -enum Types e = ONE; - -Typest f = TWO; - -union Combines -{ - int a; - int b; - int c; -}; - -typedef union Combines CombinesTypeDef; - -int main() -{ - Point *a = (Point *)malloc(sizeof(Point)); - if (a != NULL) - return -1; - free(a); - return 0; -} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f613127..69a81b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,8 @@ dependencies = [ "hatch", "transformers", "torch", + "langchain", + "langchain_community" ] [project.scripts] diff --git a/tests/regtest/_regtest_outputs/test_solution_generator.test_generate_solution.out b/tests/regtest/_regtest_outputs/test_solution_generator.test_generate_solution.out new file mode 100644 index 0000000..32b9755 --- /dev/null +++ b/tests/regtest/_regtest_outputs/test_solution_generator.test_generate_solution.out @@ -0,0 +1,4 @@ +FinishReason.stop +22222 +FinishReason.stop +33333 diff --git a/tests/regtest/test_solution_generator.py b/tests/regtest/test_solution_generator.py new file mode 100644 index 0000000..e911a23 --- /dev/null +++ b/tests/regtest/test_solution_generator.py @@ -0,0 +1,53 @@ +# Author: Yiannis Charalambous + + +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.language_models import FakeListChatModel + +from esbmc_ai.ai_models import AIModel +from esbmc_ai.verifiers.dummy_verifier import DummyVerifier +from esbmc_ai.chats.solution_generator import SolutionGenerator +from esbmc_ai.config import FixCodeScenario + + +def test_generate_solution(regtest) -> None: + with open( + "tests/samples/esbmc_output/line_test/cartpole_95_safe.c-amalgamation-80.c", "r" + ) as file: + esbmc_output: str = file.read() + + verifier = DummyVerifier(responses=[esbmc_output] * 2) + + chat = SolutionGenerator( + scenarios={ + "base": FixCodeScenario( + initial=HumanMessage( + "{source_code}{esbmc_output}{error_line}{error_type}" + ), + system=( + SystemMessage( + content="System:{source_code}{esbmc_output}{error_line}{error_type}" + ), + ), + ) + }, + verifier=verifier, + ai_model=AIModel("test", 10000000), + llm=FakeListChatModel(responses=["22222", "33333"]), + source_code_format="full", + esbmc_output_type="full", + ) + + chat.update_state("11111", esbmc_output) + sol, res = chat.generate_solution(ignore_system_message=False) + + with regtest: + print(res) + print(sol) + + chat.update_state("11111", esbmc_output) + sol, res = chat.generate_solution(ignore_system_message=False) + + with regtest: + print(res) + print(sol) diff --git a/tests/test_config.py b/tests/test_config.py index b663717..255f5cf 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,8 +2,7 @@ from pytest import raises -import esbmc_ai.config as config - +from esbmc_ai.config import Config from esbmc_ai.ai_models import is_valid_ai_model @@ -16,7 +15,7 @@ def test_load_custom_ai() -> None: } } - config._load_custom_ai(custom_ai_config) + Config()._load_custom_ai(custom_ai_config) assert is_valid_ai_model("example_ai") @@ -27,59 +26,59 @@ def test_load_custom_ai_fail() -> None: "example_ai_2": { "max_tokens": "1024", "url": "www.example.com", - "config_message": "example", + "server_type": "ollama", } } - with raises(AssertionError): - config._load_custom_ai(ai_conf) + with raises(TypeError): + Config()._validate_custom_ai(ai_conf) # Wrong max_tokens value ai_conf: dict = { "example_ai_2": { "max_tokens": 0, "url": "www.example.com", - "config_message": "example", + "server_type": "ollama", } } - with raises(AssertionError): - config._load_custom_ai(ai_conf) + with raises(ValueError): + Config()._validate_custom_ai(ai_conf) # Missing max_tokens ai_conf: dict = { "example_ai_2": { "url": "www.example.com", - "config_message": "example", + "server_type": "ollama", } } - with raises(AssertionError): - config._load_custom_ai(ai_conf) + with raises(KeyError): + Config()._validate_custom_ai(ai_conf) # Missing url ai_conf: dict = { "example_ai_2": { - "max_tokens": 0, - "config_message": "example", + "max_tokens": 1000, + "server_type": "ollama", } } - with raises(AssertionError): - config._load_custom_ai(ai_conf) + with raises(KeyError): + Config()._validate_custom_ai(ai_conf) - # Missing config message + # Missing server type ai_conf: dict = { "example_ai_2": { - "max_tokens": 0, + "max_tokens": 100, "url": "www.example.com", } } - with raises(AssertionError): - config._load_custom_ai(ai_conf) + with raises(KeyError): + Config()._validate_custom_ai(ai_conf) # Test load empty ai_conf: dict = {} - config._load_custom_ai(ai_conf) + Config()._validate_custom_ai(ai_conf) diff --git a/tests/test_esbmc_util.py b/tests/test_esbmc_util.py index 09f78bd..d4cedff 100644 --- a/tests/test_esbmc_util.py +++ b/tests/test_esbmc_util.py @@ -3,7 +3,7 @@ import pytest from os import listdir -from esbmc_ai.esbmc_util import ESBMCUtil +from esbmc_ai.verifiers import ESBMC as ESBMCUtil @pytest.fixture(scope="module") diff --git a/tests/test_latest_state_solution_generator.py b/tests/test_latest_state_solution_generator.py index f81371b..9d37d6b 100644 --- a/tests/test_latest_state_solution_generator.py +++ b/tests/test_latest_state_solution_generator.py @@ -1,15 +1,16 @@ # Author: Yiannis Charalambous -from typing import Any, Optional +from typing import Optional from langchain_core.language_models import FakeListChatModel import pytest -from langchain.schema import HumanMessage, AIMessage, SystemMessage +from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage -from esbmc_ai.config import default_scenario +from esbmc_ai.config import FixCodeScenario, default_scenario from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse from esbmc_ai.chats.latest_state_solution_generator import LatestStateSolutionGenerator +from esbmc_ai.verifiers import ESBMC @pytest.fixture(scope="function") @@ -29,25 +30,26 @@ def test_send_message(setup_llm_model) -> None: llm, model = setup_llm_model solution_generator = LatestStateSolutionGenerator( + verifier=ESBMC(), scenarios={ - "base": { - "initial": "Initial test message", - "system": ( + "base": FixCodeScenario( + initial=HumanMessage(content="Initial test message"), + system=( SystemMessage(content="Test message 1"), HumanMessage(content="Test message 2"), AIMessage(content="Test message 3"), ), - } + ) }, llm=llm, ai_model=model, ) - initial_prompt = solution_generator.scenarios[default_scenario]["initial"] + initial_prompt: BaseMessage = solution_generator.scenarios[default_scenario].initial def send_message_mock(message: Optional[str] = None) -> ChatResponse: assert len(solution_generator.messages) == 1 - assert solution_generator.messages[0].content == initial_prompt + assert solution_generator.messages[0].content == initial_prompt.content assert solution_generator.messages[0].type == HumanMessage(content="").type return ChatResponse() @@ -60,18 +62,18 @@ def send_message_mock(message: Optional[str] = None) -> ChatResponse: # Check now if the message stack is wiped per generate solution call. solution_generator.generate_solution(ignore_system_message=True) - solution_generator.scenarios[default_scenario]["initial"] = initial_prompt = ( - "aaaaaaa" + solution_generator.scenarios[default_scenario].initial = initial_prompt = ( + HumanMessage("aaaaaaa") ) solution_generator.generate_solution(ignore_system_message=True) - solution_generator.scenarios[default_scenario]["initial"] = initial_prompt = ( - "bbbbbbb" + solution_generator.scenarios[default_scenario].initial = initial_prompt = ( + HumanMessage("bbbbbbb") ) solution_generator.generate_solution(ignore_system_message=True) - solution_generator.scenarios[default_scenario]["initial"] = initial_prompt = ( - "ccccccc" + solution_generator.scenarios[default_scenario].initial = initial_prompt = ( + HumanMessage("ccccccc") ) @@ -81,15 +83,16 @@ def test_message_stack(setup_llm_model) -> None: solution_generator = LatestStateSolutionGenerator( llm=llm, ai_model=model, + verifier=ESBMC(), scenarios={ - "base": { - "initial": "Initial test message", - "system": ( + "base": FixCodeScenario( + initial=HumanMessage("Initial test message"), + system=( SystemMessage(content="Test message 1"), HumanMessage(content="Test message 2"), AIMessage(content="Test message 3"), ), - } + ) }, ) @@ -100,10 +103,14 @@ def test_message_stack(setup_llm_model) -> None: solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[0] - solution_generator.scenarios[default_scenario]["initial"] = "Test message 2" + solution_generator.scenarios[default_scenario].initial = HumanMessage( + "Test message 2" + ) solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[1] - solution_generator.scenarios[default_scenario]["initial"] = "Test message 3" + solution_generator.scenarios[default_scenario].initial = HumanMessage( + "Test message 3" + ) solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[2] diff --git a/tests/test_reverse_order_solution_generator.py b/tests/test_reverse_order_solution_generator.py index 000fb64..d661d80 100644 --- a/tests/test_reverse_order_solution_generator.py +++ b/tests/test_reverse_order_solution_generator.py @@ -9,11 +9,12 @@ SystemMessage, ) -from esbmc_ai.config import default_scenario +from esbmc_ai.config import FixCodeScenario, default_scenario from esbmc_ai.ai_models import AIModel from esbmc_ai.chats.reverse_order_solution_generator import ( ReverseOrderSolutionGenerator, ) +from esbmc_ai.verifiers import ESBMC @pytest.fixture(scope="function") @@ -41,15 +42,16 @@ def test_message_stack(setup_llm_model) -> None: solution_generator = ReverseOrderSolutionGenerator( llm=llm, ai_model=model, + verifier=ESBMC(), scenarios={ - "base": { - "initial": "Initial test message", - "system": ( + "base": FixCodeScenario( + initial=HumanMessage("Initial test message"), + system=( SystemMessage(content="Test message 1"), HumanMessage(content="Test message 2"), AIMessage(content="Test message 3"), ), - } + ) }, ) @@ -60,10 +62,14 @@ def test_message_stack(setup_llm_model) -> None: solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[0] - solution_generator.scenarios[default_scenario]["initial"] = "Test message 2" + solution_generator.scenarios[default_scenario].initial = HumanMessage( + "Test message 2" + ) solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[1] - solution_generator.scenarios[default_scenario]["initial"] = "Test message 3" + solution_generator.scenarios[default_scenario].initial = HumanMessage( + "Test message 3" + ) solution, _ = solution_generator.generate_solution(ignore_system_message=True) assert solution == llm.responses[2] diff --git a/tests/test_solution.py b/tests/test_solution.py index 4ab3c98..2f4bdd9 100644 --- a/tests/test_solution.py +++ b/tests/test_solution.py @@ -1,5 +1,6 @@ # Author: Yiannis Charalambous +from pathlib import Path import pytest from esbmc_ai.solution import Solution, SourceFile @@ -15,32 +16,32 @@ def solution() -> Solution: def test_add_source_file(solution) -> None: src = '#include int main(int argc, char** argv) { printf("hello world\n"); return 0;}' - solution.add_source_file("Testfile1", src) - solution.add_source_file("Testfile2", src) - solution.add_source_file("Testfile3", src) + solution.add_source_file(SourceFile(Path("Testfile1"), src)) + solution.add_source_file(SourceFile(Path("Testfile2"), src)) + solution.add_source_file(SourceFile(Path("Testfile3"), src)) assert len(solution.files) == 3 assert ( - solution.files[0].file_path == "Testfile1" + solution.files[0].file_path == Path("Testfile1") and solution.files[0].latest_content == src ) assert ( - solution.files[1].file_path == "Testfile2" + solution.files[1].file_path == Path("Testfile2") and solution.files[1].latest_content == src ) assert ( - solution.files[2].file_path == "Testfile3" + solution.files[2].file_path == Path("Testfile3") and solution.files[2].latest_content == src ) assert ( len(solution.files_mapped) == 3 - and solution.files_mapped["Testfile1"].file_path == "Testfile1" + and solution.files_mapped["Testfile1"].file_path == Path("Testfile1") and solution.files_mapped["Testfile1"].initial_content == src - and solution.files_mapped["Testfile2"].file_path == "Testfile2" + and solution.files_mapped["Testfile2"].file_path == Path("Testfile2") and solution.files_mapped["Testfile2"].initial_content == src - and solution.files_mapped["Testfile3"].file_path == "Testfile3" + and solution.files_mapped["Testfile3"].file_path == Path("Testfile3") and solution.files_mapped["Testfile3"].initial_content == src ) diff --git a/tests/test_solution_generator.py b/tests/test_solution_generator.py index 040e9f0..0fcd5cd 100644 --- a/tests/test_solution_generator.py +++ b/tests/test_solution_generator.py @@ -4,8 +4,10 @@ from langchain_core.language_models import FakeListChatModel, FakeListLLM import pytest +from esbmc_ai.config import FixCodeScenario from esbmc_ai.ai_models import AIModel from esbmc_ai.chats.solution_generator import SolutionGenerator +from esbmc_ai.verifiers import ESBMC @pytest.fixture(scope="function") @@ -27,15 +29,16 @@ def test_call_update_state_first(setup_llm_model) -> None: solution_generator = SolutionGenerator( llm=llm, ai_model=model, + verifier=ESBMC(), scenarios={ - "base": { - "initial": "Initial test message", - "system": ( + "base": FixCodeScenario( + initial=HumanMessage("Initial test message"), + system=( SystemMessage(content="Test message 1"), HumanMessage(content="Test message 2"), AIMessage(content="Test message 3"), ), - } + ) }, ) @@ -45,14 +48,14 @@ def test_call_update_state_first(setup_llm_model) -> None: def test_get_code_from_solution(): assert ( - SolutionGenerator.get_code_from_solution( + SolutionGenerator.extract_code_from_solution( "This is a code block:\n\n```c\naaa\n```" ) == "aaa" ) assert ( - SolutionGenerator.get_code_from_solution( + SolutionGenerator.extract_code_from_solution( "This is a code block:\n\n```\nabc\n```" ) == "abc" @@ -60,12 +63,14 @@ def test_get_code_from_solution(): # Edge case assert ( - SolutionGenerator.get_code_from_solution("This is a code block:```abc\n```") + SolutionGenerator.extract_code_from_solution("This is a code block:```abc\n```") == "" ) assert ( - SolutionGenerator.get_code_from_solution("The repaired C code is:\n\n```\n```") + SolutionGenerator.extract_code_from_solution( + "The repaired C code is:\n\n```\n```" + ) == "" ) @@ -78,15 +83,18 @@ def test_substitution() -> None: chat = SolutionGenerator( scenarios={ - "base": { - "initial": "{source_code}{esbmc_output}{error_line}{error_type}", - "system": ( + "base": FixCodeScenario( + initial=HumanMessage( + "{source_code}{esbmc_output}{error_line}{error_type}" + ), + system=( SystemMessage( content="System:{source_code}{esbmc_output}{error_line}{error_type}" ), ), - } + ) }, + verifier=ESBMC(), ai_model=AIModel("test", 10000000), llm=FakeListChatModel(responses=["22222", "33333"]), source_code_format="full", @@ -104,17 +112,25 @@ def test_substitution() -> None: + "dereference failure: Access to object out of bounds" ) - assert chat.messages[1].content == "22222" + assert ( + chat.messages[1].content + == "11111" + + esbmc_output + + str(285) + + "dereference failure: Access to object out of bounds" + ) + + assert chat.messages[2].content == "22222" chat.update_state("11111", esbmc_output) chat.generate_solution(ignore_system_message=False) assert ( - chat.messages[2].content + chat.messages[3].content == "11111" + esbmc_output + str(285) + "dereference failure: Access to object out of bounds" ) - assert chat.messages[3].content == "33333" + assert chat.messages[4].content == "33333" diff --git a/tests/test_user_chat.py b/tests/test_user_chat.py index 2df2049..a528376 100644 --- a/tests/test_user_chat.py +++ b/tests/test_user_chat.py @@ -8,6 +8,7 @@ from esbmc_ai.ai_models import AIModel from esbmc_ai.chat_response import ChatResponse, FinishReason from esbmc_ai.chats.user_chat import UserChat +from esbmc_ai.verifiers import ESBMC as ESBMCUtil @pytest.fixture @@ -22,6 +23,7 @@ def setup(): system_messages=system_messages, ai_model=AIModel(name="test", tokens=12), llm=FakeListChatModel(responses=[summary_text]), + verifier=ESBMCUtil(), source_code="This is source code", esbmc_output="This is esbmc output", set_solution_messages=[ @@ -85,6 +87,7 @@ def test_substitution() -> None: system_messages=[ SystemMessage(content="{source_code}{esbmc_output}{error_line}{error_type}") ], + verifier=ESBMCUtil(), llm=FakeListChatModel(responses=["THIS IS A SUMMARY OF THE CONVERSATION"]), set_solution_messages=[HumanMessage(content="")], )