diff --git a/dewy/chunks/models.py b/dewy/chunks/models.py index db2a40d..4981fd4 100644 --- a/dewy/chunks/models.py +++ b/dewy/chunks/models.py @@ -8,6 +8,7 @@ class TextChunk(BaseModel): kind: Literal["text"] = "text" raw: bool + text: str start_char_idx: Optional[int] = Field( default=None, description="Start char index of the chunk." ) diff --git a/dewy/chunks/router.py b/dewy/chunks/router.py index 16de872..61d49f7 100644 --- a/dewy/chunks/router.py +++ b/dewy/chunks/router.py @@ -5,7 +5,7 @@ from dewy.common.collection_embeddings import CollectionEmbeddings from dewy.common.db import PgPoolDep -from .models import Chunk, RetrieveRequest, RetrieveResponse +from .models import Chunk, RetrieveRequest, RetrieveResponse, TextChunk router = APIRouter(prefix="/chunks") @@ -14,22 +14,29 @@ async def list_chunks( pg_pool: PgPoolDep, collection_id: Annotated[int | None, Query(description="Limit to chunks associated with this collection")] = None, document_id: Annotated[int | None, Query(description="Limit to chunks associated with this document")] = None, + page: int | None = 1, + perPage: int | None = 10, ) -> List[Chunk]: """List chunks.""" # TODO: handle collection & document ID results = await pg_pool.fetch( """ - SELECT chunk.id, chunk.document_id, chunk.kind, chunk.text + SELECT chunk.id, chunk.document_id, chunk.kind, TRUE as raw, chunk.text FROM chunk + JOIN document ON document.id = chunk.document_id WHERE document.collection_id = coalesce($1, document.collection_id) AND chunk.document_id = coalesce($2, chunk.document_id) - JOIN document ON document.id = chunk.document_id + ORDER BY chunk.id + OFFSET $4 + LIMIT $3 """, collection_id, document_id, + perPage, + page, ) - return [Chunk.model_validate(dict(result)) for result in results] + return [TextChunk.model_validate(dict(result)) for result in results] PathChunkId = Annotated[int, Path(..., description="The chunk ID.")] @@ -64,5 +71,6 @@ async def retrieve_chunks( return RetrieveResponse( summary=None, text_results=text_results if request.include_text_chunks else [], + image_results=[], # image_results=image_results if request.include_image_chunks else [], ) diff --git a/dewy/common/collection_embeddings.py b/dewy/common/collection_embeddings.py index e252187..d7ff311 100644 --- a/dewy/common/collection_embeddings.py +++ b/dewy/common/collection_embeddings.py @@ -164,8 +164,9 @@ async def retrieve_text_chunks(self, query: str, n: int = 10) -> List[TextResult async with self._pg_pool.acquire() as conn: logger.info("Executing SQL query for chunks from {}", self.collection_id) - embeddings = await conn.fetch(self.collection_id, - self._retrieve_chunks, + embeddings = await conn.fetch( + self._retrieve_chunks, + self.collection_id, embedded_query, n) embeddings = [ @@ -213,7 +214,7 @@ async def ingest(self, document_id: int, url: str) -> None: INSERT INTO chunk (document_id, kind, text) VALUES ($1, $2, $3); """, - [(document_id, "text", text_chunk) for text_chunk in text_chunks], + [(document_id, "text", text_chunk.encode('utf-8').decode('utf-8', 'ignore').replace("\x00", "\uFFFD")) for text_chunk in text_chunks], ) # Then, embed each of those chunks. diff --git a/dewy/common/extract.py b/dewy/common/extract.py index 12754c0..0cc8e7d 100644 --- a/dewy/common/extract.py +++ b/dewy/common/extract.py @@ -61,7 +61,7 @@ async def extract( """Extract documents from a local or remote URL.""" import httpx - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(follow_redirects=True) as client: # Determine the extension by requesting the headers. response = await client.head(url) response.raise_for_status() @@ -69,7 +69,7 @@ async def extract( logger.debug("Content type of {} is {}", url, content_type) # Load the content. - if content_type == "application/pdf": + if content_type.startswith("application/pdf"): from tempfile import NamedTemporaryFile with NamedTemporaryFile(suffix=".pdf") as temp_file: diff --git a/dewy/documents/models.py b/dewy/documents/models.py index f8342c2..04f7e97 100644 --- a/dewy/documents/models.py +++ b/dewy/documents/models.py @@ -5,9 +5,11 @@ class CreateRequest(BaseModel): - """The name of the collection the document should be added to.""" + """The name of the collection the document should be added to. Either `collection` or `collection_id` must be provided""" + collection: Optional[str] = None - collection_id: int + """The id of the collection the document should be added to. Either `collection` or `collection_id` must be provided""" + collection_id: Optional[int] = None """The URL of the document to add.""" url: str diff --git a/dewy/documents/router.py b/dewy/documents/router.py index 582d153..c46fa46 100644 --- a/dewy/documents/router.py +++ b/dewy/documents/router.py @@ -25,15 +25,25 @@ async def add_document( ) -> Document: """Add a document.""" - row = None + collection_id = req.collection_id async with pg_pool.acquire() as conn: + if collection_id is None: + collection_id = await conn.fetchval( + """ + SELECT collection_id + FROM document + WHERE name = $1 + """, + req.collection + ) + row = None row = await conn.fetchrow( """ INSERT INTO document (collection_id, url, ingest_state) VALUES ($1, $2, 'pending') RETURNING id, collection_id, url, ingest_state, ingest_error """, - req.collection_id, + collection_id, req.url, ) diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index b993498..9034161 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -1,13 +1,21 @@ import { Admin, Resource, - ListGuesser, - EditGuesser, - ShowGuesser, + CustomRoutes, houseLightTheme as lightTheme, houseDarkTheme as darkTheme, + Menu } from "react-admin"; +import FolderIcon from '@mui/icons-material/Folder'; +import ArticleIcon from '@mui/icons-material/Article'; +import SegmentIcon from '@mui/icons-material/Segment'; +import { Route } from "react-router-dom"; import { dataProvider } from "./dataProvider"; +import { CollectionList, CollectionCreate, CollectionEdit } from "./Collection"; +import { DocumentList, DocumentCreate, DocumentEdit } from "./Document"; +import { ChunkList } from "./Chunk"; +import { Search } from "./Search"; +import { MyLayout } from "./MyLayout"; export const App = () => ( ( theme={lightTheme} darkTheme={darkTheme} defaultTheme="light" + layout={MyLayout} > record.name} + icon={FolderIcon} /> record.url} + icon={ArticleIcon} /> + + } /> + ); diff --git a/frontend/src/Chunk.tsx b/frontend/src/Chunk.tsx new file mode 100644 index 0000000..61829ac --- /dev/null +++ b/frontend/src/Chunk.tsx @@ -0,0 +1,69 @@ +import { + List, + ListBase, + TopToolbar, + FilterButton, + Pagination, + SearchInput, + TextInput, + WithListContext, + useListContext, + TextField, + RichTextField, + ChipField, + ReferenceInput, + RecordContextProvider, + WrapperField, + ReferenceField, + ListToolbar, + Title, + SimpleShowLayout, + simpleList +} from 'react-admin'; + +import { Stack, Typography, Paper, Card, Accordion } from '@mui/material'; + +type Chunk = { + id: number; + kind: string; +}; + +const ListActions = () => ( + + + +); + +const listFilters = [ + , + , + , +]; + +const ChunkListView = () => { + const { data, isLoading } = useListContext(); + if (isLoading) return null; + + return ( + <> + {data.map((chunk) => + + + + + + + + )} + + ) +}; + +export const ChunkList = () => ( + + + <ListToolbar actions={<ListActions/>} filters={listFilters}/> + <ChunkListView /> + <Pagination /> + </ListBase> +); \ No newline at end of file diff --git a/frontend/src/Collection.tsx b/frontend/src/Collection.tsx new file mode 100644 index 0000000..1faa6ab --- /dev/null +++ b/frontend/src/Collection.tsx @@ -0,0 +1,80 @@ +import { + List, + Datagrid, + TextField, + TopToolbar, + FilterButton, + SearchInput, + EditButton, + CreateButton, + Create, + Edit, + SimpleForm, + CheckboxGroupInput, + TextInput, + SelectInput, + required +} from 'react-admin'; + +const ListActions = () => ( + <TopToolbar> + <CreateButton/> + </TopToolbar> +); +export const CollectionList = () => ( + <List actions={<ListActions/>} > + <Datagrid> + <TextField source="name" /> + <TextField source="text_embedding_model" /> + <TextField source="text_distance_metric" /> + <TextField source="llm_model" /> + </Datagrid> + </List> +); + +export const ChunkingConfig = () => ( + <> + <CheckboxGroupInput label="Chunks to Extract" source="extract" defaultValue={["snippets"]} choices={[ + {id: "snippets", name: "Snippets"}, + {id: "summaries", name: "Summaries"}, + {id: "images", name: "Images"} + ]}/> + <CheckboxGroupInput label="Retrieve Using" source="index" defaultValue={["text"]} choices={[ + {id: "text", name: "Text"}, + {id: "questions_answered", name: "Questions Answered"}, + {id: "statements", name: "Statements"}, + ]}/> + </> +) + +const Form = () => ( + <SimpleForm> + <TextInput source="name" validate={[required()]} fullWidth /> + <SelectInput source="text_embedding_model" defaultValue="hf:BAAI/bge-small-en" choices={[ + {id: 'hf:BAAI/bge-small-en', name: 'BAAI/bge-small-en'}, + {id: 'openai:text-embedding-ada-002', name: 'OpenAI/text_embedding_ada_002'}, + ]}/> + <SelectInput source="text_distance_metric" defaultValue="cosine" choices={[ + {id: 'cosine', name: 'Cosine'}, + {id: 'ip', name: 'Inner Product'}, + {id: 'l2', name: 'L2-Norm'}, + ]}/> + <SelectInput source="llm_model" defaultValue="huggingface:StabilityAI/stablelm-tuned-alpha-3b" choices={[ + {id: 'huggingface:StabilityAI/stablelm-tuned-alpha-3b', name: 'stablelm-tuned-alpha-3b'}, + {id: 'openai:gpt-3.5-turbo', name: 'gpt-3.5-turbo'}, + ]}/> + <ChunkingConfig /> + </SimpleForm> +) + +export const CollectionCreate = () => ( + <Create redirect="list"> + <Form/> + </Create> +); + +export const CollectionEdit = () => ( + <Edit redirect="list"> + <Form/> + </Edit> +); diff --git a/frontend/src/Document.tsx b/frontend/src/Document.tsx new file mode 100644 index 0000000..1203755 --- /dev/null +++ b/frontend/src/Document.tsx @@ -0,0 +1,63 @@ +import { + List, + Datagrid, + TextField, + ReferenceField, + BooleanField, + FileField, + TopToolbar, + EditButton, + FilterButton, + CreateButton, + SearchInput, + Create, + Edit, + SimpleForm, + SelectInput, + ReferenceInput, + FileInput, + required, + TextInput +} from 'react-admin'; + +import { ChunkingConfig } from "./Collection"; + +const ListActions = () => ( + <TopToolbar> + <FilterButton/> + <CreateButton/> + </TopToolbar> +); + +const listFilters = [ + <ReferenceInput source="collection_id" reference="collections"/>, +]; + +export const DocumentList = () => ( + <List actions={<ListActions/>} filters={listFilters} > + <Datagrid> + <TextField source="url" /> + <> + <EditButton /> + </> + </Datagrid> + </List> +); + + +export const DocumentCreate = () => ( + <Create redirect="list"> + <SimpleForm> + <ReferenceInput source="collection_id" reference="collections" /> + <TextInput source="url"/> + </SimpleForm> + </Create> +); + +export const DocumentEdit = () => ( + <Edit redirect="list"> + <SimpleForm> + <TextInput source="url"/> + </SimpleForm> + </Edit> +); \ No newline at end of file diff --git a/frontend/src/data.json b/frontend/src/data.json index 56c9c21..e13f012 100644 --- a/frontend/src/data.json +++ b/frontend/src/data.json @@ -33,16 +33,20 @@ { "id": 0, "collection_id": 0, + "url": "https://arxiv.org/pdf/2005.11401.pdf", "title": "Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks", + "chunks": {"tokens": {"index": {"statements": true}}}, "kind": "PDF", - "chunks": {"tokens": {"index": {"statements": true}}} + "indexed": false }, { "id": 1, "collection_id": 1, + "url": "https://arxiv.org/pdf/2305.14283.pdf", "title": "Query Rewriting for Retrieval-Augmented Large Language Models", + "chunks": null, "kind": "PDF", - "chunks": null + "indexed": true } ], "chunks": [ diff --git a/frontend/src/dataProvider.ts b/frontend/src/dataProvider.ts index a73a52c..00cedf6 100644 --- a/frontend/src/dataProvider.ts +++ b/frontend/src/dataProvider.ts @@ -1,4 +1,72 @@ +import { fetchUtils } from 'react-admin'; +import { stringify } from 'query-string'; import fakeRestDataProvider from "ra-data-fakerest"; import data from "./data.json"; -export const dataProvider = fakeRestDataProvider(data, true); +export const fakeDataProvider = fakeRestDataProvider(data, true); + +const apiUrl = 'http://localhost:8000'; +const httpClient = fetchUtils.fetchJson; + + +export const dataProvider = { + // get a list of records based on sort, filter, and pagination + getList: async (resource, params) => { + // TODO: Handle pagination and sorting + const { page, perPage } = params.pagination; + const { field, order } = params.sort; + const queryparams = {...params.pagination, ...params.sort, ...params.filter}; + + const url = `${apiUrl}/api/${resource}/?${stringify(queryparams)}`; + const { json, headers } = await httpClient(url); + console.log(params); + return { + data: json, + pageInfo: {hasNextPage: false, hasPreviousPage: false}, + }; + }, + getOne: async(resource, params) => { + const url = `${apiUrl}/api/${resource}/${params.id}` + const { json } = await httpClient(url, params); + return { data: json }; + }, + getMany: async (resource, params) => { + const query = { + filter: JSON.stringify({ ids: params.ids }), + }; + const url = `${apiUrl}/api/${resource}?${stringify(query)}`; + const { json } = await httpClient(url, params); + return { data: json }; + }, + create: async (resource, params) => { + const { json } = await httpClient(`${apiUrl}/api/${resource}/`, { + method: 'PUT', + body: JSON.stringify(params.data), + }) + return { data: json }; + }, + update: async (resource, params) => { + const url = `${apiUrl}/api/${resource}/${params.id}`; + const { json } = await httpClient(url, { + method: 'PUT', + body: JSON.stringify(params.data), + }) + return { data: json }; + }, + delete: async (resource, params) => { + const url = `${apiUrl}/api/${resource}/${params.id}`; + const { json } = await httpClient(url, { + method: 'DELETE', + }); + return { data: json }; + }, + deleteMany: async (resource, params) => { + for (const id of params.ids) { + const url = `${apiUrl}/api/${resource}/${id}`; + const { json } = await httpClient(url, { + method: 'DELETE', + }); + } + return {} + }, +} diff --git a/migrations/0001_schema.sql b/migrations/0001_schema.sql index c51f3cd..9c950a3 100644 --- a/migrations/0001_schema.sql +++ b/migrations/0001_schema.sql @@ -95,4 +95,11 @@ CREATE TABLE embedding( PRIMARY KEY (id), FOREIGN KEY(chunk_id) REFERENCES chunk (id) -); \ No newline at end of file +); + +-- Default collection +INSERT INTO collection (name, text_embedding_model, text_distance_metric) VALUES ('main', 'openai:text-embedding-ada-002', 'cosine'); +CREATE INDEX embedding_collection_1_index +ON embedding +USING hnsw ((embedding::vector(1536)) vector_cosine_ops) +WHERE collection_id = 1; \ No newline at end of file