Skip to content

Commit

Permalink
Fix class promotion (#6187)
Browse files Browse the repository at this point in the history
In LangChain, all module classes are enumerated in the `__init__.py`
file of the correspondent module. But some classes were missed and were
not included in the module `__init__.py`

This PR:
- added the missed classes to the module `__init__.py` files
- `__init__.py:__all_` variable value (a list of the class names) was
sorted
- `langchain.tools.sql_database.tool.QueryCheckerTool` was renamed into
the `QuerySQLCheckerTool` because it conflicted with
`langchain.tools.spark_sql.tool.QueryCheckerTool`
- changes to `pyproject.toml`:
  - added `pgvector` to `pyproject.toml:extended_testing`
- added `pandas` to
`pyproject.toml:[tool.poetry.group.test.dependencies]`
- commented out the `streamlit` from `collbacks/__init__.py`, It is
because now the `streamlit` requires Python >=3.7, !=3.9.7
- fixed duplicate names in `tools`
- fixed correspondent ut-s

#### Who can review?
@hwchase17
@dev2049
  • Loading branch information
leo-gan authored Jun 18, 2023
1 parent c0c2fd0 commit c7ca350
Show file tree
Hide file tree
Showing 25 changed files with 683 additions and 1,115 deletions.
2 changes: 2 additions & 0 deletions langchain/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from langchain.agents.loading import load_agent
from langchain.agents.mrkl.base import MRKLChain, ZeroShotAgent
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.agents.react.base import ReActChain, ReActTextWorldAgent
from langchain.agents.self_ask_with_search.base import SelfAskWithSearchChain
from langchain.agents.structured_chat.base import StructuredChatAgent
Expand All @@ -47,6 +48,7 @@
"ConversationalChatAgent",
"LLMSingleActionAgent",
"MRKLChain",
"OpenAIFunctionsAgent",
"ReActChain",
"ReActTextWorldAgent",
"SelfAskWithSearchChain",
Expand Down
4 changes: 2 additions & 2 deletions langchain/agents/agent_toolkits/sql/toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain.tools.sql_database.tool import (
InfoSQLDatabaseTool,
ListSQLDatabaseTool,
QueryCheckerTool,
QuerySQLCheckerTool,
QuerySQLDataBaseTool,
)

Expand Down Expand Up @@ -55,5 +55,5 @@ def get_tools(self) -> List[BaseTool]:
db=self.db, description=info_sql_database_tool_description
),
ListSQLDatabaseTool(db=self.db),
QueryCheckerTool(db=self.db, llm=self.llm),
QuerySQLCheckerTool(db=self.db, llm=self.llm),
]
27 changes: 20 additions & 7 deletions langchain/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from langchain.callbacks.aim_callback import AimCallbackHandler
from langchain.callbacks.argilla_callback import ArgillaCallbackHandler
from langchain.callbacks.arize_callback import ArizeCallbackHandler
from langchain.callbacks.clearml_callback import ClearMLCallbackHandler
from langchain.callbacks.comet_ml_callback import CometCallbackHandler
from langchain.callbacks.file import FileCallbackHandler
Expand All @@ -15,23 +16,35 @@
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.streaming_aiter import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.streaming_stdout_final_only import (
FinalStreamingStdOutCallbackHandler,
)

# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
from langchain.callbacks.wandb_callback import WandbCallbackHandler
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler

