generated from intsystems/SoftwareTemplate-simplified
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6e8b269
commit 608d1cd
Showing
21 changed files
with
4,082 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import openai | ||
from openai import OpenAI, AsyncOpenAI | ||
import json | ||
|
||
|
||
class BaseExtractior: | ||
""" | ||
A class to handle any extraction using OpenAI's GPT models. | ||
""" | ||
|
||
def __init__(self, openai_client: OpenAI | AsyncOpenAI) -> None: | ||
""" | ||
Initializes the BaseExtractior with an OpenAI client. | ||
Args: | ||
openai_client (OpenAI | AsyncOpenAI): An instance of OpenAI or AsyncOpenAI client to interact with the API. | ||
""" | ||
self.openai_client = openai_client | ||
self.is_async = False | ||
if isinstance(self.openai_client, AsyncOpenAI): | ||
self.is_async = True | ||
|
||
def _get_completion_result( | ||
self, model, messages: list[dict], temperature: float | ||
) -> str: | ||
|
||
completion = self.openai_client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
) | ||
|
||
return completion.choices[0].message.content | ||
|
||
def _get_completion_parsed_result( | ||
self, model, messages: list[dict], response_format, temperature: float, key: str | ||
): | ||
|
||
completion = self.openai_client.beta.chat.completions.parse( | ||
model=model, | ||
messages=messages, | ||
response_format=response_format, | ||
temperature=temperature, | ||
) | ||
|
||
return json.loads(completion.choices[0].message.content)[key] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from openai import OpenAI, AsyncOpenAI | ||
|
||
from gmg_auto.NLI_node_extraction import NodeExtractor | ||
from gmg_auto.NLI_extract_edges import EdgeExtractor | ||
from gmg_auto.NLI_suggest_node_distribution import NodeDistributer | ||
from gmg_auto.graph_class import GMG_graph | ||
|
||
|
||
class NaturalLanguageInput: | ||
""" | ||
some docstring | ||
""" | ||
|
||
def __init__(self, openai_client: OpenAI | AsyncOpenAI) -> None: | ||
self.descr = None | ||
self.openai_client = openai_client | ||
|
||
def fit(self, description: str): | ||
"""just saves the description of the graph | ||
Args: | ||
description (str): description of the graph | ||
Returns: | ||
Self: NaturalLanguageInput object | ||
""" | ||
|
||
self.descr = description | ||
|
||
# now initialize all extractors | ||
self.node_exractor = NodeExtractor(self.openai_client) | ||
self.edge_extractor = EdgeExtractor(self.openai_client) | ||
self.node_distributer = NodeDistributer(self.openai_client) | ||
|
||
def construct_graph( | ||
self, gpt_model: str = "gpt-4o-mini", temperature: float = 0 | ||
) -> GMG_graph: | ||
nodes = self.node_exractor.extract_nodes_gpt(self.descr, gpt_model, temperature) | ||
edges = self.edge_extractor.extract_all_edges( | ||
self.descr, nodes, gpt_model, temperature | ||
) | ||
node_distributions = self.node_distributer.suggest_vertex_distributions( | ||
self.descr, nodes, gpt_model, temperature | ||
) | ||
|
||
return GMG_graph(nodes, edges, node_distributions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
"""Code for vertices extraction. | ||
This module provides functionality to extract edges between nodes in a graph using OpenAI's GPT models. It includes methods to determine the existence and direction of edges based on a textual description. | ||
The module contains the following functions: | ||
- _extract_one_edge_gpt: Determines the existence and direction of an edge between a pair of nodes. | ||
- extract_all_edges: Extracts all possible edges from a set of nodes based on a given description. | ||
""" | ||
|
||
from pydantic import BaseModel | ||
from enum import Enum | ||
from openai import OpenAI, AsyncOpenAI | ||
from gmg_auto.NLI_base_extractor import BaseExtractior | ||
from gmg_auto.config import node_extraction_sys_message, edge_extraction_str_template | ||
|
||
|
||
class ArrowEnum(str, Enum): | ||
""" | ||
Enumeration for arrow direction types in a graph. | ||
""" | ||
|
||
no = "no arrow" | ||
forward = "forward arrow" | ||
backward = "backward arrow" | ||
|
||
|
||
class ArrowType(BaseModel): | ||
""" | ||
Model representing the type of arrow direction for an edge. | ||
""" | ||
|
||
arrow_type: ArrowEnum | ||
|
||
|
||
class EdgeExtractor(BaseExtractior): | ||
""" | ||
Class to handle edge extraction from a description using OpenAI's GPT models. | ||
""" | ||
|
||
def __init__(self, openai_client: OpenAI | AsyncOpenAI) -> None: | ||
""" | ||
Initializes the NodeExtraction with an OpenAI client. | ||
Args: | ||
openai_client (OpenAI | AsyncOpenAI): An instance of OpenAI or AsyncOpenAI client to interact with the API. | ||
""" | ||
super(EdgeExtractor, self).__init__(openai_client) | ||
|
||
@staticmethod | ||
def _get_messages_for_edge_direction( | ||
description: str, | ||
set_of_nodes: list[str], | ||
pair_of_nodes: tuple[str, str], | ||
) -> tuple[dict, dict]: | ||
""" | ||
Prepares the message payload for edge direction identification using the GPT model. | ||
Args: | ||
description (str): The description text to be processed. | ||
set_of_nodes (list[str]): The nodes of the graph. | ||
pair_of_nodes (tuple[str, str]): Pair of nodes to decide about the existance and the direction of the edge. | ||
Returns: | ||
tuple[dict, dict]: A tuple containing system and user messages formatted for the API request. | ||
""" | ||
|
||
messages = ( | ||
{"role": "system", "content": node_extraction_sys_message}, | ||
{ | ||
"role": "user", | ||
"content": edge_extraction_str_template.format( | ||
description=description, | ||
set_of_nodes=set_of_nodes, | ||
pair_of_nodes=pair_of_nodes, | ||
), | ||
}, | ||
) | ||
|
||
return messages | ||
|
||
def _extract_one_edge_gpt( | ||
self, | ||
description: str, | ||
set_of_nodes: list[str], | ||
pair_of_nodes: tuple[str, str], | ||
gpt_model: str = "gpt-4o-mini", | ||
temperature: float = 0, | ||
) -> tuple[str | None, str | None]: | ||
""" | ||
Determines the existence and direction of an edge between a pair of nodes using a GPT model. | ||
Args: | ||
description (str): The description text to be processed. | ||
set_of_nodes (list[str]): The nodes of the graph. | ||
pair_of_nodes (tuple[str, str]): A pair of nodes to check for an edge. | ||
gpt_model (str, optional): The GPT model to use for extraction. Defaults to 'gpt-4o-mini'. | ||
temperature (float, optional): The randomness of the model's output. Defaults to 0. | ||
Returns: | ||
tuple[str | None, str | None]: either (None, None) if no edge exists, or a tuple representing the edge with identified direction. | ||
""" | ||
|
||
arrow_type = self._get_completion_parsed_result( | ||
model=gpt_model, | ||
messages=self._get_messages_for_edge_direction( | ||
description, | ||
f"[{', '.join(set_of_nodes)}]", | ||
f"[{', '.join(pair_of_nodes)}]", | ||
), | ||
response_format=ArrowType, | ||
temperature=temperature, | ||
key="arrow_type", | ||
) | ||
|
||
if "forward" in arrow_type.lower(): | ||
return pair_of_nodes | ||
if "backward" in arrow_type.lower(): | ||
return pair_of_nodes[::-1] | ||
|
||
return (None, None) | ||
|
||
def extract_all_edges( | ||
self, | ||
description: str, | ||
set_of_nodes: list[str], | ||
gpt_model: str = "gpt-4o-mini", | ||
temperature: float = 0, | ||
verbose=False, | ||
) -> list[tuple[str | None, str | None]]: | ||
""" | ||
Extracts all possible edges from a set of nodes based on a given description. | ||
Args: | ||
description (str): The description text to be processed. | ||
set_of_nodes (list[str]): The nodes of the graph. | ||
gpt_model (str, optional): The GPT model to use for extraction. Defaults to 'gpt-4o-mini'. | ||
temperature (float, optional): The randomness of the model's output. Defaults to 0. | ||
verbose (bool, optional): If True, provides detailed logging. Defaults to False. | ||
Returns: | ||
list[tuple[str | None, str | None]]: A list of tuples representing edges with identified directions. | ||
""" | ||
edge_list = [] | ||
|
||
for i, node_a in enumerate(set_of_nodes): | ||
for node_b in set_of_nodes[i + 1 :]: | ||
if verbose: | ||
print(f"{node_a} # {node_b}") | ||
edge = self._extract_one_edge_gpt( | ||
description, set_of_nodes, (node_a, node_b), gpt_model, temperature | ||
) | ||
if edge[0] is not None: | ||
edge_list.append(edge) | ||
|
||
return edge_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import List | ||
from pydantic import BaseModel | ||
from openai import OpenAI, AsyncOpenAI | ||
from gmg_auto.config import node_extraction_str_template, node_extraction_sys_message | ||
from gmg_auto.NLI_base_extractor import BaseExtractior | ||
|
||
|
||
class Nodes(BaseModel): | ||
list_of_nodes: list[str] | ||
|
||
|
||
class NodeExtractor(BaseExtractior): | ||
""" | ||
A class to handle node extraction from a given text description using OpenAI's GPT models. | ||
""" | ||
|
||
def __init__(self, openai_client: OpenAI | AsyncOpenAI) -> None: | ||
""" | ||
Initializes the NodeExtractor with an OpenAI client. | ||
Args: | ||
openai_client (OpenAI | AsyncOpenAI): An instance of OpenAI or AsyncOpenAI client to interact with the API. | ||
""" | ||
super(NodeExtractor, self).__init__(openai_client) | ||
|
||
def extract_nodes_gpt( | ||
self, description: str, gpt_model: str = "gpt-4o-mini", temperature: float = 0.0 | ||
) -> list[str]: | ||
""" | ||
Extracts nodes from a given description using a specified GPT model. | ||
Args: | ||
description (str): The text description from which nodes need to be extracted. | ||
gpt_model (str, optional): The GPT model to use for extraction. Defaults to 'gpt-4o-mini'. | ||
temperature (float, optional): The randomness of the model's output. Defaults to 0.0. | ||
Returns: | ||
list[str]: A list of extracted nodes from the description. | ||
""" | ||
|
||
list_of_nodes = self._get_completion_parsed_result( | ||
model=gpt_model, | ||
messages=self._get_messages_for_node_extraction(description), | ||
response_format=Nodes, | ||
temperature=temperature, | ||
key="list_of_nodes", | ||
) | ||
|
||
return list_of_nodes | ||
|
||
@staticmethod | ||
def _get_messages_for_node_extraction(description: str) -> tuple[dict, dict]: | ||
""" | ||
Prepares the message payload for node extraction using the GPT model. | ||
Args: | ||
description (str): The description text to be processed. | ||
Returns: | ||
tuple[dict, dict]: A tuple containing system and user messages formatted for the API request. | ||
""" | ||
messages = ( | ||
{"role": "system", "content": node_extraction_sys_message}, | ||
{ | ||
"role": "user", | ||
"content": node_extraction_str_template.format(description=description), | ||
}, | ||
) | ||
|
||
return messages |
Oops, something went wrong.