__all__ = [
"AimCallbackHandler",
"ArgillaCallbackHandler",
"ArizeCallbackHandler",
"AsyncIteratorCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"FileCallbackHandler",
"FinalStreamingStdOutCallbackHandler",
"HumanApprovalCallbackHandler",
"MlflowCallbackHandler",
"OpenAICallbackHandler",
"StdOutCallbackHandler",
"FileCallbackHandler",
"AimCallbackHandler",
"StreamingStdOutCallbackHandler",
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
# "StreamlitCallbackHandler",
"WandbCallbackHandler",
"MlflowCallbackHandler",
"ClearMLCallbackHandler",
"CometCallbackHandler",
"WhyLabsCallbackHandler",
"AsyncIteratorCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"wandb_tracing_enabled",
"HumanApprovalCallbackHandler",
]
60 changes: 37 additions & 23 deletions langchain/chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from langchain.chains.loading import load_chain
from langchain.chains.mapreduce import MapReduceChain
from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.natbot.base import NatBotChain
from langchain.chains.openai_functions import (
create_extraction_chain,
create_extraction_chain_pydantic,
Expand All @@ -34,6 +35,13 @@
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.router import (
LLMRouterChain,
MultiPromptChain,
MultiRetrievalQAChain,
MultiRouteChain,
RouterChain,
)
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.chains.sql_database.base import (
SQLDatabaseChain,
Expand All @@ -42,41 +50,47 @@
from langchain.chains.transform import TransformChain

__all__ = [
"APIChain",
"AnalyzeDocumentChain",
"ChatVectorDBChain",
"ConstitutionalChain",
"ConversationChain",
"LLMChain",
"ConversationalRetrievalChain",
"FlareChain",
"GraphCypherQAChain",
"GraphQAChain",
"HypotheticalDocumentEmbedder",
"LLMBashChain",
"LLMChain",
"LLMCheckerChain",
"LLMSummarizationCheckerChain",
"LLMMathChain",
"LLMRequestsChain",
"LLMRouterChain",
"LLMSummarizationCheckerChain",
"MapReduceChain",
"MultiPromptChain",
"MultiRetrievalQAChain",
"MultiRouteChain",
"NatBotChain",
"NebulaGraphQAChain",
"OpenAIModerationChain",
"OpenAPIEndpointChain",
"PALChain",
"QAGenerationChain",
"QAWithSourcesChain",
"RetrievalQA",
"RetrievalQAWithSourcesChain",
"RouterChain",
"SQLDatabaseChain",
"SQLDatabaseSequentialChain",
"SequentialChain",
"SimpleSequentialChain",
"TransformChain",
"VectorDBQA",
"VectorDBQAWithSourcesChain",
"APIChain",
"LLMRequestsChain",
"TransformChain",
"MapReduceChain",
"OpenAIModerationChain",
"SQLDatabaseSequentialChain",
"load_chain",
"AnalyzeDocumentChain",
"HypotheticalDocumentEmbedder",
"ChatVectorDBChain",
"GraphQAChain",
"GraphCypherQAChain",
"ConstitutionalChain",
"QAGenerationChain",
"RetrievalQA",
"RetrievalQAWithSourcesChain",
"ConversationalRetrievalChain",
"OpenAPIEndpointChain",
"FlareChain",
"NebulaGraphQAChain",
"create_extraction_chain",
"create_tagging_chain",
"create_extraction_chain_pydantic",
"create_tagging_chain",
"create_tagging_chain_pydantic",
"load_chain",
]
3 changes: 2 additions & 1 deletion langchain/docstore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Wrappers on top of docstores."""
from langchain.docstore.arbitrary_fn import DocstoreFn
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.docstore.wikipedia import Wikipedia

__all__ = ["InMemoryDocstore", "Wikipedia"]
__all__ = ["DocstoreFn", "InMemoryDocstore", "Wikipedia"]
20 changes: 15 additions & 5 deletions langchain/document_loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
from langchain.document_loaders.bigquery import BigQueryLoader
from langchain.document_loaders.bilibili import BiliBiliLoader
from langchain.document_loaders.blackboard import BlackboardLoader
from langchain.document_loaders.blob_loaders import (
Blob,
BlobLoader,
FileSystemBlobLoader,
YoutubeAudioLoader,
)
from langchain.document_loaders.blockchain import BlockchainDocumentLoader
from langchain.document_loaders.chatgpt import ChatGPTLoader
from langchain.document_loaders.college_confidential import CollegeConfidentialLoader
Expand Down Expand Up @@ -150,6 +156,8 @@
"BigQueryLoader",
"BiliBiliLoader",
"BlackboardLoader",
"Blob",
"BlobLoader",
"BlockchainDocumentLoader",
"CSVLoader",
"ChatGPTLoader",
Expand All @@ -163,10 +171,13 @@
"DocugamiLoader",
"Docx2txtLoader",
"DuckDBLoader",
"FaunaLoader",
"EmbaasBlobLoader",
"EmbaasLoader",
"EverNoteLoader",
"FacebookChatLoader",
"FaunaLoader",
"FigmaFileLoader",
"FileSystemBlobLoader",
"GCSDirectoryLoader",
"GCSFileLoader",
"GitHubIssuesLoader",
Expand Down Expand Up @@ -194,8 +205,8 @@
"NotionDBLoader",
"NotionDirectoryLoader",
"ObsidianLoader",
"OneDriveLoader",
"OneDriveFileLoader",
"OneDriveLoader",
"OnlinePDFLoader",
"OutlookMessageLoader",
"PDFMinerLoader",
Expand All @@ -219,6 +230,7 @@
"SeleniumURLLoader",
"SitemapLoader",
"SlackDirectoryLoader",
"SnowflakeLoader",
"SpreedlyLoader",
"StripeLoader",
"TelegramChatApiLoader",
Expand Down Expand Up @@ -251,8 +263,6 @@
"WebBaseLoader",
"WhatsAppChatLoader",
"WikipediaLoader",
"YoutubeAudioLoader",
"YoutubeLoader",
"SnowflakeLoader",
"EmbaasLoader",
"EmbaasBlobLoader",
]
56 changes: 29 additions & 27 deletions langchain/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
from langchain.llms.human import HumanInputLLM
from langchain.llms.llamacpp import LlamaCpp
from langchain.llms.manifest import ManifestWrapper
from langchain.llms.modal import Modal
from langchain.llms.mosaicml import MosaicML
from langchain.llms.nlpcloud import NLPCloud
Expand All @@ -47,25 +48,34 @@
from langchain.llms.writer import Writer

__all__ = [
"Anthropic",
"AI21",
"AlephAlpha",
"Anthropic",
"Anyscale",
"Aviary",
"AzureOpenAI",
"Banana",
"Baseten",
"Beam",
"Bedrock",
"CTransformers",
"CerebriumAI",
"Cohere",
"CTransformers",
"Databricks",
"DeepInfra",
"FakeListLLM",
"ForefrontAI",
"GPT4All",
"GooglePalm",
"GooseAI",
"GPT4All",
"HuggingFaceEndpoint",
"HuggingFaceHub",
"HuggingFacePipeline",
"HuggingFaceTextGenInference",
"HumanInputLLM",
"LlamaCpp",
"TextGen",
"ManifestWrapper",
"Modal",
"MosaicML",
"NLPCloud",
Expand All @@ -74,25 +84,17 @@
"OpenLM",
"Petals",
"PipelineAI",
"HuggingFaceEndpoint",
"HuggingFaceHub",
"SagemakerEndpoint",
"HuggingFacePipeline",
"AI21",
"AzureOpenAI",
"Replicate",
"SelfHostedPipeline",
"SelfHostedHuggingFaceLLM",
"PredictionGuard",
"PromptLayerOpenAI",
"PromptLayerOpenAIChat",
"StochasticAI",
"Writer",
"RWKV",
"PredictionGuard",
"HumanInputLLM",
"HuggingFaceTextGenInference",
"FakeListLLM",
"Replicate",
"SagemakerEndpoint",
"SelfHostedHuggingFaceLLM",
"SelfHostedPipeline",
"StochasticAI",
"VertexAI",
"Writer",
]

type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
Expand All @@ -101,6 +103,7 @@
"anthropic": Anthropic,
"anyscale": Anyscale,
"aviary": Aviary,
"azure": AzureOpenAI,
"bananadev": Banana,
"baseten": Baseten,
"beam": Beam,
Expand All @@ -109,32 +112,31 @@
"ctransformers": CTransformers,
"databricks": Databricks,
"deepinfra": DeepInfra,
"fake-list": FakeListLLM,
"forefrontai": ForefrontAI,
"google_palm": GooglePalm,
"gooseai": GooseAI,
"gpt4all": GPT4All,
"huggingface_hub": HuggingFaceHub,
"huggingface_endpoint": HuggingFaceEndpoint,
"huggingface_hub": HuggingFaceHub,
"huggingface_pipeline": HuggingFacePipeline,
"huggingface_textgen_inference": HuggingFaceTextGenInference,
"human-input": HumanInputLLM,
"llamacpp": LlamaCpp,
"textgen": TextGen,
"modal": Modal,
"mosaic": MosaicML,
"sagemaker_endpoint": SagemakerEndpoint,
"nlpcloud": NLPCloud,
"human-input": HumanInputLLM,
"openai": OpenAI,
"openlm": OpenLM,
"petals": Petals,
"pipelineai": PipelineAI,
"huggingface_pipeline": HuggingFacePipeline,
"azure": AzureOpenAI,
"replicate": Replicate,
"rwkv": RWKV,
"sagemaker_endpoint": SagemakerEndpoint,
"self_hosted": SelfHostedPipeline,
"self_hosted_hugging_face": SelfHostedHuggingFaceLLM,
"stochasticai": StochasticAI,
"writer": Writer,
"rwkv": RWKV,
"huggingface_textgen_inference": HuggingFaceTextGenInference,
"fake-list": FakeListLLM,
"vertexai": VertexAI,
"writer": Writer,
}
Loading

0 comments on commit c7ca350

Please sign in to comment.